Reduce the number of times we recreate the cipher in klap

This commit is contained in:
J. Nick Koston 2024-01-25 22:23:09 -10:00
parent c318303255
commit dcd9322cfe
No known key found for this signature in database

View File

@ -46,6 +46,7 @@ import datetime
import hashlib
import logging
import secrets
import struct
import time
from pprint import pformat as pf
from typing import Any, Dict, Optional, Tuple, cast
@ -66,6 +67,8 @@ _LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
PACK_SIGNED_LONG = struct.Struct(">l").pack
def _sha256(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.SHA256()) # noqa: S303
@ -432,6 +435,8 @@ class KlapEncryptionSession:
self.user_hash = user_hash
self._key = self._key_derive(local_seed, remote_seed, user_hash)
(self._iv, self._seq) = self._iv_derive(local_seed, remote_seed, user_hash)
self._aes = algorithms.AES(self._key)
self._generate_cipher()
self._sig = self._sig_derive(local_seed, remote_seed, user_hash)
def _key_derive(self, local_seed, remote_seed, user_hash):
@ -451,19 +456,20 @@ class KlapEncryptionSession:
payload = b"ldk" + local_seed + remote_seed + user_hash
return hashlib.sha256(payload).digest()[:28]
def _iv_seq(self):
seq = self._seq.to_bytes(4, "big", signed=True)
iv = self._iv + seq
return iv
def _generate_cipher(self):
iv_seq = self._iv + PACK_SIGNED_LONG(self._seq)
cbc = modes.CBC(iv_seq)
self._cipher = Cipher(self._aes, cbc)
def encrypt(self, msg):
"""Encrypt the data and increment the sequence number."""
self._seq = self._seq + 1
self._seq += 1
self._generate_cipher()
if isinstance(msg, str):
msg = msg.encode("utf-8")
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq()))
encryptor = cipher.encryptor()
encryptor = self._cipher.encryptor()
padder = padding.PKCS7(128).padder()
padded_data = padder.update(msg) + padder.finalize()
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
@ -478,8 +484,7 @@ class KlapEncryptionSession:
def decrypt(self, msg):
"""Decrypt the data."""
cipher = Cipher(algorithms.AES(self._key), modes.CBC(self._iv_seq()))
decryptor = cipher.decryptor()
decryptor = self._cipher.decryptor()
dp = decryptor.update(msg[32:]) + decryptor.finalize()
unpadder = padding.PKCS7(128).unpadder()
plaintextbytes = unpadder.update(dp) + unpadder.finalize()