mirror of https://github.com/sgoudham/Enso-Bot.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.2 KiB
Python
147 lines
5.2 KiB
Python
import asyncio
|
|
import unittest
|
|
import unittest.mock
|
|
|
|
from .exceptions import PayloadTooBig, WebSocketProtocolError
|
|
from .framing import *
|
|
|
|
|
|
class FramingTests(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
def tearDown(self):
|
|
self.loop.close()
|
|
|
|
def decode(self, message, mask=False, max_size=None):
|
|
self.stream = asyncio.StreamReader(loop=self.loop)
|
|
self.stream.feed_data(message)
|
|
self.stream.feed_eof()
|
|
frame = self.loop.run_until_complete(read_frame(
|
|
self.stream.readexactly, mask, max_size=max_size))
|
|
# Make sure all the data was consumed.
|
|
self.assertTrue(self.stream.at_eof())
|
|
return frame
|
|
|
|
def encode(self, frame, mask=False):
|
|
writer = unittest.mock.Mock()
|
|
write_frame(frame, writer, mask)
|
|
# Ensure the entire frame is sent with a single call to writer().
|
|
# Multiple calls cause TCP fragmentation and degrade performance.
|
|
self.assertEqual(writer.call_count, 1)
|
|
# The frame data is the single positional argument of that call.
|
|
return writer.call_args[0][0]
|
|
|
|
def round_trip(self, message, expected, mask=False):
|
|
decoded = self.decode(message, mask)
|
|
self.assertEqual(decoded, expected)
|
|
encoded = self.encode(decoded, mask)
|
|
if mask: # non-deterministic encoding
|
|
decoded = self.decode(encoded, mask)
|
|
self.assertEqual(decoded, expected)
|
|
else: # deterministic encoding
|
|
self.assertEqual(encoded, message)
|
|
|
|
def round_trip_close(self, data, code, reason):
|
|
parsed = parse_close(data)
|
|
self.assertEqual(parsed, (code, reason))
|
|
serialized = serialize_close(code, reason)
|
|
self.assertEqual(serialized, data)
|
|
|
|
def test_text(self):
|
|
self.round_trip(b'\x81\x04Spam', Frame(True, OP_TEXT, b'Spam'))
|
|
|
|
def test_text_masked(self):
|
|
self.round_trip(
|
|
b'\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5',
|
|
Frame(True, OP_TEXT, b'Spam'), mask=True)
|
|
|
|
def test_binary(self):
|
|
self.round_trip(b'\x82\x04Eggs', Frame(True, OP_BINARY, b'Eggs'))
|
|
|
|
def test_binary_masked(self):
|
|
self.round_trip(
|
|
b'\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa',
|
|
Frame(True, OP_BINARY, b'Eggs'), mask=True)
|
|
|
|
def test_non_ascii_text(self):
|
|
self.round_trip(
|
|
b'\x81\x05caf\xc3\xa9',
|
|
Frame(True, OP_TEXT, 'café'.encode('utf-8')))
|
|
|
|
def test_non_ascii_text_masked(self):
|
|
self.round_trip(
|
|
b'\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd',
|
|
Frame(True, OP_TEXT, 'café'.encode('utf-8')), mask=True)
|
|
|
|
def test_close(self):
|
|
self.round_trip(b'\x88\x00', Frame(True, OP_CLOSE, b''))
|
|
|
|
def test_ping(self):
|
|
self.round_trip(b'\x89\x04ping', Frame(True, OP_PING, b'ping'))
|
|
|
|
def test_pong(self):
|
|
self.round_trip(b'\x8a\x04pong', Frame(True, OP_PONG, b'pong'))
|
|
|
|
def test_long(self):
|
|
self.round_trip(
|
|
b'\x82\x7e\x00\x7e' + 126 * b'a',
|
|
Frame(True, OP_BINARY, 126 * b'a'))
|
|
|
|
def test_very_long(self):
|
|
self.round_trip(
|
|
b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a',
|
|
Frame(True, OP_BINARY, 65536 * b'a'))
|
|
|
|
def test_payload_too_big(self):
|
|
with self.assertRaises(PayloadTooBig):
|
|
self.decode(b'\x82\x7e\x04\x01' + 1025 * b'a', max_size=1024)
|
|
|
|
def test_bad_reserved_bits(self):
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\xc0\x00')
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\xa0\x00')
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\x90\x00')
|
|
|
|
def test_bad_opcode(self):
|
|
for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0b)):
|
|
self.decode(bytes([0x80 | opcode, 0]))
|
|
for opcode in list(range(0x03, 0x08)) + list(range(0x0b, 0x10)):
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(bytes([0x80 | opcode, 0]))
|
|
|
|
def test_bad_mask_flag(self):
|
|
self.decode(b'\x80\x80\x00\x00\x00\x00', mask=True)
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\x80\x80\x00\x00\x00\x00')
|
|
self.decode(b'\x80\x00')
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\x80\x00', mask=True)
|
|
|
|
def test_control_frame_too_long(self):
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\x88\x7e\x00\x7e' + 126 * b'a')
|
|
|
|
def test_fragmented_control_frame(self):
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
self.decode(b'\x08\x00')
|
|
|
|
def test_parse_close(self):
|
|
self.round_trip_close(b'\x03\xe8', 1000, '')
|
|
self.round_trip_close(b'\x03\xe8OK', 1000, 'OK')
|
|
|
|
def test_parse_close_empty(self):
|
|
self.assertEqual(parse_close(b''), (1005, ''))
|
|
|
|
def test_parse_close_errors(self):
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
parse_close(b'\x03')
|
|
with self.assertRaises(WebSocketProtocolError):
|
|
parse_close(b'\x03\xe7')
|
|
with self.assertRaises(UnicodeDecodeError):
|
|
parse_close(b'\x03\xe8\xff\xff')
|