import asyncio from cryptography.hazmat.primitives.ciphers.aead import AESGCM import io from . import utils from typing import Sequence __all__ = [ 'MessageFormatError', 'sendmsg', 'readmsg', 'pack_fields', 'unpack_fields', ] _MSG_START_BYTE = b'\x01' class MessageFormatError(Exception): ... async def sendmsg(m: bytes, writer: asyncio.StreamWriter, aesgcm: AESGCM = None, aesnonce: bytes = None, send_headers: bool = True, _content_length: int = None, start_byte: bytes = _MSG_START_BYTE): ''' Sends a message. :param m: The message :type m: bytes :param writer: The writer :type writer: asyncio.StreamWriter :param aesgcm: AESGCM :type aesgcm: AESGCM :param aesnonce: The nonce :type aesnonce: bytes :param send_headers: Controls whether the protocol headers are sent. :type send_headers: bool :param _content_length: The content length. If ``None``, the content length will be calculated automatically. :type _content_length: int :param start_byte: The start byte. :type start_byte: bytes ''' if type(m) == str: m = m.encode() if aesgcm: m = aesgcm.encrypt(aesnonce, m, None) if _content_length == None: _content_length = len(m) if send_headers: content_length = utils.int_to_bytes(_content_length) content_length = bytes([len(content_length)]) + content_length else: content_length = b'' writer.write(start_byte + content_length + m) await writer.drain() async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce: bytes = None, start_byte: bytes = _MSG_START_BYTE) -> bytes: ''' Receives a message. :param reader: The reader :type reader: asyncio.StreamReader :param aesgcm: AESGCM :type aesgcm: AESGCM :param aesnonce: The nonce :type aesnonce: bytes :param start_byte: The start byte :type start_byte: bytes :raises MessageFormatError: If the message format is invalid. :return: The received message :rtype: bytes ''' if await reader.readexactly(1) != start_byte: raise MessageFormatError('invalid message format') content_length_length = await reader.readexactly(1) content_length = await reader.readexactly(int.from_bytes(content_length_length)) m = await reader.readexactly(int.from_bytes(content_length)) if aesgcm: return aesgcm.decrypt(aesnonce, m, None) return m def pack_fields(*fields: bytes) -> bytes: ''' Packs the fields into a compact bytestring. :param fields: The fields :type fields: bytes :return: The result :rtype: bytes ''' result = b'' for field in fields: field_length = utils.int_to_bytes(len(field)) field_length_length = utils.int_to_bytes(len(field_length)) result += field_length_length + field_length + field return result def unpack_fields(packed: bytes) -> Sequence[bytes]: ''' Unpacks the field from a bytestring. :param packed: The packed fields :type packed: bytes :return: The unpacked fields :rtype: Sequence[bytes] ''' packed = [] buffer = io.BytesIO(packed) while buffer.tell() < len(packed): try: field_length_length = int.from_bytes(buffer.read(1)) field_length = int.from_bytes(buffer.read(field_length_length)) field = buffer.read(field_length) packed += (field,) except: pass return packed