Reduce the number of times creating the cipher in klap (#712)

This commit is contained in:
J. Nick Koston 2024-01-26 07:44:41 -10:00 committed by GitHub
parent dd38225f51
commit 7e2be35e4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,7 @@ import datetime
import hashlib import hashlib
import logging import logging
import secrets import secrets
import struct
import time import time
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Dict, Optional, Tuple, cast from typing import Any, Dict, Optional, Tuple, cast
@ -66,6 +67,8 @@ _LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400 ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20 SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
PACK_SIGNED_LONG = struct.Struct(">l").pack
def _sha256(payload: bytes) -> bytes: def _sha256(payload: bytes) -> bytes:
return hashlib.sha256(payload).digest() # noqa: S324 return hashlib.sha256(payload).digest() # noqa: S324
@ -421,12 +424,15 @@ class KlapEncryptionSession:
i.e. sequence number which the device expects to increment. i.e. sequence number which the device expects to increment.
""" """
_cipher: Cipher
def __init__(self, local_seed, remote_seed, user_hash): def __init__(self, local_seed, remote_seed, user_hash):
self.local_seed = local_seed self.local_seed = local_seed
self.remote_seed = remote_seed self.remote_seed = remote_seed
self.user_hash = user_hash self.user_hash = user_hash
self._key = self._key_derive(local_seed, remote_seed, user_hash) self._key = self._key_derive(local_seed, remote_seed, user_hash)
(self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash) (self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash)
self._aes = algorithms.AES(self._key)
self._sig = self._sig_derive(local_seed, remote_seed, user_hash) self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
def _key_derive(self, local_seed, remote_seed, user_hash): def _key_derive(self, local_seed, remote_seed, user_hash):
@ -446,31 +452,31 @@ class KlapEncryptionSession:
payload = b"ldk" + local_seed + remote_seed + user_hash payload = b"ldk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:28] return hashlib.sha256(payload).digest()[:28]
def _iv_seq(self): def _generate_cipher(self):
seq = self._seq.to_bytes(4, "big", signed=True) iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
iv = self._iv + seq cbc = modes.CBC(iv_seq)
return iv self._cipher = Cipher(self._aes, cbc)
def encrypt(self, msg): def encrypt(self, msg):
"""Encrypt the data and increment the sequence number.""" """Encrypt the data and increment the sequence number."""
self._seq = self._seq + 1 self._seq += 1
self._generate_cipher()
if isinstance(msg, str): if isinstance(msg, str):
msg = msg.encode("utf-8") msg = msg.encode("utf-8")
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq())) encryptor = self._cipher.encryptor()
encryptor = cipher.encryptor()
padder = padding.PKCS7(128).padder() padder = padding.PKCS7(128).padder()
padded_data = padder.update(msg) + padder.finalize() padded_data = padder.update(msg) + padder.finalize()
ciphertext = encryptor.update(padded_data) + encryptor.finalize() ciphertext = encryptor.update(padded_data) + encryptor.finalize()
signature = hashlib.sha256( signature = hashlib.sha256(
self._sig + self._seq.to_bytes(4, "big", signed=True) + ciphertext self._sig + PACK_SIGNED_LONG(self._seq) + ciphertext
).digest() ).digest()
return (signature + ciphertext, self._seq) return (signature + ciphertext, self._seq)
def decrypt(self, msg): def decrypt(self, msg):
"""Decrypt the data.""" """Decrypt the data."""
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq())) decryptor = self._cipher.decryptor()
decryptor = cipher.decryptor()
dp = decryptor.update(msg[32:]) + decryptor.finalize() dp = decryptor.update(msg[32:]) + decryptor.finalize()
unpadder = padding.PKCS7(128).unpadder() unpadder = padding.PKCS7(128).unpadder()
plaintextbytes = unpadder.update(dp) + unpadder.finalize() plaintextbytes = unpadder.update(dp) + unpadder.finalize()