Add tests for jebp_utils
This commit is contained in:
@@ -90,7 +90,7 @@ async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce:
|
|||||||
return aesgcm.decrypt(aesnonce, m, None)
|
return aesgcm.decrypt(aesnonce, m, None)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
def pack_fields(*fields: bytes) -> bytes:
|
def pack_fields(fields: list[bytes]) -> bytes:
|
||||||
'''
|
'''
|
||||||
Packs the fields into a compact bytestring.
|
Packs the fields into a compact bytestring.
|
||||||
|
|
||||||
@@ -100,11 +100,12 @@ def pack_fields(*fields: bytes) -> bytes:
|
|||||||
:return: The result
|
:return: The result
|
||||||
:rtype: bytes
|
:rtype: bytes
|
||||||
'''
|
'''
|
||||||
|
|
||||||
result = b''
|
result = b''
|
||||||
for field in fields:
|
for field in fields:
|
||||||
field_length = utils.int_to_bytes(len(field))
|
payload_length = utils.int_to_bytes(len(field))
|
||||||
field_length_length = utils.int_to_bytes(len(field_length))
|
payload_length_length = len(payload_length).to_bytes(1)
|
||||||
result += field_length_length + field_length + field
|
result += payload_length_length + payload_length + field
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def unpack_fields(raw: bytes) -> Sequence[bytes]:
|
def unpack_fields(raw: bytes) -> Sequence[bytes]:
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from src.jeb_utils.jebp_utils import pack_fields
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('fields,expected', [
|
||||||
|
([], b''),
|
||||||
|
([b''], b'\x01\x00'),
|
||||||
|
([b'', b''], b'\x01\x00\x01\x00'),
|
||||||
|
([b'', b'a'], b'\x01\x00\x01\x01a'),
|
||||||
|
([b'', b'ab'], b'\x01\x00\x01\x02ab'),
|
||||||
|
([b'a', b'ab'], b'\x01\x01a\x01\x02ab'),
|
||||||
|
([b'a', b'ab', b'', b'Hello, World!'], b'\x01\x01a\x01\x02ab\x01\x00\x01\x0dHello, World!'),
|
||||||
|
([b'a' * 255], b'\x01\xff' + b'a' * 255),
|
||||||
|
([b'a' * 255, b'a' * 255], (b'\x01\xff' + b'a' * 255) * 2),
|
||||||
|
([b'a' * 256], b'\x02\x01\x00' + b'a' * 256),
|
||||||
|
([b'a' * 256, b'a' * 256], (b'\x02\x01\x00' + b'a' * 256) * 2),
|
||||||
|
([b'a' * 257], b'\x02\x01\x01' + b'a' * 257),
|
||||||
|
([b'a' * 257, b'abc'], b'\x02\x01\x01' + b'a' * 257 + b'\x01\x03abc'),
|
||||||
|
])
|
||||||
|
def test_pack_fields(fields, expected):
|
||||||
|
assert pack_fields(fields) == expected
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
from src.jeb_utils.jebp_utils import readmsg, _MSG_START_BYTE, MessageFormatError
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
class AsyncTestStreamReader:
|
||||||
|
def __init__(self, content: bytes) -> None:
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
async def readexactly(self, n: int) -> bytes:
|
||||||
|
if len(self.content) < n:
|
||||||
|
raise EOFError
|
||||||
|
result = self.content[:n]
|
||||||
|
self.content = self.content[n:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize('expected,reader,start_byte', [
|
||||||
|
(b'', AsyncTestStreamReader(b'\x01\x01\x00'), b'\x01'),
|
||||||
|
(b'', AsyncTestStreamReader(b'\x02\x01\x00'), b'\x02'),
|
||||||
|
(b'a', AsyncTestStreamReader(b'\x02\x01\x01a'), b'\x02'),
|
||||||
|
(b'', AsyncTestStreamReader(b'\x01\x00'), b'\x01'),
|
||||||
|
(b'a' * 255, AsyncTestStreamReader(b'\x01\x01\xff' + b'a' * 255), b'\x01'),
|
||||||
|
(b'a' * 256, AsyncTestStreamReader(b'\x01\x02\x01\x00' + b'a' * 256), b'\x01'),
|
||||||
|
(b'a' * 256, AsyncTestStreamReader(b'\x01\x02\x01\x00' + b'a' * 257), b'\x01'),
|
||||||
|
(b'a' * 257, AsyncTestStreamReader(b'\x01\x02\x01\x01' + b'a' * 257), b'\x01'),
|
||||||
|
])
|
||||||
|
async def test_readmsg(expected, reader, start_byte):
|
||||||
|
assert await readmsg(reader, start_byte = start_byte) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize('expected_exception,reader,start_byte', [
|
||||||
|
(EOFError, AsyncTestStreamReader(b'\x01\x01\x01'), b'\x01'),
|
||||||
|
(EOFError, AsyncTestStreamReader(b'\x01\x02\x01a'), b'\x01'),
|
||||||
|
(EOFError, AsyncTestStreamReader(b'\x01'), b'\x01'),
|
||||||
|
(EOFError, AsyncTestStreamReader(b'\x01\x01'), b'\x01'),
|
||||||
|
(EOFError, AsyncTestStreamReader(b'\x01\x02\x00'), b'\x01'),
|
||||||
|
(MessageFormatError, AsyncTestStreamReader(b'\x01\x02\x01a'), b'\x02'),
|
||||||
|
(MessageFormatError, AsyncTestStreamReader(b'\x01\x01\x01'), b'\x02'),
|
||||||
|
(MessageFormatError, AsyncTestStreamReader(b'\x01'), b'\x02'),
|
||||||
|
])
|
||||||
|
async def test_readmsg_exceptions(expected_exception, reader, start_byte):
|
||||||
|
with pytest.raises(expected_exception):
|
||||||
|
await readmsg(reader, start_byte = start_byte)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
from src.jeb_utils.jebp_utils import sendmsg
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
class AsyncTestStreamWriter:
|
||||||
|
def __init__(self):
|
||||||
|
self.content = b''
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
self.content += data
|
||||||
|
|
||||||
|
async def drain(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize('expected,message,send_headers,_content_length,start_byte', [
|
||||||
|
(b'\x01\x01\x00', b'', True, None, b'\x01'),
|
||||||
|
(b'\x02\x01\x00', b'', True, None, b'\x02'),
|
||||||
|
(b'\x01\x01\x01a', b'a', True, None, b'\x01'),
|
||||||
|
(b'\x01a', b'a', False, None, b'\x01'),
|
||||||
|
(b'\x01a', b'a', False, 42, b'\x01'),
|
||||||
|
(b'\x01\x01\x2aa', b'a', True, 42, b'\x01'),
|
||||||
|
(b'\x01\x0209a', b'a', True, 12345, b'\x01'),
|
||||||
|
(b'\x01a', b'a', False, 42, b'\x01'),
|
||||||
|
])
|
||||||
|
async def test_sendmsg(expected, message, send_headers, _content_length, start_byte):
|
||||||
|
writer = AsyncTestStreamWriter()
|
||||||
|
await sendmsg(message, writer, send_headers = send_headers, _content_length = _content_length, start_byte = start_byte)
|
||||||
|
assert writer.content == expected
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
from src.jeb_utils.jebp_utils import unpack_fields
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('packed,expected', [
|
||||||
|
(b'', []),
|
||||||
|
(b'\x02\x00\x00', [b'']),
|
||||||
|
(b'\x01\x00', [b'']),
|
||||||
|
(b'\x01\x00\x01\x00', [b'', b'']),
|
||||||
|
(b'\x01\x00\x01\x01a', [b'', b'a']),
|
||||||
|
(b'\x01\x00\x01\x02ab', [b'', b'ab']),
|
||||||
|
(b'\x01\x01a\x01\x02ab', [b'a', b'ab']),
|
||||||
|
(b'\x01\x01a\x01\x02ab\x01\x00\x01\x0dHello, World!', [b'a', b'ab', b'', b'Hello, World!']),
|
||||||
|
(b'\x01\xff' + b'a' * 255, [b'a' * 255]),
|
||||||
|
((b'\x01\xff' + b'a' * 255) * 2, [b'a' * 255, b'a' * 255]),
|
||||||
|
(b'\x02\x01\x00' + b'a' * 256, [b'a' * 256]),
|
||||||
|
((b'\x02\x01\x00' + b'a' * 256) * 2, [b'a' * 256, b'a' * 256]),
|
||||||
|
(b'\x02\x01\x01' + b'a' * 257, [b'a' * 257]),
|
||||||
|
(b'\x02\x01\x01' + b'a' * 257 + b'\x01\x03abc', [b'a' * 257, b'abc']),
|
||||||
|
])
|
||||||
|
def test_unpack_fields(packed, expected):
|
||||||
|
assert unpack_fields(packed) == expected
|
||||||
Reference in New Issue
Block a user