741 lines
26 KiB
Python
741 lines
26 KiB
Python
|
"""WebSocket protocol versions 13 and 8."""
|
||
|
|
||
|
import asyncio
|
||
|
import functools
|
||
|
import json
|
||
|
import random
|
||
|
import re
|
||
|
import sys
|
||
|
import zlib
|
||
|
from enum import IntEnum
|
||
|
from struct import Struct
|
||
|
from typing import (
|
||
|
Any,
|
||
|
Callable,
|
||
|
Final,
|
||
|
List,
|
||
|
NamedTuple,
|
||
|
Optional,
|
||
|
Pattern,
|
||
|
Set,
|
||
|
Tuple,
|
||
|
Union,
|
||
|
cast,
|
||
|
)
|
||
|
|
||
|
from .base_protocol import BaseProtocol
|
||
|
from .compression_utils import ZLibCompressor, ZLibDecompressor
|
||
|
from .helpers import NO_EXTENSIONS, set_exception
|
||
|
from .streams import DataQueue
|
||
|
|
||
|
__all__ = (
|
||
|
"WS_CLOSED_MESSAGE",
|
||
|
"WS_CLOSING_MESSAGE",
|
||
|
"WS_KEY",
|
||
|
"WebSocketReader",
|
||
|
"WebSocketWriter",
|
||
|
"WSMessage",
|
||
|
"WebSocketError",
|
||
|
"WSMsgType",
|
||
|
"WSCloseCode",
|
||
|
)
|
||
|
|
||
|
|
||
|
class WSCloseCode(IntEnum):
|
||
|
OK = 1000
|
||
|
GOING_AWAY = 1001
|
||
|
PROTOCOL_ERROR = 1002
|
||
|
UNSUPPORTED_DATA = 1003
|
||
|
ABNORMAL_CLOSURE = 1006
|
||
|
INVALID_TEXT = 1007
|
||
|
POLICY_VIOLATION = 1008
|
||
|
MESSAGE_TOO_BIG = 1009
|
||
|
MANDATORY_EXTENSION = 1010
|
||
|
INTERNAL_ERROR = 1011
|
||
|
SERVICE_RESTART = 1012
|
||
|
TRY_AGAIN_LATER = 1013
|
||
|
BAD_GATEWAY = 1014
|
||
|
|
||
|
|
||
|
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
||
|
|
||
|
# For websockets, keeping latency low is extremely important as implementations
|
||
|
# generally expect to be able to send and receive messages quickly. We use a
|
||
|
# larger chunk size than the default to reduce the number of executor calls
|
||
|
# since the executor is a significant source of latency and overhead when
|
||
|
# the chunks are small. A size of 5KiB was chosen because it is also the
|
||
|
# same value python-zlib-ng choose to use as the threshold to release the GIL.
|
||
|
|
||
|
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
|
||
|
|
||
|
|
||
|
class WSMsgType(IntEnum):
|
||
|
# websocket spec types
|
||
|
CONTINUATION = 0x0
|
||
|
TEXT = 0x1
|
||
|
BINARY = 0x2
|
||
|
PING = 0x9
|
||
|
PONG = 0xA
|
||
|
CLOSE = 0x8
|
||
|
|
||
|
# aiohttp specific types
|
||
|
CLOSING = 0x100
|
||
|
CLOSED = 0x101
|
||
|
ERROR = 0x102
|
||
|
|
||
|
text = TEXT
|
||
|
binary = BINARY
|
||
|
ping = PING
|
||
|
pong = PONG
|
||
|
close = CLOSE
|
||
|
closing = CLOSING
|
||
|
closed = CLOSED
|
||
|
error = ERROR
|
||
|
|
||
|
|
||
|
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||
|
|
||
|
|
||
|
UNPACK_LEN2 = Struct("!H").unpack_from
|
||
|
UNPACK_LEN3 = Struct("!Q").unpack_from
|
||
|
UNPACK_CLOSE_CODE = Struct("!H").unpack
|
||
|
PACK_LEN1 = Struct("!BB").pack
|
||
|
PACK_LEN2 = Struct("!BBH").pack
|
||
|
PACK_LEN3 = Struct("!BBQ").pack
|
||
|
PACK_CLOSE_CODE = Struct("!H").pack
|
||
|
MSG_SIZE: Final[int] = 2**14
|
||
|
DEFAULT_LIMIT: Final[int] = 2**16
|
||
|
|
||
|
|
||
|
class WSMessage(NamedTuple):
|
||
|
type: WSMsgType
|
||
|
# To type correctly, this would need some kind of tagged union for each type.
|
||
|
data: Any
|
||
|
extra: Optional[str]
|
||
|
|
||
|
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
|
||
|
"""Return parsed JSON data.
|
||
|
|
||
|
.. versionadded:: 0.22
|
||
|
"""
|
||
|
return loads(self.data)
|
||
|
|
||
|
|
||
|
WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
|
||
|
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
|
||
|
|
||
|
|
||
|
class WebSocketError(Exception):
|
||
|
"""WebSocket protocol parser error."""
|
||
|
|
||
|
def __init__(self, code: int, message: str) -> None:
|
||
|
self.code = code
|
||
|
super().__init__(code, message)
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return cast(str, self.args[1])
|
||
|
|
||
|
|
||
|
class WSHandshakeError(Exception):
|
||
|
"""WebSocket protocol handshake error."""
|
||
|
|
||
|
|
||
|
native_byteorder: Final[str] = sys.byteorder
|
||
|
|
||
|
|
||
|
# Used by _websocket_mask_python
|
||
|
@functools.lru_cache
|
||
|
def _xor_table() -> List[bytes]:
|
||
|
return [bytes(a ^ b for a in range(256)) for b in range(256)]
|
||
|
|
||
|
|
||
|
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
|
||
|
"""Websocket masking function.
|
||
|
|
||
|
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
|
||
|
object of any length. The contents of `data` are masked with `mask`,
|
||
|
as specified in section 5.3 of RFC 6455.
|
||
|
|
||
|
Note that this function mutates the `data` argument.
|
||
|
|
||
|
This pure-python implementation may be replaced by an optimized
|
||
|
version when available.
|
||
|
|
||
|
"""
|
||
|
assert isinstance(data, bytearray), data
|
||
|
assert len(mask) == 4, mask
|
||
|
|
||
|
if data:
|
||
|
_XOR_TABLE = _xor_table()
|
||
|
a, b, c, d = (_XOR_TABLE[n] for n in mask)
|
||
|
data[::4] = data[::4].translate(a)
|
||
|
data[1::4] = data[1::4].translate(b)
|
||
|
data[2::4] = data[2::4].translate(c)
|
||
|
data[3::4] = data[3::4].translate(d)
|
||
|
|
||
|
|
||
|
if NO_EXTENSIONS: # pragma: no cover
|
||
|
_websocket_mask = _websocket_mask_python
|
||
|
else:
|
||
|
try:
|
||
|
from ._websocket import _websocket_mask_cython # type: ignore[import-not-found]
|
||
|
|
||
|
_websocket_mask = _websocket_mask_cython
|
||
|
except ImportError: # pragma: no cover
|
||
|
_websocket_mask = _websocket_mask_python
|
||
|
|
||
|
_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
|
||
|
|
||
|
|
||
|
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
|
||
|
r"^(?:;\s*(?:"
|
||
|
r"(server_no_context_takeover)|"
|
||
|
r"(client_no_context_takeover)|"
|
||
|
r"(server_max_window_bits(?:=(\d+))?)|"
|
||
|
r"(client_max_window_bits(?:=(\d+))?)))*$"
|
||
|
)
|
||
|
|
||
|
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
|
||
|
|
||
|
|
||
|
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
|
||
|
if not extstr:
|
||
|
return 0, False
|
||
|
|
||
|
compress = 0
|
||
|
notakeover = False
|
||
|
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
|
||
|
defext = ext.group(1)
|
||
|
# Return compress = 15 when get `permessage-deflate`
|
||
|
if not defext:
|
||
|
compress = 15
|
||
|
break
|
||
|
match = _WS_EXT_RE.match(defext)
|
||
|
if match:
|
||
|
compress = 15
|
||
|
if isserver:
|
||
|
# Server never fail to detect compress handshake.
|
||
|
# Server does not need to send max wbit to client
|
||
|
if match.group(4):
|
||
|
compress = int(match.group(4))
|
||
|
# Group3 must match if group4 matches
|
||
|
# Compress wbit 8 does not support in zlib
|
||
|
# If compress level not support,
|
||
|
# CONTINUE to next extension
|
||
|
if compress > 15 or compress < 9:
|
||
|
compress = 0
|
||
|
continue
|
||
|
if match.group(1):
|
||
|
notakeover = True
|
||
|
# Ignore regex group 5 & 6 for client_max_window_bits
|
||
|
break
|
||
|
else:
|
||
|
if match.group(6):
|
||
|
compress = int(match.group(6))
|
||
|
# Group5 must match if group6 matches
|
||
|
# Compress wbit 8 does not support in zlib
|
||
|
# If compress level not support,
|
||
|
# FAIL the parse progress
|
||
|
if compress > 15 or compress < 9:
|
||
|
raise WSHandshakeError("Invalid window size")
|
||
|
if match.group(2):
|
||
|
notakeover = True
|
||
|
# Ignore regex group 5 & 6 for client_max_window_bits
|
||
|
break
|
||
|
# Return Fail if client side and not match
|
||
|
elif not isserver:
|
||
|
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
|
||
|
|
||
|
return compress, notakeover
|
||
|
|
||
|
|
||
|
def ws_ext_gen(
|
||
|
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
|
||
|
) -> str:
|
||
|
# client_notakeover=False not used for server
|
||
|
# compress wbit 8 does not support in zlib
|
||
|
if compress < 9 or compress > 15:
|
||
|
raise ValueError(
|
||
|
"Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
|
||
|
)
|
||
|
enabledext = ["permessage-deflate"]
|
||
|
if not isserver:
|
||
|
enabledext.append("client_max_window_bits")
|
||
|
|
||
|
if compress < 15:
|
||
|
enabledext.append("server_max_window_bits=" + str(compress))
|
||
|
if server_notakeover:
|
||
|
enabledext.append("server_no_context_takeover")
|
||
|
# if client_notakeover:
|
||
|
# enabledext.append('client_no_context_takeover')
|
||
|
return "; ".join(enabledext)
|
||
|
|
||
|
|
||
|
class WSParserState(IntEnum):
|
||
|
READ_HEADER = 1
|
||
|
READ_PAYLOAD_LENGTH = 2
|
||
|
READ_PAYLOAD_MASK = 3
|
||
|
READ_PAYLOAD = 4
|
||
|
|
||
|
|
||
|
class WebSocketReader:
|
||
|
def __init__(
|
||
|
self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
|
||
|
) -> None:
|
||
|
self.queue = queue
|
||
|
self._max_msg_size = max_msg_size
|
||
|
|
||
|
self._exc: Optional[BaseException] = None
|
||
|
self._partial = bytearray()
|
||
|
self._state = WSParserState.READ_HEADER
|
||
|
|
||
|
self._opcode: Optional[int] = None
|
||
|
self._frame_fin = False
|
||
|
self._frame_opcode: Optional[int] = None
|
||
|
self._frame_payload = bytearray()
|
||
|
|
||
|
self._tail = b""
|
||
|
self._has_mask = False
|
||
|
self._frame_mask: Optional[bytes] = None
|
||
|
self._payload_length = 0
|
||
|
self._payload_length_flag = 0
|
||
|
self._compressed: Optional[bool] = None
|
||
|
self._decompressobj: Optional[ZLibDecompressor] = None
|
||
|
self._compress = compress
|
||
|
|
||
|
def feed_eof(self) -> None:
|
||
|
self.queue.feed_eof()
|
||
|
|
||
|
def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
|
||
|
if self._exc:
|
||
|
return True, data
|
||
|
|
||
|
try:
|
||
|
return self._feed_data(data)
|
||
|
except Exception as exc:
|
||
|
self._exc = exc
|
||
|
set_exception(self.queue, exc)
|
||
|
return True, b""
|
||
|
|
||
|
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
|
||
|
for fin, opcode, payload, compressed in self.parse_frame(data):
|
||
|
if compressed and not self._decompressobj:
|
||
|
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||
|
if opcode == WSMsgType.CLOSE:
|
||
|
if len(payload) >= 2:
|
||
|
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||
|
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
f"Invalid close code: {close_code}",
|
||
|
)
|
||
|
try:
|
||
|
close_message = payload[2:].decode("utf-8")
|
||
|
except UnicodeDecodeError as exc:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||
|
) from exc
|
||
|
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
|
||
|
elif payload:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||
|
)
|
||
|
else:
|
||
|
msg = WSMessage(WSMsgType.CLOSE, 0, "")
|
||
|
|
||
|
self.queue.feed_data(msg, 0)
|
||
|
|
||
|
elif opcode == WSMsgType.PING:
|
||
|
self.queue.feed_data(
|
||
|
WSMessage(WSMsgType.PING, payload, ""), len(payload)
|
||
|
)
|
||
|
|
||
|
elif opcode == WSMsgType.PONG:
|
||
|
self.queue.feed_data(
|
||
|
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
|
||
|
)
|
||
|
|
||
|
elif (
|
||
|
opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
|
||
|
and self._opcode is None
|
||
|
):
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||
|
)
|
||
|
else:
|
||
|
# load text/binary
|
||
|
if not fin:
|
||
|
# got partial frame payload
|
||
|
if opcode != WSMsgType.CONTINUATION:
|
||
|
self._opcode = opcode
|
||
|
self._partial.extend(payload)
|
||
|
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.MESSAGE_TOO_BIG,
|
||
|
"Message size {} exceeds limit {}".format(
|
||
|
len(self._partial), self._max_msg_size
|
||
|
),
|
||
|
)
|
||
|
else:
|
||
|
# previous frame was non finished
|
||
|
# we should get continuation opcode
|
||
|
if self._partial:
|
||
|
if opcode != WSMsgType.CONTINUATION:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
"The opcode in non-fin frame is expected "
|
||
|
"to be zero, got {!r}".format(opcode),
|
||
|
)
|
||
|
|
||
|
if opcode == WSMsgType.CONTINUATION:
|
||
|
assert self._opcode is not None
|
||
|
opcode = self._opcode
|
||
|
self._opcode = None
|
||
|
|
||
|
self._partial.extend(payload)
|
||
|
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.MESSAGE_TOO_BIG,
|
||
|
"Message size {} exceeds limit {}".format(
|
||
|
len(self._partial), self._max_msg_size
|
||
|
),
|
||
|
)
|
||
|
|
||
|
# Decompress process must to be done after all packets
|
||
|
# received.
|
||
|
if compressed:
|
||
|
assert self._decompressobj is not None
|
||
|
self._partial.extend(_WS_DEFLATE_TRAILING)
|
||
|
payload_merged = self._decompressobj.decompress_sync(
|
||
|
self._partial, self._max_msg_size
|
||
|
)
|
||
|
if self._decompressobj.unconsumed_tail:
|
||
|
left = len(self._decompressobj.unconsumed_tail)
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.MESSAGE_TOO_BIG,
|
||
|
"Decompressed message size {} exceeds limit {}".format(
|
||
|
self._max_msg_size + left, self._max_msg_size
|
||
|
),
|
||
|
)
|
||
|
else:
|
||
|
payload_merged = bytes(self._partial)
|
||
|
|
||
|
self._partial.clear()
|
||
|
|
||
|
if opcode == WSMsgType.TEXT:
|
||
|
try:
|
||
|
text = payload_merged.decode("utf-8")
|
||
|
self.queue.feed_data(
|
||
|
WSMessage(WSMsgType.TEXT, text, ""), len(text)
|
||
|
)
|
||
|
except UnicodeDecodeError as exc:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||
|
) from exc
|
||
|
else:
|
||
|
self.queue.feed_data(
|
||
|
WSMessage(WSMsgType.BINARY, payload_merged, ""),
|
||
|
len(payload_merged),
|
||
|
)
|
||
|
|
||
|
return False, b""
|
||
|
|
||
|
def parse_frame(
|
||
|
self, buf: bytes
|
||
|
) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
|
||
|
"""Return the next frame from the socket."""
|
||
|
frames = []
|
||
|
if self._tail:
|
||
|
buf, self._tail = self._tail + buf, b""
|
||
|
|
||
|
start_pos = 0
|
||
|
buf_length = len(buf)
|
||
|
|
||
|
while True:
|
||
|
# read header
|
||
|
if self._state == WSParserState.READ_HEADER:
|
||
|
if buf_length - start_pos >= 2:
|
||
|
data = buf[start_pos : start_pos + 2]
|
||
|
start_pos += 2
|
||
|
first_byte, second_byte = data
|
||
|
|
||
|
fin = (first_byte >> 7) & 1
|
||
|
rsv1 = (first_byte >> 6) & 1
|
||
|
rsv2 = (first_byte >> 5) & 1
|
||
|
rsv3 = (first_byte >> 4) & 1
|
||
|
opcode = first_byte & 0xF
|
||
|
|
||
|
# frame-fin = %x0 ; more frames of this message follow
|
||
|
# / %x1 ; final frame of this message
|
||
|
# frame-rsv1 = %x0 ;
|
||
|
# 1 bit, MUST be 0 unless negotiated otherwise
|
||
|
# frame-rsv2 = %x0 ;
|
||
|
# 1 bit, MUST be 0 unless negotiated otherwise
|
||
|
# frame-rsv3 = %x0 ;
|
||
|
# 1 bit, MUST be 0 unless negotiated otherwise
|
||
|
#
|
||
|
# Remove rsv1 from this test for deflate development
|
||
|
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
"Received frame with non-zero reserved bits",
|
||
|
)
|
||
|
|
||
|
if opcode > 0x7 and fin == 0:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
"Received fragmented control frame",
|
||
|
)
|
||
|
|
||
|
has_mask = (second_byte >> 7) & 1
|
||
|
length = second_byte & 0x7F
|
||
|
|
||
|
# Control frames MUST have a payload
|
||
|
# length of 125 bytes or less
|
||
|
if opcode > 0x7 and length > 125:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
"Control frame payload cannot be " "larger than 125 bytes",
|
||
|
)
|
||
|
|
||
|
# Set compress status if last package is FIN
|
||
|
# OR set compress status if this is first fragment
|
||
|
# Raise error if not first fragment with rsv1 = 0x1
|
||
|
if self._frame_fin or self._compressed is None:
|
||
|
self._compressed = True if rsv1 else False
|
||
|
elif rsv1:
|
||
|
raise WebSocketError(
|
||
|
WSCloseCode.PROTOCOL_ERROR,
|
||
|
"Received frame with non-zero reserved bits",
|
||
|
)
|
||
|
|
||
|
self._frame_fin = bool(fin)
|
||
|
self._frame_opcode = opcode
|
||
|
self._has_mask = bool(has_mask)
|
||
|
self._payload_length_flag = length
|
||
|
self._state = WSParserState.READ_PAYLOAD_LENGTH
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
# read payload length
|
||
|
if self._state == WSParserState.READ_PAYLOAD_LENGTH:
|
||
|
length = self._payload_length_flag
|
||
|
if length == 126:
|
||
|
if buf_length - start_pos >= 2:
|
||
|
data = buf[start_pos : start_pos + 2]
|
||
|
start_pos += 2
|
||
|
length = UNPACK_LEN2(data)[0]
|
||
|
self._payload_length = length
|
||
|
self._state = (
|
||
|
WSParserState.READ_PAYLOAD_MASK
|
||
|
if self._has_mask
|
||
|
else WSParserState.READ_PAYLOAD
|
||
|
)
|
||
|
else:
|
||
|
break
|
||
|
elif length > 126:
|
||
|
if buf_length - start_pos >= 8:
|
||
|
data = buf[start_pos : start_pos + 8]
|
||
|
start_pos += 8
|
||
|
length = UNPACK_LEN3(data)[0]
|
||
|
self._payload_length = length
|
||
|
self._state = (
|
||
|
WSParserState.READ_PAYLOAD_MASK
|
||
|
if self._has_mask
|
||
|
else WSParserState.READ_PAYLOAD
|
||
|
)
|
||
|
else:
|
||
|
break
|
||
|
else:
|
||
|
self._payload_length = length
|
||
|
self._state = (
|
||
|
WSParserState.READ_PAYLOAD_MASK
|
||
|
if self._has_mask
|
||
|
else WSParserState.READ_PAYLOAD
|
||
|
)
|
||
|
|
||
|
# read payload mask
|
||
|
if self._state == WSParserState.READ_PAYLOAD_MASK:
|
||
|
if buf_length - start_pos >= 4:
|
||
|
self._frame_mask = buf[start_pos : start_pos + 4]
|
||
|
start_pos += 4
|
||
|
self._state = WSParserState.READ_PAYLOAD
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
if self._state == WSParserState.READ_PAYLOAD:
|
||
|
length = self._payload_length
|
||
|
payload = self._frame_payload
|
||
|
|
||
|
chunk_len = buf_length - start_pos
|
||
|
if length >= chunk_len:
|
||
|
self._payload_length = length - chunk_len
|
||
|
payload.extend(buf[start_pos:])
|
||
|
start_pos = buf_length
|
||
|
else:
|
||
|
self._payload_length = 0
|
||
|
payload.extend(buf[start_pos : start_pos + length])
|
||
|
start_pos = start_pos + length
|
||
|
|
||
|
if self._payload_length == 0:
|
||
|
if self._has_mask:
|
||
|
assert self._frame_mask is not None
|
||
|
_websocket_mask(self._frame_mask, payload)
|
||
|
|
||
|
frames.append(
|
||
|
(self._frame_fin, self._frame_opcode, payload, self._compressed)
|
||
|
)
|
||
|
|
||
|
self._frame_payload = bytearray()
|
||
|
self._state = WSParserState.READ_HEADER
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
self._tail = buf[start_pos:]
|
||
|
|
||
|
return frames
|
||
|
|
||
|
|
||
|
class WebSocketWriter:
|
||
|
def __init__(
|
||
|
self,
|
||
|
protocol: BaseProtocol,
|
||
|
transport: asyncio.Transport,
|
||
|
*,
|
||
|
use_mask: bool = False,
|
||
|
limit: int = DEFAULT_LIMIT,
|
||
|
random: random.Random = random.Random(),
|
||
|
compress: int = 0,
|
||
|
notakeover: bool = False,
|
||
|
) -> None:
|
||
|
self.protocol = protocol
|
||
|
self.transport = transport
|
||
|
self.use_mask = use_mask
|
||
|
self.randrange = random.randrange
|
||
|
self.compress = compress
|
||
|
self.notakeover = notakeover
|
||
|
self._closing = False
|
||
|
self._limit = limit
|
||
|
self._output_size = 0
|
||
|
self._compressobj: Any = None # actually compressobj
|
||
|
|
||
|
async def _send_frame(
|
||
|
self, message: bytes, opcode: int, compress: Optional[int] = None
|
||
|
) -> None:
|
||
|
"""Send a frame over the websocket with message as its payload."""
|
||
|
if self._closing and not (opcode & WSMsgType.CLOSE):
|
||
|
raise ConnectionResetError("Cannot write to closing transport")
|
||
|
|
||
|
rsv = 0
|
||
|
|
||
|
# Only compress larger packets (disabled)
|
||
|
# Does small packet needs to be compressed?
|
||
|
# if self.compress and opcode < 8 and len(message) > 124:
|
||
|
if (compress or self.compress) and opcode < 8:
|
||
|
if compress:
|
||
|
# Do not set self._compress if compressing is for this frame
|
||
|
compressobj = self._make_compress_obj(compress)
|
||
|
else: # self.compress
|
||
|
if not self._compressobj:
|
||
|
self._compressobj = self._make_compress_obj(self.compress)
|
||
|
compressobj = self._compressobj
|
||
|
|
||
|
message = await compressobj.compress(message)
|
||
|
# Its critical that we do not return control to the event
|
||
|
# loop until we have finished sending all the compressed
|
||
|
# data. Otherwise we could end up mixing compressed frames
|
||
|
# if there are multiple coroutines compressing data.
|
||
|
message += compressobj.flush(
|
||
|
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
|
||
|
)
|
||
|
if message.endswith(_WS_DEFLATE_TRAILING):
|
||
|
message = message[:-4]
|
||
|
rsv = rsv | 0x40
|
||
|
|
||
|
msg_length = len(message)
|
||
|
|
||
|
use_mask = self.use_mask
|
||
|
if use_mask:
|
||
|
mask_bit = 0x80
|
||
|
else:
|
||
|
mask_bit = 0
|
||
|
|
||
|
if msg_length < 126:
|
||
|
header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
|
||
|
elif msg_length < (1 << 16):
|
||
|
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
|
||
|
else:
|
||
|
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
|
||
|
if use_mask:
|
||
|
mask_int = self.randrange(0, 0xFFFFFFFF)
|
||
|
mask = mask_int.to_bytes(4, "big")
|
||
|
message = bytearray(message)
|
||
|
_websocket_mask(mask, message)
|
||
|
self._write(header + mask + message)
|
||
|
self._output_size += len(header) + len(mask) + msg_length
|
||
|
else:
|
||
|
if msg_length > MSG_SIZE:
|
||
|
self._write(header)
|
||
|
self._write(message)
|
||
|
else:
|
||
|
self._write(header + message)
|
||
|
|
||
|
self._output_size += len(header) + msg_length
|
||
|
|
||
|
# It is safe to return control to the event loop when using compression
|
||
|
# after this point as we have already sent or buffered all the data.
|
||
|
|
||
|
if self._output_size > self._limit:
|
||
|
self._output_size = 0
|
||
|
await self.protocol._drain_helper()
|
||
|
|
||
|
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
|
||
|
return ZLibCompressor(
|
||
|
level=zlib.Z_BEST_SPEED,
|
||
|
wbits=-compress,
|
||
|
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
||
|
)
|
||
|
|
||
|
def _write(self, data: bytes) -> None:
|
||
|
if self.transport is None or self.transport.is_closing():
|
||
|
raise ConnectionResetError("Cannot write to closing transport")
|
||
|
self.transport.write(data)
|
||
|
|
||
|
async def pong(self, message: Union[bytes, str] = b"") -> None:
|
||
|
"""Send pong message."""
|
||
|
if isinstance(message, str):
|
||
|
message = message.encode("utf-8")
|
||
|
await self._send_frame(message, WSMsgType.PONG)
|
||
|
|
||
|
async def ping(self, message: Union[bytes, str] = b"") -> None:
|
||
|
"""Send ping message."""
|
||
|
if isinstance(message, str):
|
||
|
message = message.encode("utf-8")
|
||
|
await self._send_frame(message, WSMsgType.PING)
|
||
|
|
||
|
async def send(
|
||
|
self,
|
||
|
message: Union[str, bytes],
|
||
|
binary: bool = False,
|
||
|
compress: Optional[int] = None,
|
||
|
) -> None:
|
||
|
"""Send a frame over the websocket with message as its payload."""
|
||
|
if isinstance(message, str):
|
||
|
message = message.encode("utf-8")
|
||
|
if binary:
|
||
|
await self._send_frame(message, WSMsgType.BINARY, compress)
|
||
|
else:
|
||
|
await self._send_frame(message, WSMsgType.TEXT, compress)
|
||
|
|
||
|
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
|
||
|
"""Close the websocket, sending the specified code and message."""
|
||
|
if isinstance(message, str):
|
||
|
message = message.encode("utf-8")
|
||
|
try:
|
||
|
await self._send_frame(
|
||
|
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
|
||
|
)
|
||
|
finally:
|
||
|
self._closing = True
|