Reduce the number of times we recreate the cipher in klap

This commit is contained in:
J. Nick Koston 2024-01-25 22:23:09 -10:00
parent c318303255
commit dcd9322cfe
No known key found for this signature in database

View File

@ -46,6 +46,7 @@ import datetime
import hashlib import hashlib
import logging import logging
import secrets import secrets
import struct
import time import time
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Dict, Optional, Tuple, cast from typing import Any, Dict, Optional, Tuple, cast
@ -66,6 +67,8 @@ _LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400 ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20 SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
PACK_SIGNED_LONG = struct.Struct(">l").pack
def _sha256(payload: bytes) -> bytes: def _sha256(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.SHA256()) # noqa: S303 digest = hashes.Hash(hashes.SHA256()) # noqa: S303
@ -432,6 +435,8 @@ class KlapEncryptionSession:
self.user_hash = user_hash self.user_hash = user_hash
self._key = self._key_derive(local_seed, remote_seed, user_hash) self._key = self._key_derive(local_seed, remote_seed, user_hash)
(self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash) (self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash)
self._aes = algorithms.AES(self._key)
self._generate_cipher()
self._sig = self._sig_derive(local_seed, remote_seed, user_hash) self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
def _key_derive(self, local_seed, remote_seed, user_hash): def _key_derive(self, local_seed, remote_seed, user_hash):
@ -451,19 +456,20 @@ class KlapEncryptionSession:
payload = b"ldk" + local_seed + remote_seed + user_hash payload = b"ldk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:28] return hashlib.sha256(payload).digest()[:28]
def _iv_seq(self): def _generate_cipher(self):
seq = self._seq.to_bytes(4, "big", signed=True) iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
iv = self._iv + seq cbc = modes.CBC(iv_seq)
return iv self._cipher = Cipher(self._aes, cbc)
def encrypt(self, msg): def encrypt(self, msg):
"""Encrypt the data and increment the sequence number.""" """Encrypt the data and increment the sequence number."""
self._seq = self._seq + 1 self._seq += 1
self._generate_cipher()
if isinstance(msg, str): if isinstance(msg, str):
msg = msg.encode("utf-8") msg = msg.encode("utf-8")
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq())) encryptor = self._cipher.encryptor()
encryptor = cipher.encryptor()
padder = padding.PKCS7(128).padder() padder = padding.PKCS7(128).padder()
padded_data = padder.update(msg) + padder.finalize() padded_data = padder.update(msg) + padder.finalize()
ciphertext = encryptor.update(padded_data) + encryptor.finalize() ciphertext = encryptor.update(padded_data) + encryptor.finalize()
@ -478,8 +484,7 @@ class KlapEncryptionSession:
def decrypt(self, msg): def decrypt(self, msg):
"""Decrypt the data.""" """Decrypt the data."""
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq())) decryptor = self._cipher.decryptor()
decryptor = cipher.decryptor()
dp = decryptor.update(msg[32:]) + decryptor.finalize() dp = decryptor.update(msg[32:]) + decryptor.finalize()
unpadder = padding.PKCS7(128).unpadder() unpadder = padding.PKCS7(128).unpadder()
plaintextbytes = unpadder.update(dp) + unpadder.finalize() plaintextbytes = unpadder.update(dp) + unpadder.finalize()