mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Refactor aestransport to use a state enum (#691)
This commit is contained in:
parent
3f40410db3
commit
24c645746e
@ -8,7 +8,8 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
|
from enum import Enum, auto
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import padding, serialization
|
from cryptography.hazmat.primitives import padding, serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
@ -41,6 +42,14 @@ def _sha1(payload: bytes) -> str:
|
|||||||
return sha1_algo.hexdigest()
|
return sha1_algo.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class TransportState(Enum):
|
||||||
|
"""Enum for AES state."""
|
||||||
|
|
||||||
|
HANDSHAKE_REQUIRED = auto() # Handshake needed
|
||||||
|
LOGIN_REQUIRED = auto() # Login needed
|
||||||
|
ESTABLISHED = auto() # Ready to send requests
|
||||||
|
|
||||||
|
|
||||||
class AesTransport(BaseTransport):
|
class AesTransport(BaseTransport):
|
||||||
"""Implementation of the AES encryption protocol.
|
"""Implementation of the AES encryption protocol.
|
||||||
|
|
||||||
@ -79,21 +88,21 @@ class AesTransport(BaseTransport):
|
|||||||
self._default_credentials: Optional[Credentials] = None
|
self._default_credentials: Optional[Credentials] = None
|
||||||
self._http_client: HttpClient = HttpClient(config)
|
self._http_client: HttpClient = HttpClient(config)
|
||||||
|
|
||||||
self._handshake_done = False
|
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||||
|
|
||||||
self._encryption_session: Optional[AesEncyptionSession] = None
|
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||||
self._session_expire_at: Optional[float] = None
|
self._session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
self._session_cookie: Optional[Dict[str, str]] = None
|
self._session_cookie: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
self._login_token = None
|
self._login_token: Optional[str] = None
|
||||||
|
|
||||||
self._key_pair: Optional[KeyPair] = None
|
self._key_pair: Optional[KeyPair] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_port(self):
|
def default_port(self) -> int:
|
||||||
"""Default port for the transport."""
|
"""Default port for the transport."""
|
||||||
return self.DEFAULT_PORT
|
return self.DEFAULT_PORT
|
||||||
|
|
||||||
@ -102,30 +111,25 @@ class AesTransport(BaseTransport):
|
|||||||
"""The hashed credentials used by the transport."""
|
"""The hashed credentials used by the transport."""
|
||||||
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||||
|
|
||||||
def _get_login_params(self, credentials):
|
def _get_login_params(self, credentials: Credentials) -> Dict[str, str]:
|
||||||
"""Get the login parameters based on the login_version."""
|
"""Get the login parameters based on the login_version."""
|
||||||
un, pw = self.hash_credentials(self._login_version == 2, credentials)
|
un, pw = self.hash_credentials(self._login_version == 2, credentials)
|
||||||
password_field_name = "password2" if self._login_version == 2 else "password"
|
password_field_name = "password2" if self._login_version == 2 else "password"
|
||||||
return {password_field_name: pw, "username": un}
|
return {password_field_name: pw, "username": un}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash_credentials(login_v2, credentials):
|
def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]:
|
||||||
"""Hash the credentials."""
|
"""Hash the credentials."""
|
||||||
|
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
|
||||||
if login_v2:
|
if login_v2:
|
||||||
un = base64.b64encode(
|
|
||||||
_sha1(credentials.username.encode()).encode()
|
|
||||||
).decode()
|
|
||||||
pw = base64.b64encode(
|
pw = base64.b64encode(
|
||||||
_sha1(credentials.password.encode()).encode()
|
_sha1(credentials.password.encode()).encode()
|
||||||
).decode()
|
).decode()
|
||||||
else:
|
else:
|
||||||
un = base64.b64encode(
|
|
||||||
_sha1(credentials.username.encode()).encode()
|
|
||||||
).decode()
|
|
||||||
pw = base64.b64encode(credentials.password.encode()).decode()
|
pw = base64.b64encode(credentials.password.encode()).decode()
|
||||||
return un, pw
|
return un, pw
|
||||||
|
|
||||||
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
|
||||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
if error_code == SmartErrorCode.SUCCESS:
|
if error_code == SmartErrorCode.SUCCESS:
|
||||||
return
|
return
|
||||||
@ -135,12 +139,11 @@ class AesTransport(BaseTransport):
|
|||||||
if error_code in SMART_RETRYABLE_ERRORS:
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
raise RetryableException(msg, error_code=error_code)
|
raise RetryableException(msg, error_code=error_code)
|
||||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||||
self._handshake_done = False
|
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||||
self._login_token = None
|
|
||||||
raise AuthenticationException(msg, error_code=error_code)
|
raise AuthenticationException(msg, error_code=error_code)
|
||||||
raise SmartDeviceException(msg, error_code=error_code)
|
raise SmartDeviceException(msg, error_code=error_code)
|
||||||
|
|
||||||
async def send_secure_passthrough(self, request: str):
|
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
url = f"http://{self._host}/app"
|
url = f"http://{self._host}/app"
|
||||||
if self._login_token:
|
if self._login_token:
|
||||||
@ -165,16 +168,17 @@ class AesTransport(BaseTransport):
|
|||||||
+ f"status code {status_code} to passthrough"
|
+ f"status code {status_code} to passthrough"
|
||||||
)
|
)
|
||||||
|
|
||||||
resp_dict = cast(Dict, resp_dict)
|
|
||||||
self._handle_response_error_code(
|
self._handle_response_error_code(
|
||||||
resp_dict, "Error sending secure_passthrough message"
|
resp_dict, "Error sending secure_passthrough message"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self._encryption_session.decrypt( # type: ignore
|
if TYPE_CHECKING:
|
||||||
resp_dict["result"]["response"].encode()
|
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||||
)
|
assert self._encryption_session is not None
|
||||||
resp_dict = json_loads(response)
|
|
||||||
return resp_dict
|
raw_response: str = resp_dict["result"]["response"]
|
||||||
|
response = self._encryption_session.decrypt(raw_response.encode())
|
||||||
|
return json_loads(response) # type: ignore[return-value]
|
||||||
|
|
||||||
async def perform_login(self):
|
async def perform_login(self):
|
||||||
"""Login to the device."""
|
"""Login to the device."""
|
||||||
@ -182,7 +186,7 @@ class AesTransport(BaseTransport):
|
|||||||
await self.try_login(self._login_params)
|
await self.try_login(self._login_params)
|
||||||
except AuthenticationException as aex:
|
except AuthenticationException as aex:
|
||||||
try:
|
try:
|
||||||
if aex.error_code != SmartErrorCode.LOGIN_ERROR:
|
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
|
||||||
raise aex
|
raise aex
|
||||||
if self._default_credentials is None:
|
if self._default_credentials is None:
|
||||||
self._default_credentials = get_default_credentials(
|
self._default_credentials = get_default_credentials(
|
||||||
@ -203,9 +207,8 @@ class AesTransport(BaseTransport):
|
|||||||
ex,
|
ex,
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
async def try_login(self, login_params):
|
async def try_login(self, login_params: Dict[str, Any]) -> None:
|
||||||
"""Try to login with supplied login_params."""
|
"""Try to login with supplied login_params."""
|
||||||
self._login_token = None
|
|
||||||
login_request = {
|
login_request = {
|
||||||
"method": "login_device",
|
"method": "login_device",
|
||||||
"params": login_params,
|
"params": login_params,
|
||||||
@ -216,6 +219,7 @@ class AesTransport(BaseTransport):
|
|||||||
resp_dict = await self.send_secure_passthrough(request)
|
resp_dict = await self.send_secure_passthrough(request)
|
||||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||||
self._login_token = resp_dict["result"]["token"]
|
self._login_token = resp_dict["result"]["token"]
|
||||||
|
self._state = TransportState.ESTABLISHED
|
||||||
|
|
||||||
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
||||||
"""Generate the request body and return an ascyn_generator.
|
"""Generate the request body and return an ascyn_generator.
|
||||||
@ -236,12 +240,11 @@ class AesTransport(BaseTransport):
|
|||||||
_LOGGER.debug(f"Request {request_body}")
|
_LOGGER.debug(f"Request {request_body}")
|
||||||
yield json_dumps(request_body).encode()
|
yield json_dumps(request_body).encode()
|
||||||
|
|
||||||
async def perform_handshake(self):
|
async def perform_handshake(self) -> None:
|
||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
_LOGGER.debug("Will perform handshaking...")
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
|
||||||
self._key_pair = None
|
self._key_pair = None
|
||||||
self._handshake_done = False
|
|
||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
@ -258,7 +261,7 @@ class AesTransport(BaseTransport):
|
|||||||
cookies_dict=self._session_cookie,
|
cookies_dict=self._session_cookie,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
_LOGGER.debug("Device responded with: %s", resp_dict)
|
||||||
|
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
@ -268,6 +271,9 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||||
|
|
||||||
handshake_key = resp_dict["result"]["key"]
|
handshake_key = resp_dict["result"]["key"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -283,12 +289,12 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
self._session_expire_at = time.time() + 86400
|
self._session_expire_at = time.time() + 86400
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._key_pair is not None # pragma: no cover
|
assert self._key_pair is not None
|
||||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||||
handshake_key, self._key_pair
|
handshake_key, self._key_pair
|
||||||
)
|
)
|
||||||
|
|
||||||
self._handshake_done = True
|
self._state = TransportState.LOGIN_REQUIRED
|
||||||
|
|
||||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||||
|
|
||||||
@ -299,17 +305,20 @@ class AesTransport(BaseTransport):
|
|||||||
or self._session_expire_at - time.time() <= 0
|
or self._session_expire_at - time.time() <= 0
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send(self, request: str):
|
async def send(self, request: str) -> Dict[str, Any]:
|
||||||
"""Send the request."""
|
"""Send the request."""
|
||||||
if not self._handshake_done or self._handshake_session_expired():
|
if (
|
||||||
|
self._state is TransportState.HANDSHAKE_REQUIRED
|
||||||
|
or self._handshake_session_expired()
|
||||||
|
):
|
||||||
await self.perform_handshake()
|
await self.perform_handshake()
|
||||||
if not self._login_token:
|
if self._state is not TransportState.ESTABLISHED:
|
||||||
try:
|
try:
|
||||||
await self.perform_login()
|
await self.perform_login()
|
||||||
# After a login failure handshake needs to
|
# After a login failure handshake needs to
|
||||||
# be redone or a 9999 error is received.
|
# be redone or a 9999 error is received.
|
||||||
except AuthenticationException as ex:
|
except AuthenticationException as ex:
|
||||||
self._handshake_done = False
|
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
return await self.send_secure_passthrough(request)
|
return await self.send_secure_passthrough(request)
|
||||||
@ -321,8 +330,7 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
"""Reset internal handshake and login state."""
|
"""Reset internal handshake and login state."""
|
||||||
self._handshake_done = False
|
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||||
self._login_token = None
|
|
||||||
|
|
||||||
|
|
||||||
class AesEncyptionSession:
|
class AesEncyptionSession:
|
||||||
|
@ -10,7 +10,7 @@ import pytest
|
|||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
|
|
||||||
from ..aestransport import AesEncyptionSession, AesTransport
|
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
from ..deviceconfig import DeviceConfig
|
from ..deviceconfig import DeviceConfig
|
||||||
from ..exceptions import (
|
from ..exceptions import (
|
||||||
@ -66,11 +66,11 @@ async def test_handshake(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert transport._encryption_session is None
|
assert transport._encryption_session is None
|
||||||
assert transport._handshake_done is False
|
assert transport._state is TransportState.HANDSHAKE_REQUIRED
|
||||||
with expectation:
|
with expectation:
|
||||||
await transport.perform_handshake()
|
await transport.perform_handshake()
|
||||||
assert transport._encryption_session is not None
|
assert transport._encryption_session is not None
|
||||||
assert transport._handshake_done is True
|
assert transport._state is TransportState.LOGIN_REQUIRED
|
||||||
|
|
||||||
|
|
||||||
@status_parameters
|
@status_parameters
|
||||||
@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
|||||||
transport = AesTransport(
|
transport = AesTransport(
|
||||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||||
)
|
)
|
||||||
transport._handshake_done = True
|
transport._state = TransportState.LOGIN_REQUIRED
|
||||||
transport._session_expire_at = time.time() + 86400
|
transport._session_expire_at = time.time() + 86400
|
||||||
transport._encryption_session = mock_aes_device.encryption_session
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
|
|||||||
transport = AesTransport(
|
transport = AesTransport(
|
||||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||||
)
|
)
|
||||||
transport._handshake_done = True
|
transport._state = TransportState.LOGIN_REQUIRED
|
||||||
transport._session_expire_at = time.time() + 86400
|
transport._session_expire_at = time.time() + 86400
|
||||||
transport._encryption_session = mock_aes_device.encryption_session
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
|
||||||
|
@ -65,9 +65,16 @@ omit = ["kasa/tests/*"]
|
|||||||
|
|
||||||
[tool.coverage.report]
|
[tool.coverage.report]
|
||||||
exclude_lines = [
|
exclude_lines = [
|
||||||
# ignore abstract methods
|
# Don't complain if tests don't hit defensive assertion code:
|
||||||
|
"raise AssertionError",
|
||||||
"raise NotImplementedError",
|
"raise NotImplementedError",
|
||||||
"def __repr__"
|
# Don't complain about missing debug-only code:
|
||||||
|
"def __repr__",
|
||||||
|
# Have to re-enable the standard pragma
|
||||||
|
"pragma: no cover",
|
||||||
|
# TYPE_CHECKING and @overload blocks are never executed during pytest run
|
||||||
|
"if TYPE_CHECKING:",
|
||||||
|
"@overload"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
Loading…
Reference in New Issue
Block a user