|
|
|
import asyncio
|
|
|
|
import json
|
|
|
|
import sys
|
|
|
|
import warnings
|
|
|
|
from collections import namedtuple
|
|
|
|
|
|
|
|
from . import Timeout, hdrs
|
|
|
|
from ._ws_impl import (CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType,
|
|
|
|
do_handshake)
|
|
|
|
from .errors import ClientDisconnectedError, HttpProcessingError
|
|
|
|
from .web_exceptions import (HTTPBadRequest, HTTPInternalServerError,
|
|
|
|
HTTPMethodNotAllowed)
|
|
|
|
from .web_reqrep import StreamResponse
|
|
|
|
|
|
|
|
__all__ = ('WebSocketResponse', 'WebSocketReady', 'MsgType', 'WSMsgType',)
|
|
|
|
|
|
|
|
PY_35 = sys.version_info >= (3, 5)
|
|
|
|
PY_352 = sys.version_info >= (3, 5, 2)
|
|
|
|
|
|
|
|
THRESHOLD_CONNLOST_ACCESS = 5
|
|
|
|
|
|
|
|
|
|
|
|
# deprecated since 1.0
|
|
|
|
MsgType = WSMsgType
|
|
|
|
|
|
|
|
|
|
|
|
class WebSocketReady(namedtuple('WebSocketReady', 'ok protocol')):
|
|
|
|
def __bool__(self):
|
|
|
|
return self.ok
|
|
|
|
|
|
|
|
|
|
|
|
class WebSocketResponse(StreamResponse):
|
|
|
|
|
|
|
|
def __init__(self, *,
|
|
|
|
timeout=10.0, autoclose=True, autoping=True, protocols=()):
|
|
|
|
super().__init__(status=101)
|
|
|
|
self._protocols = protocols
|
|
|
|
self._protocol = None
|
|
|
|
self._writer = None
|
|
|
|
self._reader = None
|
|
|
|
self._closed = False
|
|
|
|
self._closing = False
|
|
|
|
self._conn_lost = 0
|
|
|
|
self._close_code = None
|
|
|
|
self._loop = None
|
|
|
|
self._waiting = False
|
|
|
|
self._exception = None
|
|
|
|
self._timeout = timeout
|
|
|
|
self._autoclose = autoclose
|
|
|
|
self._autoping = autoping
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def prepare(self, request):
|
|
|
|
# make pre-check to don't hide it by do_handshake() exceptions
|
|
|
|
resp_impl = self._start_pre_check(request)
|
|
|
|
if resp_impl is not None:
|
|
|
|
return resp_impl
|
|
|
|
|
|
|
|
parser, protocol, writer = self._pre_start(request)
|
|
|
|
resp_impl = yield from super().prepare(request)
|
|
|
|
self._post_start(request, parser, protocol, writer)
|
|
|
|
return resp_impl
|
|
|
|
|
|
|
|
def _pre_start(self, request):
|
|
|
|
try:
|
|
|
|
status, headers, parser, writer, protocol = do_handshake(
|
|
|
|
request.method, request.headers, request.transport,
|
|
|
|
self._protocols)
|
|
|
|
except HttpProcessingError as err:
|
|
|
|
if err.code == 405:
|
|
|
|
raise HTTPMethodNotAllowed(
|
|
|
|
request.method, [hdrs.METH_GET], body=b'')
|
|
|
|
elif err.code == 400:
|
|
|
|
raise HTTPBadRequest(text=err.message, headers=err.headers)
|
|
|
|
else: # pragma: no cover
|
|
|
|
raise HTTPInternalServerError() from err
|
|
|
|
|
|
|
|
if self.status != status:
|
|
|
|
self.set_status(status)
|
|
|
|
for k, v in headers:
|
|
|
|
self.headers[k] = v
|
|
|
|
self.force_close()
|
|
|
|
return parser, protocol, writer
|
|
|
|
|
|
|
|
def _post_start(self, request, parser, protocol, writer):
|
|
|
|
self._reader = request._reader.set_parser(parser)
|
|
|
|
self._writer = writer
|
|
|
|
self._protocol = protocol
|
|
|
|
self._loop = request.app.loop
|
|
|
|
|
|
|
|
def start(self, request):
|
|
|
|
warnings.warn('use .prepare(request) instead', DeprecationWarning)
|
|
|
|
# make pre-check to don't hide it by do_handshake() exceptions
|
|
|
|
resp_impl = self._start_pre_check(request)
|
|
|
|
if resp_impl is not None:
|
|
|
|
return resp_impl
|
|
|
|
|
|
|
|
parser, protocol, writer = self._pre_start(request)
|
|
|
|
resp_impl = super().start(request)
|
|
|
|
self._post_start(request, parser, protocol, writer)
|
|
|
|
return resp_impl
|
|
|
|
|
|
|
|
def can_prepare(self, request):
|
|
|
|
if self._writer is not None:
|
|
|
|
raise RuntimeError('Already started')
|
|
|
|
try:
|
|
|
|
_, _, _, _, protocol = do_handshake(
|
|
|
|
request.method, request.headers, request.transport,
|
|
|
|
self._protocols)
|
|
|
|
except HttpProcessingError:
|
|
|
|
return WebSocketReady(False, None)
|
|
|
|
else:
|
|
|
|
return WebSocketReady(True, protocol)
|
|
|
|
|
|
|
|
def can_start(self, request):
|
|
|
|
warnings.warn('use .can_prepare(request) instead', DeprecationWarning)
|
|
|
|
return self.can_prepare(request)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def closed(self):
|
|
|
|
return self._closed
|
|
|
|
|
|
|
|
@property
|
|
|
|
def close_code(self):
|
|
|
|
return self._close_code
|
|
|
|
|
|
|
|
@property
|
|
|
|
def protocol(self):
|
|
|
|
return self._protocol
|
|
|
|
|
|
|
|
def exception(self):
|
|
|
|
return self._exception
|
|
|
|
|
|
|
|
def ping(self, message='b'):
|
|
|
|
if self._writer is None:
|
|
|
|
raise RuntimeError('Call .prepare() first')
|
|
|
|
if self._closed:
|
|
|
|
raise RuntimeError('websocket connection is closing')
|
|
|
|
self._writer.ping(message)
|
|
|
|
|
|
|
|
def pong(self, message='b'):
|
|
|
|
# unsolicited pong
|
|
|
|
if self._writer is None:
|
|
|
|
raise RuntimeError('Call .prepare() first')
|
|
|
|
if self._closed:
|
|
|
|
raise RuntimeError('websocket connection is closing')
|
|
|
|
self._writer.pong(message)
|
|
|
|
|
|
|
|
def send_str(self, data):
|
|
|
|
if self._writer is None:
|
|
|
|
raise RuntimeError('Call .prepare() first')
|
|
|
|
if self._closed:
|
|
|
|
raise RuntimeError('websocket connection is closing')
|
|
|
|
if not isinstance(data, str):
|
|
|
|
raise TypeError('data argument must be str (%r)' % type(data))
|
|
|
|
self._writer.send(data, binary=False)
|
|
|
|
|
|
|
|
def send_bytes(self, data):
|
|
|
|
if self._writer is None:
|
|
|
|
raise RuntimeError('Call .prepare() first')
|
|
|
|
if self._closed:
|
|
|
|
raise RuntimeError('websocket connection is closing')
|
|
|
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
|
|
|
raise TypeError('data argument must be byte-ish (%r)' %
|
|
|
|
type(data))
|
|
|
|
self._writer.send(data, binary=True)
|
|
|
|
|
|
|
|
def send_json(self, data, *, dumps=json.dumps):
|
|
|
|
self.send_str(dumps(data))
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def write_eof(self):
|
|
|
|
if self._eof_sent:
|
|
|
|
return
|
|
|
|
if self._resp_impl is None:
|
|
|
|
raise RuntimeError("Response has not been started")
|
|
|
|
|
|
|
|
yield from self.close()
|
|
|
|
self._eof_sent = True
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def close(self, *, code=1000, message=b''):
|
|
|
|
if self._writer is None:
|
|
|
|
raise RuntimeError('Call .prepare() first')
|
|
|
|
|
|
|
|
if not self._closed:
|
|
|
|
self._closed = True
|
|
|
|
try:
|
|
|
|
self._writer.close(code, message)
|
|
|
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
|
|
self._close_code = 1006
|
|
|
|
raise
|
|
|
|
except Exception as exc:
|
|
|
|
self._close_code = 1006
|
|
|
|
self._exception = exc
|
|
|
|
return True
|
|
|
|
|
|
|
|
if self._closing:
|
|
|
|
return True
|
|
|
|
|
|
|
|
begin = self._loop.time()
|
|
|
|
while self._loop.time() - begin < self._timeout:
|
|
|
|
try:
|
|
|
|
with Timeout(timeout=self._timeout,
|
|
|
|
loop=self._loop):
|
|
|
|
msg = yield from self._reader.read()
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
self._close_code = 1006
|
|
|
|
raise
|
|
|
|
except Exception as exc:
|
|
|
|
self._close_code = 1006
|
|
|
|
self._exception = exc
|
|
|
|
return True
|
|
|
|
|
|
|
|
if msg.type == WSMsgType.CLOSE:
|
|
|
|
self._close_code = msg.data
|
|
|
|
return True
|
|
|
|
|
|
|
|
self._close_code = 1006
|
|
|
|
self._exception = asyncio.TimeoutError()
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def receive(self):
|
|
|
|
if self._reader is None:
|
|
|
|
raise RuntimeError('Call .prepare() first')
|
|
|
|
if self._waiting:
|
|
|
|
raise RuntimeError('Concurrent call to receive() is not allowed')
|
|
|
|
|
|
|
|
self._waiting = True
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
if self._closed:
|
|
|
|
self._conn_lost += 1
|
|
|
|
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
|
|
|
|
raise RuntimeError('WebSocket connection is closed.')
|
|
|
|
return CLOSED_MESSAGE
|
|
|
|
|
|
|
|
try:
|
|
|
|
msg = yield from self._reader.read()
|
|
|
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
|
|
raise
|
|
|
|
except WebSocketError as exc:
|
|
|
|
self._close_code = exc.code
|
|
|
|
yield from self.close(code=exc.code)
|
|
|
|
return WSMessage(WSMsgType.ERROR, exc, None)
|
|
|
|
except ClientDisconnectedError:
|
|
|
|
self._closed = True
|
|
|
|
self._close_code = 1006
|
|
|
|
return WSMessage(WSMsgType.CLOSE, None, None)
|
|
|
|
except Exception as exc:
|
|
|
|
self._exception = exc
|
|
|
|
self._closing = True
|
|
|
|
self._close_code = 1006
|
|
|
|
yield from self.close()
|
|
|
|
return WSMessage(WSMsgType.ERROR, exc, None)
|
|
|
|
|
|
|
|
if msg.type == WSMsgType.CLOSE:
|
|
|
|
self._closing = True
|
|
|
|
self._close_code = msg.data
|
|
|
|
if not self._closed and self._autoclose:
|
|
|
|
yield from self.close()
|
|
|
|
return msg
|
|
|
|
if msg.type == WSMsgType.PING and self._autoping:
|
|
|
|
self.pong(msg.data)
|
|
|
|
elif msg.type == WSMsgType.PONG and self._autoping:
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
return msg
|
|
|
|
finally:
|
|
|
|
self._waiting = False
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def receive_msg(self):
|
|
|
|
warnings.warn(
|
|
|
|
'receive_msg() coroutine is deprecated. use receive() instead',
|
|
|
|
DeprecationWarning)
|
|
|
|
return (yield from self.receive())
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def receive_str(self):
|
|
|
|
msg = yield from self.receive()
|
|
|
|
if msg.type != WSMsgType.TEXT:
|
|
|
|
raise TypeError(
|
|
|
|
"Received message {}:{!r} is not str".format(msg.type,
|
|
|
|
msg.data))
|
|
|
|
return msg.data
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def receive_bytes(self):
|
|
|
|
msg = yield from self.receive()
|
|
|
|
if msg.type != WSMsgType.BINARY:
|
|
|
|
raise TypeError(
|
|
|
|
"Received message {}:{!r} is not bytes".format(msg.type,
|
|
|
|
msg.data))
|
|
|
|
return msg.data
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def receive_json(self, *, loads=json.loads):
|
|
|
|
data = yield from self.receive_str()
|
|
|
|
return loads(data)
|
|
|
|
|
|
|
|
def write(self, data):
|
|
|
|
raise RuntimeError("Cannot call .write() for websocket")
|
|
|
|
|
|
|
|
if PY_35:
|
|
|
|
def __aiter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
if not PY_352: # pragma: no cover
|
|
|
|
__aiter__ = asyncio.coroutine(__aiter__)
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def __anext__(self):
|
|
|
|
msg = yield from self.receive()
|
|
|
|
if msg.type == WSMsgType.CLOSE:
|
|
|
|
raise StopAsyncIteration # NOQA
|
|
|
|
return msg
|