# Copyright 2026 jCloud Services GbR # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio from cryptography.hazmat.primitives.ciphers.aead import AESGCM import io from . import utils from typing import Sequence __all__ = [ 'MessageFormatError', 'sendmsg', 'readmsg', 'pack_fields', 'unpack_fields', ] _MSG_START_BYTE = b'\x01' class MessageFormatError(Exception): ... async def sendmsg(m: bytes, writer: asyncio.StreamWriter, aesgcm: AESGCM = None, aesnonce: bytes = None, send_headers: bool = True, _content_length: int = None, start_byte: bytes = _MSG_START_BYTE): ''' Sends a message. :param m: The message :type m: bytes :param writer: The writer :type writer: asyncio.StreamWriter :param aesgcm: AESGCM :type aesgcm: AESGCM :param aesnonce: The nonce :type aesnonce: bytes :param send_headers: Controls whether the protocol headers are sent. :type send_headers: bool :param _content_length: The content length. If ``None``, the content length will be calculated automatically. :type _content_length: int :param start_byte: The start byte. :type start_byte: bytes ''' if type(m) == str: m = m.encode() if aesgcm: m = aesgcm.encrypt(aesnonce, m, None) if _content_length == None: _content_length = len(m) if send_headers: content_length = utils.int_to_bytes(_content_length) content_length = bytes([len(content_length)]) + content_length else: content_length = b'' writer.write(start_byte + content_length + m) await writer.drain() async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce: bytes = None, start_byte: bytes = _MSG_START_BYTE) -> bytes: ''' Receives a message. :param reader: The reader :type reader: asyncio.StreamReader :param aesgcm: AESGCM :type aesgcm: AESGCM :param aesnonce: The nonce :type aesnonce: bytes :param start_byte: The start byte :type start_byte: bytes :raises MessageFormatError: If the message format is invalid. :return: The received message :rtype: bytes ''' if await reader.readexactly(1) != start_byte: raise MessageFormatError('invalid message format') content_length_length = await reader.readexactly(1) content_length = await reader.readexactly(int.from_bytes(content_length_length)) m = await reader.readexactly(int.from_bytes(content_length)) if aesgcm: return aesgcm.decrypt(aesnonce, m, None) return m def pack_fields(fields: list[bytes]) -> bytes: ''' Packs the fields into a compact bytestring. :param fields: The fields :type fields: bytes :return: The result :rtype: bytes ''' result = b'' for field in fields: payload_length = utils.int_to_bytes(len(field)) payload_length_length = len(payload_length).to_bytes(1) result += payload_length_length + payload_length + field return result def unpack_fields(raw: bytes) -> Sequence[bytes]: ''' Unpacks the field from a bytestring. :param raw: The packed fields :type raw: bytes :return: The unpacked fields :rtype: Sequence[bytes] ''' packed = [] buffer = io.BytesIO(raw) while buffer.tell() < len(raw): try: field_length_length = int.from_bytes(buffer.read(1)) field_length = int.from_bytes(buffer.read(field_length_length)) field = buffer.read(field_length) packed.append(field) except: pass return packed