Files
jeb-utils/src/jeb_utils/jebp_utils.py
T
2026-04-02 12:34:21 +02:00

131 lines
4.0 KiB
Python

# Copyright 2026 jCloud Services GbR
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import io
from . import utils
from typing import Sequence
__all__ = [
'MessageFormatError',
'sendmsg',
'readmsg',
'pack_fields',
'unpack_fields',
]
_MSG_START_BYTE = b'\x01'
class MessageFormatError(Exception): ...
async def sendmsg(m: bytes, writer: asyncio.StreamWriter, aesgcm: AESGCM = None, aesnonce: bytes = None, send_headers: bool = True, _content_length: int = None, start_byte: bytes = _MSG_START_BYTE):
'''
Sends a message.
:param m: The message
:type m: bytes
:param writer: The writer
:type writer: asyncio.StreamWriter
:param aesgcm: AESGCM
:type aesgcm: AESGCM
:param aesnonce: The nonce
:type aesnonce: bytes
:param send_headers: Controls whether the protocol headers are sent.
:type send_headers: bool
:param _content_length: The content length. If ``None``, the content length will be calculated automatically.
:type _content_length: int
:param start_byte: The start byte.
:type start_byte: bytes
'''
if type(m) == str:
m = m.encode()
if aesgcm:
m = aesgcm.encrypt(aesnonce, m, None)
if _content_length == None:
_content_length = len(m)
if send_headers:
content_length = utils.int_to_bytes(_content_length)
content_length = bytes([len(content_length)]) + content_length
else:
content_length = b''
writer.write(start_byte + content_length + m)
await writer.drain()
async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce: bytes = None, start_byte: bytes = _MSG_START_BYTE) -> bytes:
'''
Receives a message.
:param reader: The reader
:type reader: asyncio.StreamReader
:param aesgcm: AESGCM
:type aesgcm: AESGCM
:param aesnonce: The nonce
:type aesnonce: bytes
:param start_byte: The start byte
:type start_byte: bytes
:raises MessageFormatError: If the message format is invalid.
:return: The received message
:rtype: bytes
'''
if await reader.readexactly(1) != start_byte:
raise MessageFormatError('invalid message format')
content_length_length = await reader.readexactly(1)
content_length = await reader.readexactly(int.from_bytes(content_length_length))
m = await reader.readexactly(int.from_bytes(content_length))
if aesgcm:
return aesgcm.decrypt(aesnonce, m, None)
return m
def pack_fields(fields: list[bytes]) -> bytes:
'''
Packs the fields into a compact bytestring.
:param fields: The fields
:type fields: bytes
:return: The result
:rtype: bytes
'''
result = b''
for field in fields:
payload_length = utils.int_to_bytes(len(field))
payload_length_length = len(payload_length).to_bytes(1)
result += payload_length_length + payload_length + field
return result
def unpack_fields(raw: bytes) -> Sequence[bytes]:
'''
Unpacks the field from a bytestring.
:param raw: The packed fields
:type raw: bytes
:return: The unpacked fields
:rtype: Sequence[bytes]
'''
packed = []
buffer = io.BytesIO(raw)
while buffer.tell() < len(raw):
try:
field_length_length = int.from_bytes(buffer.read(1))
field_length = int.from_bytes(buffer.read(field_length_length))
field = buffer.read(field_length)
packed.append(field)
except:
pass
return packed