neue Datei: .gitignore
neue Datei: README.md neue Datei: pyproject.toml neue Datei: src/jeb_utils/__init__.py neue Datei: src/jeb_utils/auth_utils.py neue Datei: src/jeb_utils/crypto_utils.py neue Datei: src/jeb_utils/exceptions.py neue Datei: src/jeb_utils/jeb_utils.py neue Datei: src/jeb_utils/jebp_utils.py neue Datei: src/jeb_utils/utils.py
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
import asyncio
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
import io
|
||||
from . import utils
|
||||
from typing import Sequence
|
||||
|
||||
__all__ = [
|
||||
'MessageFormatError',
|
||||
'sendmsg',
|
||||
'readmsg',
|
||||
'pack_fields',
|
||||
'unpack_fields',
|
||||
]
|
||||
|
||||
_MSG_START_BYTE = b'\x01'
|
||||
|
||||
class MessageFormatError(Exception): ...
|
||||
|
||||
async def sendmsg(m: bytes, writer: asyncio.StreamWriter, aesgcm: AESGCM = None, aesnonce: bytes = None, send_headers: bool = True, _content_length: int = None, start_byte: bytes = _MSG_START_BYTE):
|
||||
'''
|
||||
Sends a message.
|
||||
|
||||
:param m: The message
|
||||
:type m: bytes
|
||||
:param writer: The writer
|
||||
:type writer: asyncio.StreamWriter
|
||||
:param aesgcm: AESGCM
|
||||
:type aesgcm: AESGCM
|
||||
:param aesnonce: The nonce
|
||||
:type aesnonce: bytes
|
||||
:param send_headers: Controls whether the protocol headers are sent.
|
||||
:type send_headers: bool
|
||||
:param _content_length: The content length. If ``None``, the content length will be calculated automatically.
|
||||
:type _content_length: int
|
||||
:param start_byte: The start byte.
|
||||
:type start_byte: bytes
|
||||
'''
|
||||
if type(m) == str:
|
||||
m = m.encode()
|
||||
if aesgcm:
|
||||
m = aesgcm.encrypt(aesnonce, m, None)
|
||||
if _content_length == None:
|
||||
_content_length = len(m)
|
||||
if send_headers:
|
||||
content_length = utils.int_to_bytes(_content_length)
|
||||
content_length = bytes([len(content_length)]) + content_length
|
||||
else:
|
||||
content_length = b''
|
||||
writer.write(start_byte + content_length + m)
|
||||
await writer.drain()
|
||||
|
||||
async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce: bytes = None, start_byte: bytes = _MSG_START_BYTE) -> bytes:
|
||||
'''
|
||||
Receives a message.
|
||||
|
||||
:param reader: The reader
|
||||
:type reader: asyncio.StreamReader
|
||||
:param aesgcm: AESGCM
|
||||
:type aesgcm: AESGCM
|
||||
:param aesnonce: The nonce
|
||||
:type aesnonce: bytes
|
||||
:param start_byte: The start byte
|
||||
:type start_byte: bytes
|
||||
|
||||
:raises MessageFormatError: If the message format is invalid.
|
||||
|
||||
:return: The received message
|
||||
:rtype: bytes
|
||||
'''
|
||||
if await reader.readexactly(1) != start_byte:
|
||||
raise MessageFormatError('invalid message format')
|
||||
content_length_length = await reader.readexactly(1)
|
||||
content_length = await reader.readexactly(int.from_bytes(content_length_length))
|
||||
m = await reader.readexactly(int.from_bytes(content_length))
|
||||
if aesgcm:
|
||||
return aesgcm.decrypt(aesnonce, m, None)
|
||||
return m
|
||||
|
||||
def pack_fields(*fields: bytes) -> bytes:
|
||||
'''
|
||||
Packs the fields into a compact bytestring.
|
||||
|
||||
:param fields: The fields
|
||||
:type fields: bytes
|
||||
|
||||
:return: The result
|
||||
:rtype: bytes
|
||||
'''
|
||||
result = b''
|
||||
for field in fields:
|
||||
field_length = utils.int_to_bytes(len(field))
|
||||
field_length_length = utils.int_to_bytes(len(field_length))
|
||||
result += field_length_length + field_length + field
|
||||
return result
|
||||
|
||||
def unpack_fields(packed: bytes) -> Sequence[bytes]:
|
||||
'''
|
||||
Unpacks the field from a bytestring.
|
||||
|
||||
:param packed: The packed fields
|
||||
:type packed: bytes
|
||||
|
||||
:return: The unpacked fields
|
||||
:rtype: Sequence[bytes]
|
||||
'''
|
||||
packed = []
|
||||
buffer = io.BytesIO(packed)
|
||||
while buffer.tell() < len(packed):
|
||||
try:
|
||||
field_length_length = int.from_bytes(buffer.read(1))
|
||||
field_length = int.from_bytes(buffer.read(field_length_length))
|
||||
field = buffer.read(field_length)
|
||||
packed += (field,)
|
||||
except:
|
||||
pass
|
||||
return packed
|
||||
Reference in New Issue
Block a user