131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
# Copyright 2026 jCloud Services GbR
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import asyncio
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
import io
|
|
from . import utils
|
|
from typing import Sequence
|
|
|
|
__all__ = [
|
|
'MessageFormatError',
|
|
'sendmsg',
|
|
'readmsg',
|
|
'pack_fields',
|
|
'unpack_fields',
|
|
]
|
|
|
|
_MSG_START_BYTE = b'\x01'
|
|
|
|
class MessageFormatError(Exception): ...
|
|
|
|
async def sendmsg(m: bytes, writer: asyncio.StreamWriter, aesgcm: AESGCM = None, aesnonce: bytes = None, send_headers: bool = True, _content_length: int = None, start_byte: bytes = _MSG_START_BYTE):
|
|
'''
|
|
Sends a message.
|
|
|
|
:param m: The message
|
|
:type m: bytes
|
|
:param writer: The writer
|
|
:type writer: asyncio.StreamWriter
|
|
:param aesgcm: AESGCM
|
|
:type aesgcm: AESGCM
|
|
:param aesnonce: The nonce
|
|
:type aesnonce: bytes
|
|
:param send_headers: Controls whether the protocol headers are sent.
|
|
:type send_headers: bool
|
|
:param _content_length: The content length. If ``None``, the content length will be calculated automatically.
|
|
:type _content_length: int
|
|
:param start_byte: The start byte.
|
|
:type start_byte: bytes
|
|
'''
|
|
if type(m) == str:
|
|
m = m.encode()
|
|
if aesgcm:
|
|
m = aesgcm.encrypt(aesnonce, m, None)
|
|
if _content_length == None:
|
|
_content_length = len(m)
|
|
if send_headers:
|
|
content_length = utils.int_to_bytes(_content_length)
|
|
content_length = bytes([len(content_length)]) + content_length
|
|
else:
|
|
content_length = b''
|
|
writer.write(start_byte + content_length + m)
|
|
await writer.drain()
|
|
|
|
async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce: bytes = None, start_byte: bytes = _MSG_START_BYTE) -> bytes:
|
|
'''
|
|
Receives a message.
|
|
|
|
:param reader: The reader
|
|
:type reader: asyncio.StreamReader
|
|
:param aesgcm: AESGCM
|
|
:type aesgcm: AESGCM
|
|
:param aesnonce: The nonce
|
|
:type aesnonce: bytes
|
|
:param start_byte: The start byte
|
|
:type start_byte: bytes
|
|
|
|
:raises MessageFormatError: If the message format is invalid.
|
|
|
|
:return: The received message
|
|
:rtype: bytes
|
|
'''
|
|
if await reader.readexactly(1) != start_byte:
|
|
raise MessageFormatError('invalid message format')
|
|
content_length_length = await reader.readexactly(1)
|
|
content_length = await reader.readexactly(int.from_bytes(content_length_length))
|
|
m = await reader.readexactly(int.from_bytes(content_length))
|
|
if aesgcm:
|
|
return aesgcm.decrypt(aesnonce, m, None)
|
|
return m
|
|
|
|
def pack_fields(fields: list[bytes]) -> bytes:
|
|
'''
|
|
Packs the fields into a compact bytestring.
|
|
|
|
:param fields: The fields
|
|
:type fields: bytes
|
|
|
|
:return: The result
|
|
:rtype: bytes
|
|
'''
|
|
|
|
result = b''
|
|
for field in fields:
|
|
payload_length = utils.int_to_bytes(len(field))
|
|
payload_length_length = len(payload_length).to_bytes(1)
|
|
result += payload_length_length + payload_length + field
|
|
return result
|
|
|
|
def unpack_fields(raw: bytes) -> Sequence[bytes]:
|
|
'''
|
|
Unpacks the field from a bytestring.
|
|
|
|
:param raw: The packed fields
|
|
:type raw: bytes
|
|
|
|
:return: The unpacked fields
|
|
:rtype: Sequence[bytes]
|
|
'''
|
|
packed = []
|
|
buffer = io.BytesIO(raw)
|
|
while buffer.tell() < len(raw):
|
|
try:
|
|
field_length_length = int.from_bytes(buffer.read(1))
|
|
field_length = int.from_bytes(buffer.read(field_length_length))
|
|
field = buffer.read(field_length)
|
|
packed.append(field)
|
|
except:
|
|
pass
|
|
return packed |