Add tests for jebp_utils

This commit is contained in:
2026-04-02 12:34:21 +02:00
parent b12555d883
commit 32ed782fef
5 changed files with 119 additions and 4 deletions
+5 -4
View File
@@ -90,7 +90,7 @@ async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce:
return aesgcm.decrypt(aesnonce, m, None)
return m
def pack_fields(*fields: bytes) -> bytes:
def pack_fields(fields: list[bytes]) -> bytes:
'''
Packs the fields into a compact bytestring.
@@ -100,11 +100,12 @@ def pack_fields(*fields: bytes) -> bytes:
:return: The result
:rtype: bytes
'''
result = b''
for field in fields:
field_length = utils.int_to_bytes(len(field))
field_length_length = utils.int_to_bytes(len(field_length))
result += field_length_length + field_length + field
payload_length = utils.int_to_bytes(len(field))
payload_length_length = len(payload_length).to_bytes(1)
result += payload_length_length + payload_length + field
return result
def unpack_fields(raw: bytes) -> Sequence[bytes]:
+20
View File
@@ -0,0 +1,20 @@
from src.jeb_utils.jebp_utils import pack_fields
import pytest
@pytest.mark.parametrize('fields,expected', [
([], b''),
([b''], b'\x01\x00'),
([b'', b''], b'\x01\x00\x01\x00'),
([b'', b'a'], b'\x01\x00\x01\x01a'),
([b'', b'ab'], b'\x01\x00\x01\x02ab'),
([b'a', b'ab'], b'\x01\x01a\x01\x02ab'),
([b'a', b'ab', b'', b'Hello, World!'], b'\x01\x01a\x01\x02ab\x01\x00\x01\x0dHello, World!'),
([b'a' * 255], b'\x01\xff' + b'a' * 255),
([b'a' * 255, b'a' * 255], (b'\x01\xff' + b'a' * 255) * 2),
([b'a' * 256], b'\x02\x01\x00' + b'a' * 256),
([b'a' * 256, b'a' * 256], (b'\x02\x01\x00' + b'a' * 256) * 2),
([b'a' * 257], b'\x02\x01\x01' + b'a' * 257),
([b'a' * 257, b'abc'], b'\x02\x01\x01' + b'a' * 257 + b'\x01\x03abc'),
])
def test_pack_fields(fields, expected):
assert pack_fields(fields) == expected
+44
View File
@@ -0,0 +1,44 @@
from src.jeb_utils.jebp_utils import readmsg, _MSG_START_BYTE, MessageFormatError
import pytest
class AsyncTestStreamReader:
def __init__(self, content: bytes) -> None:
self.content = content
async def readexactly(self, n: int) -> bytes:
if len(self.content) < n:
raise EOFError
result = self.content[:n]
self.content = self.content[n:]
return result
@pytest.mark.asyncio
@pytest.mark.parametrize('expected,reader,start_byte', [
(b'', AsyncTestStreamReader(b'\x01\x01\x00'), b'\x01'),
(b'', AsyncTestStreamReader(b'\x02\x01\x00'), b'\x02'),
(b'a', AsyncTestStreamReader(b'\x02\x01\x01a'), b'\x02'),
(b'', AsyncTestStreamReader(b'\x01\x00'), b'\x01'),
(b'a' * 255, AsyncTestStreamReader(b'\x01\x01\xff' + b'a' * 255), b'\x01'),
(b'a' * 256, AsyncTestStreamReader(b'\x01\x02\x01\x00' + b'a' * 256), b'\x01'),
(b'a' * 256, AsyncTestStreamReader(b'\x01\x02\x01\x00' + b'a' * 257), b'\x01'),
(b'a' * 257, AsyncTestStreamReader(b'\x01\x02\x01\x01' + b'a' * 257), b'\x01'),
])
async def test_readmsg(expected, reader, start_byte):
assert await readmsg(reader, start_byte = start_byte) == expected
@pytest.mark.asyncio
@pytest.mark.parametrize('expected_exception,reader,start_byte', [
(EOFError, AsyncTestStreamReader(b'\x01\x01\x01'), b'\x01'),
(EOFError, AsyncTestStreamReader(b'\x01\x02\x01a'), b'\x01'),
(EOFError, AsyncTestStreamReader(b'\x01'), b'\x01'),
(EOFError, AsyncTestStreamReader(b'\x01\x01'), b'\x01'),
(EOFError, AsyncTestStreamReader(b'\x01\x02\x00'), b'\x01'),
(MessageFormatError, AsyncTestStreamReader(b'\x01\x02\x01a'), b'\x02'),
(MessageFormatError, AsyncTestStreamReader(b'\x01\x01\x01'), b'\x02'),
(MessageFormatError, AsyncTestStreamReader(b'\x01'), b'\x02'),
])
async def test_readmsg_exceptions(expected_exception, reader, start_byte):
with pytest.raises(expected_exception):
await readmsg(reader, start_byte = start_byte)
+29
View File
@@ -0,0 +1,29 @@
from src.jeb_utils.jebp_utils import sendmsg
import pytest
class AsyncTestStreamWriter:
def __init__(self):
self.content = b''
def write(self, data):
self.content += data
async def drain(self):
pass
@pytest.mark.asyncio
@pytest.mark.parametrize('expected,message,send_headers,_content_length,start_byte', [
(b'\x01\x01\x00', b'', True, None, b'\x01'),
(b'\x02\x01\x00', b'', True, None, b'\x02'),
(b'\x01\x01\x01a', b'a', True, None, b'\x01'),
(b'\x01a', b'a', False, None, b'\x01'),
(b'\x01a', b'a', False, 42, b'\x01'),
(b'\x01\x01\x2aa', b'a', True, 42, b'\x01'),
(b'\x01\x0209a', b'a', True, 12345, b'\x01'),
(b'\x01a', b'a', False, 42, b'\x01'),
])
async def test_sendmsg(expected, message, send_headers, _content_length, start_byte):
writer = AsyncTestStreamWriter()
await sendmsg(message, writer, send_headers = send_headers, _content_length = _content_length, start_byte = start_byte)
assert writer.content == expected
@@ -0,0 +1,21 @@
from src.jeb_utils.jebp_utils import unpack_fields
import pytest
@pytest.mark.parametrize('packed,expected', [
(b'', []),
(b'\x02\x00\x00', [b'']),
(b'\x01\x00', [b'']),
(b'\x01\x00\x01\x00', [b'', b'']),
(b'\x01\x00\x01\x01a', [b'', b'a']),
(b'\x01\x00\x01\x02ab', [b'', b'ab']),
(b'\x01\x01a\x01\x02ab', [b'a', b'ab']),
(b'\x01\x01a\x01\x02ab\x01\x00\x01\x0dHello, World!', [b'a', b'ab', b'', b'Hello, World!']),
(b'\x01\xff' + b'a' * 255, [b'a' * 255]),
((b'\x01\xff' + b'a' * 255) * 2, [b'a' * 255, b'a' * 255]),
(b'\x02\x01\x00' + b'a' * 256, [b'a' * 256]),
((b'\x02\x01\x00' + b'a' * 256) * 2, [b'a' * 256, b'a' * 256]),
(b'\x02\x01\x01' + b'a' * 257, [b'a' * 257]),
(b'\x02\x01\x01' + b'a' * 257 + b'\x01\x03abc', [b'a' * 257, b'abc']),
])
def test_unpack_fields(packed, expected):
assert unpack_fields(packed) == expected