Refactor aestransport to use a state enum (#691)

This commit is contained in:
J. Nick Koston 2024-01-23 22:50:25 -10:00 committed by GitHub
parent 3f40410db3
commit 24c645746e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 43 deletions

View File

@ -8,7 +8,8 @@ import base64
import hashlib import hashlib
import logging import logging
import time import time
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast from enum import Enum, auto
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast
from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@ -41,6 +42,14 @@ def _sha1(payload: bytes) -> str:
return sha1_algo.hexdigest() return sha1_algo.hexdigest()
class TransportState(Enum):
"""Enum for AES state."""
HANDSHAKE_REQUIRED = auto() # Handshake needed
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests
class AesTransport(BaseTransport): class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol. """Implementation of the AES encryption protocol.
@ -79,21 +88,21 @@ class AesTransport(BaseTransport):
self._default_credentials: Optional[Credentials] = None self._default_credentials: Optional[Credentials] = None
self._http_client: HttpClient = HttpClient(config) self._http_client: HttpClient = HttpClient(config)
self._handshake_done = False self._state = TransportState.HANDSHAKE_REQUIRED
self._encryption_session: Optional[AesEncyptionSession] = None self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self._session_cookie: Optional[Dict[str, str]] = None self._session_cookie: Optional[Dict[str, str]] = None
self._login_token = None self._login_token: Optional[str] = None
self._key_pair: Optional[KeyPair] = None self._key_pair: Optional[KeyPair] = None
_LOGGER.debug("Created AES transport for %s", self._host) _LOGGER.debug("Created AES transport for %s", self._host)
@property @property
def default_port(self): def default_port(self) -> int:
"""Default port for the transport.""" """Default port for the transport."""
return self.DEFAULT_PORT return self.DEFAULT_PORT
@ -102,30 +111,25 @@ class AesTransport(BaseTransport):
"""The hashed credentials used by the transport.""" """The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode() return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self, credentials): def _get_login_params(self, credentials: Credentials) -> Dict[str, str]:
"""Get the login parameters based on the login_version.""" """Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2, credentials) un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password" password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un} return {password_field_name: pw, "username": un}
@staticmethod @staticmethod
def hash_credentials(login_v2, credentials): def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]:
"""Hash the credentials.""" """Hash the credentials."""
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
if login_v2: if login_v2:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode( pw = base64.b64encode(
_sha1(credentials.password.encode()).encode() _sha1(credentials.password.encode()).encode()
).decode() ).decode()
else: else:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(credentials.password.encode()).decode() pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw return un, pw
def _handle_response_error_code(self, resp_dict: dict, msg: str): def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS: if error_code == SmartErrorCode.SUCCESS:
return return
@ -135,12 +139,11 @@ class AesTransport(BaseTransport):
if error_code in SMART_RETRYABLE_ERRORS: if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg, error_code=error_code) raise RetryableException(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS: if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False self._state = TransportState.HANDSHAKE_REQUIRED
self._login_token = None
raise AuthenticationException(msg, error_code=error_code) raise AuthenticationException(msg, error_code=error_code)
raise SmartDeviceException(msg, error_code=error_code) raise SmartDeviceException(msg, error_code=error_code)
async def send_secure_passthrough(self, request: str): async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
url = f"http://{self._host}/app" url = f"http://{self._host}/app"
if self._login_token: if self._login_token:
@ -165,16 +168,17 @@ class AesTransport(BaseTransport):
+ f"status code {status_code} to passthrough" + f"status code {status_code} to passthrough"
) )
resp_dict = cast(Dict, resp_dict)
self._handle_response_error_code( self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message" resp_dict, "Error sending secure_passthrough message"
) )
response = self._encryption_session.decrypt( # type: ignore if TYPE_CHECKING:
resp_dict["result"]["response"].encode() resp_dict = cast(Dict[str, Any], resp_dict)
) assert self._encryption_session is not None
resp_dict = json_loads(response)
return resp_dict raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value]
async def perform_login(self): async def perform_login(self):
"""Login to the device.""" """Login to the device."""
@ -182,7 +186,7 @@ class AesTransport(BaseTransport):
await self.try_login(self._login_params) await self.try_login(self._login_params)
except AuthenticationException as aex: except AuthenticationException as aex:
try: try:
if aex.error_code != SmartErrorCode.LOGIN_ERROR: if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex raise aex
if self._default_credentials is None: if self._default_credentials is None:
self._default_credentials = get_default_credentials( self._default_credentials = get_default_credentials(
@ -203,9 +207,8 @@ class AesTransport(BaseTransport):
ex, ex,
) from ex ) from ex
async def try_login(self, login_params): async def try_login(self, login_params: Dict[str, Any]) -> None:
"""Try to login with supplied login_params.""" """Try to login with supplied login_params."""
self._login_token = None
login_request = { login_request = {
"method": "login_device", "method": "login_device",
"params": login_params, "params": login_params,
@ -216,6 +219,7 @@ class AesTransport(BaseTransport):
resp_dict = await self.send_secure_passthrough(request) resp_dict = await self.send_secure_passthrough(request)
self._handle_response_error_code(resp_dict, "Error logging in") self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"] self._login_token = resp_dict["result"]["token"]
self._state = TransportState.ESTABLISHED
async def _generate_key_pair_payload(self) -> AsyncGenerator: async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator. """Generate the request body and return an ascyn_generator.
@ -236,12 +240,11 @@ class AesTransport(BaseTransport):
_LOGGER.debug(f"Request {request_body}") _LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode() yield json_dumps(request_body).encode()
async def perform_handshake(self): async def perform_handshake(self) -> None:
"""Perform the handshake.""" """Perform the handshake."""
_LOGGER.debug("Will perform handshaking...") _LOGGER.debug("Will perform handshaking...")
self._key_pair = None self._key_pair = None
self._handshake_done = False
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None
@ -258,7 +261,7 @@ class AesTransport(BaseTransport):
cookies_dict=self._session_cookie, cookies_dict=self._session_cookie,
) )
_LOGGER.debug(f"Device responded with: {resp_dict}") _LOGGER.debug("Device responded with: %s", resp_dict)
if status_code != 200: if status_code != 200:
raise SmartDeviceException( raise SmartDeviceException(
@ -268,6 +271,9 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Unable to complete handshake") self._handle_response_error_code(resp_dict, "Unable to complete handshake")
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
handshake_key = resp_dict["result"]["key"] handshake_key = resp_dict["result"]["key"]
if ( if (
@ -283,12 +289,12 @@ class AesTransport(BaseTransport):
self._session_expire_at = time.time() + 86400 self._session_expire_at = time.time() + 86400
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._key_pair is not None # pragma: no cover assert self._key_pair is not None
self._encryption_session = AesEncyptionSession.create_from_keypair( self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, self._key_pair handshake_key, self._key_pair
) )
self._handshake_done = True self._state = TransportState.LOGIN_REQUIRED
_LOGGER.debug("Handshake with %s complete", self._host) _LOGGER.debug("Handshake with %s complete", self._host)
@ -299,17 +305,20 @@ class AesTransport(BaseTransport):
or self._session_expire_at - time.time() <= 0 or self._session_expire_at - time.time() <= 0
) )
async def send(self, request: str): async def send(self, request: str) -> Dict[str, Any]:
"""Send the request.""" """Send the request."""
if not self._handshake_done or self._handshake_session_expired(): if (
self._state is TransportState.HANDSHAKE_REQUIRED
or self._handshake_session_expired()
):
await self.perform_handshake() await self.perform_handshake()
if not self._login_token: if self._state is not TransportState.ESTABLISHED:
try: try:
await self.perform_login() await self.perform_login()
# After a login failure handshake needs to # After a login failure handshake needs to
# be redone or a 9999 error is received. # be redone or a 9999 error is received.
except AuthenticationException as ex: except AuthenticationException as ex:
self._handshake_done = False self._state = TransportState.HANDSHAKE_REQUIRED
raise ex raise ex
return await self.send_secure_passthrough(request) return await self.send_secure_passthrough(request)
@ -321,8 +330,7 @@ class AesTransport(BaseTransport):
async def reset(self) -> None: async def reset(self) -> None:
"""Reset internal handshake and login state.""" """Reset internal handshake and login state."""
self._handshake_done = False self._state = TransportState.HANDSHAKE_REQUIRED
self._login_token = None
class AesEncyptionSession: class AesEncyptionSession:

View File

@ -10,7 +10,7 @@ import pytest
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from ..aestransport import AesEncyptionSession, AesTransport from ..aestransport import AesEncyptionSession, AesTransport, TransportState
from ..credentials import Credentials from ..credentials import Credentials
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..exceptions import ( from ..exceptions import (
@ -66,11 +66,11 @@ async def test_handshake(
) )
assert transport._encryption_session is None assert transport._encryption_session is None
assert transport._handshake_done is False assert transport._state is TransportState.HANDSHAKE_REQUIRED
with expectation: with expectation:
await transport.perform_handshake() await transport.perform_handshake()
assert transport._encryption_session is not None assert transport._encryption_session is not None
assert transport._handshake_done is True assert transport._state is TransportState.LOGIN_REQUIRED
@status_parameters @status_parameters
@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
transport = AesTransport( transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar")) config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
) )
transport._handshake_done = True transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session
@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
transport = AesTransport( transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar")) config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
) )
transport._handshake_done = True transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400 transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session transport._encryption_session = mock_aes_device.encryption_session

View File

@ -65,9 +65,16 @@ omit = ["kasa/tests/*"]
[tool.coverage.report] [tool.coverage.report]
exclude_lines = [ exclude_lines = [
# ignore abstract methods # Don't complain if tests don't hit defensive assertion code:
"raise AssertionError",
"raise NotImplementedError", "raise NotImplementedError",
"def __repr__" # Don't complain about missing debug-only code:
"def __repr__",
# Have to re-enable the standard pragma
"pragma: no cover",
# TYPE_CHECKING and @overload blocks are never executed during pytest run
"if TYPE_CHECKING:",
"@overload"
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]