Add tests for jebp_utils
This commit is contained in:
@@ -90,7 +90,7 @@ async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce:
|
||||
return aesgcm.decrypt(aesnonce, m, None)
|
||||
return m
|
||||
|
||||
def pack_fields(*fields: bytes) -> bytes:
|
||||
def pack_fields(fields: list[bytes]) -> bytes:
|
||||
'''
|
||||
Packs the fields into a compact bytestring.
|
||||
|
||||
@@ -100,11 +100,12 @@ def pack_fields(*fields: bytes) -> bytes:
|
||||
:return: The result
|
||||
:rtype: bytes
|
||||
'''
|
||||
|
||||
result = b''
|
||||
for field in fields:
|
||||
field_length = utils.int_to_bytes(len(field))
|
||||
field_length_length = utils.int_to_bytes(len(field_length))
|
||||
result += field_length_length + field_length + field
|
||||
payload_length = utils.int_to_bytes(len(field))
|
||||
payload_length_length = len(payload_length).to_bytes(1)
|
||||
result += payload_length_length + payload_length + field
|
||||
return result
|
||||
|
||||
def unpack_fields(raw: bytes) -> Sequence[bytes]:
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
from src.jeb_utils.jebp_utils import pack_fields
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize('fields,expected', [
|
||||
([], b''),
|
||||
([b''], b'\x01\x00'),
|
||||
([b'', b''], b'\x01\x00\x01\x00'),
|
||||
([b'', b'a'], b'\x01\x00\x01\x01a'),
|
||||
([b'', b'ab'], b'\x01\x00\x01\x02ab'),
|
||||
([b'a', b'ab'], b'\x01\x01a\x01\x02ab'),
|
||||
([b'a', b'ab', b'', b'Hello, World!'], b'\x01\x01a\x01\x02ab\x01\x00\x01\x0dHello, World!'),
|
||||
([b'a' * 255], b'\x01\xff' + b'a' * 255),
|
||||
([b'a' * 255, b'a' * 255], (b'\x01\xff' + b'a' * 255) * 2),
|
||||
([b'a' * 256], b'\x02\x01\x00' + b'a' * 256),
|
||||
([b'a' * 256, b'a' * 256], (b'\x02\x01\x00' + b'a' * 256) * 2),
|
||||
([b'a' * 257], b'\x02\x01\x01' + b'a' * 257),
|
||||
([b'a' * 257, b'abc'], b'\x02\x01\x01' + b'a' * 257 + b'\x01\x03abc'),
|
||||
])
|
||||
def test_pack_fields(fields, expected):
|
||||
assert pack_fields(fields) == expected
|
||||
@@ -0,0 +1,44 @@
|
||||
from src.jeb_utils.jebp_utils import readmsg, _MSG_START_BYTE, MessageFormatError
|
||||
import pytest
|
||||
|
||||
class AsyncTestStreamReader:
|
||||
def __init__(self, content: bytes) -> None:
|
||||
self.content = content
|
||||
|
||||
async def readexactly(self, n: int) -> bytes:
|
||||
if len(self.content) < n:
|
||||
raise EOFError
|
||||
result = self.content[:n]
|
||||
self.content = self.content[n:]
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('expected,reader,start_byte', [
|
||||
(b'', AsyncTestStreamReader(b'\x01\x01\x00'), b'\x01'),
|
||||
(b'', AsyncTestStreamReader(b'\x02\x01\x00'), b'\x02'),
|
||||
(b'a', AsyncTestStreamReader(b'\x02\x01\x01a'), b'\x02'),
|
||||
(b'', AsyncTestStreamReader(b'\x01\x00'), b'\x01'),
|
||||
(b'a' * 255, AsyncTestStreamReader(b'\x01\x01\xff' + b'a' * 255), b'\x01'),
|
||||
(b'a' * 256, AsyncTestStreamReader(b'\x01\x02\x01\x00' + b'a' * 256), b'\x01'),
|
||||
(b'a' * 256, AsyncTestStreamReader(b'\x01\x02\x01\x00' + b'a' * 257), b'\x01'),
|
||||
(b'a' * 257, AsyncTestStreamReader(b'\x01\x02\x01\x01' + b'a' * 257), b'\x01'),
|
||||
])
|
||||
async def test_readmsg(expected, reader, start_byte):
|
||||
assert await readmsg(reader, start_byte = start_byte) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('expected_exception,reader,start_byte', [
|
||||
(EOFError, AsyncTestStreamReader(b'\x01\x01\x01'), b'\x01'),
|
||||
(EOFError, AsyncTestStreamReader(b'\x01\x02\x01a'), b'\x01'),
|
||||
(EOFError, AsyncTestStreamReader(b'\x01'), b'\x01'),
|
||||
(EOFError, AsyncTestStreamReader(b'\x01\x01'), b'\x01'),
|
||||
(EOFError, AsyncTestStreamReader(b'\x01\x02\x00'), b'\x01'),
|
||||
(MessageFormatError, AsyncTestStreamReader(b'\x01\x02\x01a'), b'\x02'),
|
||||
(MessageFormatError, AsyncTestStreamReader(b'\x01\x01\x01'), b'\x02'),
|
||||
(MessageFormatError, AsyncTestStreamReader(b'\x01'), b'\x02'),
|
||||
])
|
||||
async def test_readmsg_exceptions(expected_exception, reader, start_byte):
|
||||
with pytest.raises(expected_exception):
|
||||
await readmsg(reader, start_byte = start_byte)
|
||||
@@ -0,0 +1,29 @@
|
||||
from src.jeb_utils.jebp_utils import sendmsg
|
||||
import pytest
|
||||
|
||||
class AsyncTestStreamWriter:
|
||||
def __init__(self):
|
||||
self.content = b''
|
||||
|
||||
def write(self, data):
|
||||
self.content += data
|
||||
|
||||
async def drain(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('expected,message,send_headers,_content_length,start_byte', [
|
||||
(b'\x01\x01\x00', b'', True, None, b'\x01'),
|
||||
(b'\x02\x01\x00', b'', True, None, b'\x02'),
|
||||
(b'\x01\x01\x01a', b'a', True, None, b'\x01'),
|
||||
(b'\x01a', b'a', False, None, b'\x01'),
|
||||
(b'\x01a', b'a', False, 42, b'\x01'),
|
||||
(b'\x01\x01\x2aa', b'a', True, 42, b'\x01'),
|
||||
(b'\x01\x0209a', b'a', True, 12345, b'\x01'),
|
||||
(b'\x01a', b'a', False, 42, b'\x01'),
|
||||
])
|
||||
async def test_sendmsg(expected, message, send_headers, _content_length, start_byte):
|
||||
writer = AsyncTestStreamWriter()
|
||||
await sendmsg(message, writer, send_headers = send_headers, _content_length = _content_length, start_byte = start_byte)
|
||||
assert writer.content == expected
|
||||
@@ -0,0 +1,21 @@
|
||||
from src.jeb_utils.jebp_utils import unpack_fields
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize('packed,expected', [
|
||||
(b'', []),
|
||||
(b'\x02\x00\x00', [b'']),
|
||||
(b'\x01\x00', [b'']),
|
||||
(b'\x01\x00\x01\x00', [b'', b'']),
|
||||
(b'\x01\x00\x01\x01a', [b'', b'a']),
|
||||
(b'\x01\x00\x01\x02ab', [b'', b'ab']),
|
||||
(b'\x01\x01a\x01\x02ab', [b'a', b'ab']),
|
||||
(b'\x01\x01a\x01\x02ab\x01\x00\x01\x0dHello, World!', [b'a', b'ab', b'', b'Hello, World!']),
|
||||
(b'\x01\xff' + b'a' * 255, [b'a' * 255]),
|
||||
((b'\x01\xff' + b'a' * 255) * 2, [b'a' * 255, b'a' * 255]),
|
||||
(b'\x02\x01\x00' + b'a' * 256, [b'a' * 256]),
|
||||
((b'\x02\x01\x00' + b'a' * 256) * 2, [b'a' * 256, b'a' * 256]),
|
||||
(b'\x02\x01\x01' + b'a' * 257, [b'a' * 257]),
|
||||
(b'\x02\x01\x01' + b'a' * 257 + b'\x01\x03abc', [b'a' * 257, b'abc']),
|
||||
])
|
||||
def test_unpack_fields(packed, expected):
|
||||
assert unpack_fields(packed) == expected
|
||||
Reference in New Issue
Block a user