import asyncio from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography import x509 from datetime import datetime, timezone MSG_START_BYTE = b'\x01' class MessageFormatError(Exception): ... class InvalidCertificateError(Exception): ... def int_to_bytes(n: int, signed = False): n = int(n) return n.to_bytes((n.bit_length() + 7) // 8, signed = signed) async def sendmsg(m, writer: asyncio.StreamWriter, aesgcm: AESGCM = None, aesnonce = None, start_byte = MSG_START_BYTE): if type(m) == str: m = m.encode() if aesgcm: m = aesgcm.encrypt(aesnonce, m, None) content_length = int_to_bytes(len(m)) writer.write(start_byte + bytes([len(content_length)]) + content_length + m) await writer.drain() async def readmsg(reader: asyncio.StreamReader, aesgcm: AESGCM = None, aesnonce = None, start_byte = MSG_START_BYTE): if await reader.readexactly(1) != start_byte: raise MessageFormatError('invalid message format') content_length_length = await reader.readexactly(1) content_length = await reader.readexactly(int.from_bytes(content_length_length)) m = await reader.readexactly(int.from_bytes(content_length)) if aesgcm: return aesgcm.decrypt(aesnonce, m, None) return m def validate_cert(cert: x509.Certificate, common_name, issuer_common_name, now = datetime.now(timezone.utc)): if not cert.not_valid_before_utc <= now <= cert.not_valid_after_utc: return False cn = [a.value for a in cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)] if len(cn) < 1: return False if common_name not in cn: return False cn = [a.value for a in cert.issuer.get_attributes_for_oid(x509.NameOID.COMMON_NAME)] if len(cn) < 1: return False if issuer_common_name not in cn: return False return True