import functools import os import struct from cryptography.exceptions import InvalidTag from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.ciphers import ( Cipher, algorithms, modes ) from cryptography.hazmat.primitives.serialization import ( Encoding, PublicFormat ) from cryptography.hazmat.primitives.asymmetric import ec MAX_RECORD_SIZE = pow(2, 31) - 1 MIN_RECORD_SIZE = 3 KEY_LENGTH = 16 NONCE_LENGTH = 12 TAG_LENGTH = 16 # Valid content types (ordered from newest, to most obsolete) versions = { "aes128gcm": {"pad": 1}, "aesgcm": {"pad": 2}, "aesgcm128": {"pad": 1}, } class ECEException(Exception): """Exception for ECE encryption functions""" def __init__(self, message): self.message = message def derive_key(mode, version, salt, key, private_key, dh, auth_secret, keyid, keylabel="P-256"): """Derive the encryption key :param mode: operational mode (encrypt or decrypt) :type mode: enumerate('encrypt', 'decrypt) :param salt: encryption salt value :type salt: str :param key: raw key :type key: str :param private_key: DH private key :type key: object :param dh: Diffie Helman public key value :type dh: str :param keyid: key identifier label :type keyid: str :param keylabel: label for aesgcm/aesgcm128 :type keylabel: str :param auth_secret: authorization secret :type auth_secret: str :param version: Content Type identifier :type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128') """ context = b"" keyinfo = "" nonceinfo = "" def build_info(base, info_context): return b"Content-Encoding: " + base + b"\0" + info_context def derive_dh(mode, version, private_key, dh, keylabel): def length_prefix(key): return struct.pack("!H", len(key)) + key if isinstance(dh, ec.EllipticCurvePublicKey): pubkey = dh dh = dh.public_bytes( Encoding.X962, PublicFormat.UncompressedPoint) else: pubkey = ec.EllipticCurvePublicKey.from_encoded_point( ec.SECP256R1(), dh ) encoded = private_key.public_key().public_bytes( Encoding.X962, PublicFormat.UncompressedPoint) if mode == "encrypt": sender_pub_key = encoded receiver_pub_key = dh else: sender_pub_key = dh receiver_pub_key = encoded if version == "aes128gcm": context = b"WebPush: info\x00" + receiver_pub_key + sender_pub_key else: context = (keylabel.encode('utf-8') + b"\0" + length_prefix(receiver_pub_key) + length_prefix(sender_pub_key)) return private_key.exchange(ec.ECDH(), pubkey), context if version not in versions: raise ECEException(u"Invalid version") if mode not in ['encrypt', 'decrypt']: raise ECEException(u"unknown 'mode' specified: " + mode) if salt is None or len(salt) != KEY_LENGTH: raise ECEException(u"'salt' must be a 16 octet value") if dh is not None: if private_key is None: raise ECEException(u"DH requires a private_key") (secret, context) = derive_dh(mode=mode, version=version, private_key=private_key, dh=dh, keylabel=keylabel) else: secret = key if secret is None: raise ECEException(u"unable to determine the secret") if version == "aesgcm": keyinfo = build_info(b"aesgcm", context) nonceinfo = build_info(b"nonce", context) elif version == "aesgcm128": keyinfo = b"Content-Encoding: aesgcm128" nonceinfo = b"Content-Encoding: nonce" elif version == "aes128gcm": keyinfo = b"Content-Encoding: aes128gcm\x00" nonceinfo = b"Content-Encoding: nonce\x00" if dh is None: # Only mix the authentication secret when using DH for aes128gcm auth_secret = None if auth_secret is not None: if version == "aes128gcm": info = context else: info = build_info(b'auth', b'') hkdf_auth = HKDF( algorithm=hashes.SHA256(), length=32, salt=auth_secret, info=info, backend=default_backend() ) secret = hkdf_auth.derive(secret) hkdf_key = HKDF( algorithm=hashes.SHA256(), length=KEY_LENGTH, salt=salt, info=keyinfo, backend=default_backend() ) hkdf_nonce = HKDF( algorithm=hashes.SHA256(), length=NONCE_LENGTH, salt=salt, info=nonceinfo, backend=default_backend() ) return hkdf_key.derive(secret), hkdf_nonce.derive(secret) def iv(base, counter): """Generate an initialization vector. """ if (counter >> 64) != 0: raise ECEException(u"Counter too big") (mask,) = struct.unpack("!Q", base[4:]) return base[:4] + struct.pack("!Q", counter ^ mask) def decrypt(content, salt=None, key=None, private_key=None, dh=None, auth_secret=None, keyid=None, keylabel="P-256", rs=4096, version="aes128gcm"): """ Decrypt a data block :param content: Data to be decrypted :type content: str :param salt: Encryption salt :type salt: str :param key: local public key :type key: str :param private_key: DH private key :type key: object :param keyid: Internal key identifier for private key info :type keyid: str :param dh: Remote Diffie Hellman sequence (omit for aes128gcm) :type dh: str :param rs: Record size :type rs: int :param auth_secret: Authorization secret :type auth_secret: str :param version: ECE Method version :type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128') :return: Decrypted message content :rtype str """ def parse_content_header(content): """Parse an aes128gcm content body and extract the header values. :param content: The encrypted body of the message :type content: str """ id_len = struct.unpack("!B", content[20:21])[0] return { "salt": content[:16], "rs": struct.unpack("!L", content[16:20])[0], "keyid": content[21:21 + id_len], "content": content[21 + id_len:], } def decrypt_record(key, nonce, counter, content): decryptor = Cipher( algorithms.AES(key), modes.GCM(iv(nonce, counter), tag=content[-TAG_LENGTH:]), backend=default_backend() ).decryptor() return decryptor.update(content[:-TAG_LENGTH]) + decryptor.finalize() def unpad_legacy(data): pad_size = versions[version]['pad'] pad = functools.reduce( lambda x, y: x << 8 | y, struct.unpack( "!" + ("B" * pad_size), data[0:pad_size]) ) if pad_size + pad > len(data) or \ data[pad_size:pad_size+pad] != (b"\x00" * pad): raise ECEException(u"Bad padding") return data[pad_size + pad:] def unpad(data, last): i = len(data) - 1 for i in range(len(data) - 1, -1, -1): v = struct.unpack('B', data[i:i+1])[0] if v != 0: if not last and v != 1: raise ECEException(u'record delimiter != 1') if last and v != 2: raise ECEException(u'last record delimiter != 2') return data[0:i] raise ECEException(u'all zero record plaintext') if version not in versions: raise ECEException(u"Invalid version") overhead = versions[version]['pad'] if version == "aes128gcm": try: content_header = parse_content_header(content) except Exception: raise ECEException("Could not parse the content header") salt = content_header['salt'] rs = content_header['rs'] keyid = content_header['keyid'] if private_key is not None and not dh: dh = keyid else: keyid = keyid.decode('utf-8') content = content_header['content'] overhead += 16 (key_, nonce_) = derive_key("decrypt", version=version, salt=salt, key=key, private_key=private_key, dh=dh, auth_secret=auth_secret, keyid=keyid, keylabel=keylabel) if rs <= overhead: raise ECEException(u"Record size too small") chunk = rs if version != "aes128gcm": chunk += 16 # account for tags in old versions if len(content) % chunk == 0: raise ECEException(u"Message truncated") result = b'' counter = 0 try: for i in list(range(0, len(content), chunk)): data = decrypt_record(key_, nonce_, counter, content[i:i + chunk]) if version == 'aes128gcm': last = (i + chunk) >= len(content) result += unpad(data, last) else: result += unpad_legacy(data) counter += 1 except InvalidTag as ex: raise ECEException("Decryption error: {}".format(repr(ex))) return result def encrypt(content, salt=None, key=None, private_key=None, dh=None, auth_secret=None, keyid=None, keylabel="P-256", rs=4096, version="aes128gcm"): """ Encrypt a data block :param content: block of data to encrypt :type content: str :param salt: Encryption salt :type salt: str :param key: Encryption key data :type key: str :param private_key: DH private key :type key: object :param keyid: Internal key identifier for private key info :type keyid: str :param dh: Remote Diffie Hellman sequence :type dh: str :param rs: Record size :type rs: int :param auth_secret: Authorization secret :type auth_secret: str :param version: ECE Method version :type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128') :return: Encrypted message content :rtype str """ def encrypt_record(key, nonce, counter, buf, last): encryptor = Cipher( algorithms.AES(key), modes.GCM(iv(nonce, counter)), backend=default_backend() ).encryptor() if version == 'aes128gcm': data = encryptor.update(buf + (b'\x02' if last else b'\x01')) else: data = encryptor.update((b"\x00" * versions[version]['pad']) + buf) data += encryptor.finalize() data += encryptor.tag return data def compose_aes128gcm(salt, content, rs, keyid): """Compose the header and content of an aes128gcm encrypted message body :param salt: The sender's salt value :type salt: str :param content: The encrypted body of the message :type content: str :param rs: Override for the content length :type rs: int :param keyid: The keyid to use for this message :type keyid: str """ if len(keyid) > 255: raise ECEException("keyid is too long") header = salt if rs > MAX_RECORD_SIZE: raise ECEException("Too much content") header += struct.pack("!L", rs) header += struct.pack("!B", len(keyid)) header += keyid return header + content if version not in versions: raise ECEException(u"Invalid version") if salt is None: salt = os.urandom(16) (key_, nonce_) = derive_key("encrypt", version=version, salt=salt, key=key, private_key=private_key, dh=dh, auth_secret=auth_secret, keyid=keyid, keylabel=keylabel) overhead = versions[version]['pad'] if version == 'aes128gcm': overhead += 16 end = len(content) else: end = len(content) + 1 if rs <= overhead: raise ECEException(u"Record size too small") chunk_size = rs - overhead result = b"" counter = 0 # the extra one on the loop ensures that we produce a padding only # record if the data length is an exact multiple of the chunk size print(content) for i in list(range(0, end, chunk_size)): result += encrypt_record(key_, nonce_, counter, content[i:i + chunk_size], (i + chunk_size) >= end) counter += 1 if version == "aes128gcm": if keyid is None and private_key is not None: kid = private_key.public_key().public_bytes( Encoding.X962, PublicFormat.UncompressedPoint) else: kid = (keyid or '').encode('utf-8') return compose_aes128gcm(salt, result, rs, keyid=kid) return result