mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Do login entirely within AesTransport (#580)
* Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review
This commit is contained in:
parent
209391c422
commit
20ea6700a5
@ -8,7 +8,7 @@ import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
DEFAULT_TIMEOUT = 5
|
||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||
COMMON_HEADERS = {
|
||||
@ -59,12 +60,16 @@ class AesTransport(BaseTransport):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host)
|
||||
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self._handshake_done = False
|
||||
|
||||
@ -77,7 +82,7 @@ class AesTransport(BaseTransport):
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
self._login_token = None
|
||||
|
||||
_LOGGER.debug("Created AES object for %s", self.host)
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
def hash_credentials(self, login_v2):
|
||||
"""Hash the credentials."""
|
||||
@ -123,7 +128,7 @@ class AesTransport(BaseTransport):
|
||||
if (
|
||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
) != SmartErrorCode.SUCCESS:
|
||||
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
|
||||
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||
if error_code in SMART_TIMEOUT_ERRORS:
|
||||
raise TimeoutException(msg)
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
@ -136,7 +141,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
async def send_secure_passthrough(self, request: str):
|
||||
"""Send encrypted message as passthrough."""
|
||||
url = f"http://{self.host}/app"
|
||||
url = f"http://{self._host}/app"
|
||||
if self._login_token:
|
||||
url += f"?token={self._login_token}"
|
||||
|
||||
@ -150,7 +155,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
if status_code != 200:
|
||||
raise SmartDeviceException(
|
||||
f"{self.host} responded with an unexpected "
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
@ -164,49 +169,31 @@ class AesTransport(BaseTransport):
|
||||
resp_dict = json_loads(response)
|
||||
return resp_dict
|
||||
|
||||
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
||||
async def _perform_login_for_version(self, *, login_version: int = 1):
|
||||
"""Login to the device."""
|
||||
self._login_token = None
|
||||
|
||||
if isinstance(login_request, str):
|
||||
login_request_dict: dict = json_loads(login_request)
|
||||
else:
|
||||
login_request_dict = login_request
|
||||
|
||||
un, pw = self.hash_credentials(login_v2)
|
||||
login_request_dict["params"] = {"password": pw, "username": un}
|
||||
request = json_dumps(login_request_dict)
|
||||
un, pw = self.hash_credentials(login_version == 2)
|
||||
password_field_name = "password2" if login_version == 2 else "password"
|
||||
login_request = {
|
||||
"method": "login_device",
|
||||
"params": {password_field_name: pw, "username": un},
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
}
|
||||
request = json_dumps(login_request)
|
||||
try:
|
||||
resp_dict = await self.send_secure_passthrough(request)
|
||||
except SmartDeviceException as ex:
|
||||
raise AuthenticationException(ex) from ex
|
||||
self._login_token = resp_dict["result"]["token"]
|
||||
|
||||
@property
|
||||
def needs_login(self) -> bool:
|
||||
"""Return true if the transport needs to do a login."""
|
||||
return self._login_token is None
|
||||
|
||||
async def login(self, request: str) -> None:
|
||||
async def perform_login(self) -> None:
|
||||
"""Login to the device."""
|
||||
try:
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to login"
|
||||
)
|
||||
await self.perform_login(request, login_v2=False)
|
||||
await self._perform_login_for_version(login_version=2)
|
||||
except AuthenticationException:
|
||||
_LOGGER.warning("Login version 2 failed, trying version 1")
|
||||
await self.perform_handshake()
|
||||
await self.perform_login(request, login_v2=True)
|
||||
|
||||
@property
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
return not self._handshake_done or self._handshake_session_expired()
|
||||
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
await self.perform_handshake()
|
||||
await self._perform_login_for_version(login_version=1)
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
@ -217,7 +204,7 @@ class AesTransport(BaseTransport):
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
|
||||
url = f"http://{self.host}/app"
|
||||
url = f"http://{self._host}/app"
|
||||
key_pair = KeyPair.create_key_pair()
|
||||
|
||||
pub_key = (
|
||||
@ -238,7 +225,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
if status_code != 200:
|
||||
raise SmartDeviceException(
|
||||
f"{self.host} responded with an unexpected "
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake"
|
||||
)
|
||||
|
||||
@ -261,7 +248,7 @@ class AesTransport(BaseTransport):
|
||||
|
||||
self._handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
@ -272,12 +259,10 @@ class AesTransport(BaseTransport):
|
||||
|
||||
async def send(self, request: str):
|
||||
"""Send the request."""
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to send"
|
||||
)
|
||||
if self.needs_login:
|
||||
raise SmartDeviceException("Login must be complete before trying to send")
|
||||
if not self._handshake_done or self._handshake_session_expired():
|
||||
await self.perform_handshake()
|
||||
if not self._login_token:
|
||||
await self.perform_login()
|
||||
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
|
@ -74,7 +74,12 @@ async def connect(
|
||||
host=host, port=port, credentials=credentials, timeout=timeout
|
||||
)
|
||||
if protocol_class is not None:
|
||||
dev.protocol = protocol_class(host, credentials=credentials)
|
||||
dev.protocol = protocol_class(
|
||||
host,
|
||||
transport=AesTransport(
|
||||
host, port=port, credentials=credentials, timeout=timeout
|
||||
),
|
||||
)
|
||||
await dev.update()
|
||||
if debug_enabled:
|
||||
end_time = time.perf_counter()
|
||||
@ -90,7 +95,13 @@ async def connect(
|
||||
host=host, port=port, credentials=credentials, timeout=timeout
|
||||
)
|
||||
if protocol_class is not None:
|
||||
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
||||
# TODO this will be replaced with connection params
|
||||
unknown_dev.protocol = protocol_class(
|
||||
host,
|
||||
transport=AesTransport(
|
||||
host, port=port, credentials=credentials, timeout=timeout
|
||||
),
|
||||
)
|
||||
await unknown_dev.update()
|
||||
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
|
||||
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||
@ -163,7 +174,5 @@ def get_protocol_from_connection_name(
|
||||
|
||||
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
|
||||
transport: BaseTransport = transport_class(host, credentials=credentials)
|
||||
protocol: TPLinkProtocol = protocol_class(
|
||||
host, credentials=credentials, transport=transport
|
||||
)
|
||||
protocol: TPLinkProtocol = protocol_class(host, transport=transport)
|
||||
return protocol
|
||||
|
@ -1,14 +1,12 @@
|
||||
"""Module for the IOT legacy IOT KASA protocol."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .klaptransport import KlapTransport
|
||||
from .protocol import BaseTransport, TPLinkProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -17,24 +15,14 @@ _LOGGER = logging.getLogger(__name__)
|
||||
class IotProtocol(TPLinkProtocol):
|
||||
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self._credentials: Credentials = credentials or Credentials(
|
||||
username="", password=""
|
||||
)
|
||||
self._transport: BaseTransport = transport or KlapTransport(
|
||||
host, credentials=self._credentials, timeout=timeout
|
||||
)
|
||||
"""Create a protocol object."""
|
||||
super().__init__(host, transport=transport)
|
||||
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
@ -54,30 +42,32 @@ class IotProtocol(TPLinkProtocol):
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
f"Unable to connect to the device: {self._host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
f"Unable to connect to the device, timed out: {self._host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
f"Unable to connect to the device: {self._host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
@ -85,14 +75,6 @@ class IotProtocol(TPLinkProtocol):
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
||||
if self._transport.needs_handshake:
|
||||
await self._transport.handshake()
|
||||
|
||||
if self._transport.needs_login: # This shouln't happen
|
||||
raise SmartDeviceException(
|
||||
"IOT Protocol needs to login to transport but is not login aware"
|
||||
)
|
||||
|
||||
return await self._transport.send(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
@ -82,7 +82,7 @@ class KlapTransport(BaseTransport):
|
||||
protocol, used by newer firmware versions.
|
||||
"""
|
||||
|
||||
DEFAULT_TIMEOUT = 5
|
||||
DEFAULT_PORT = 80
|
||||
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
||||
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
||||
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
|
||||
@ -92,12 +92,17 @@ class KlapTransport(BaseTransport):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host)
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||
@ -110,11 +115,10 @@ class KlapTransport(BaseTransport):
|
||||
self._encryption_session: Optional[KlapEncryptionSession] = None
|
||||
self._session_expire_at: Optional[float] = None
|
||||
|
||||
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
||||
self._session_cookie = None
|
||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||
|
||||
_LOGGER.debug("Created KLAP object for %s", self.host)
|
||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||
|
||||
async def client_post(self, url, params=None, data=None):
|
||||
"""Send an http post request to the device."""
|
||||
@ -148,7 +152,7 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
payload = local_seed
|
||||
|
||||
url = f"http://{self.host}/app/handshake1"
|
||||
url = f"http://{self._host}/app/handshake1"
|
||||
|
||||
response_status, response_data = await self.client_post(url, data=payload)
|
||||
|
||||
@ -157,14 +161,14 @@ class KlapTransport(BaseTransport):
|
||||
"Handshake1 posted at %s. Host is %s, Response"
|
||||
+ "status is %s, Request was %s",
|
||||
datetime.datetime.now(),
|
||||
self.host,
|
||||
self._host,
|
||||
response_status,
|
||||
payload.hex(),
|
||||
)
|
||||
|
||||
if response_status != 200:
|
||||
raise AuthenticationException(
|
||||
f"Device {self.host} responded with {response_status} to handshake1"
|
||||
f"Device {self._host} responded with {response_status} to handshake1"
|
||||
)
|
||||
|
||||
remote_seed: bytes = response_data[0:16]
|
||||
@ -175,7 +179,7 @@ class KlapTransport(BaseTransport):
|
||||
"Handshake1 success at %s. Host is %s, "
|
||||
+ "Server remote_seed is: %s, server hash is: %s",
|
||||
datetime.datetime.now(),
|
||||
self.host,
|
||||
self._host,
|
||||
remote_seed.hex(),
|
||||
server_hash.hex(),
|
||||
)
|
||||
@ -207,7 +211,7 @@ class KlapTransport(BaseTransport):
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s"
|
||||
+ " but an authentication with kasa setup credentials matched",
|
||||
self.host,
|
||||
self._host,
|
||||
)
|
||||
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
|
||||
|
||||
@ -226,11 +230,11 @@ class KlapTransport(BaseTransport):
|
||||
_LOGGER.debug(
|
||||
"Server response doesn't match our expected hash on ip %s"
|
||||
+ " but an authentication with blank credentials matched",
|
||||
self.host,
|
||||
self._host,
|
||||
)
|
||||
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
|
||||
|
||||
msg = f"Server response doesn't match our challenge on ip {self.host}"
|
||||
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
||||
_LOGGER.debug(msg)
|
||||
raise AuthenticationException(msg)
|
||||
|
||||
@ -241,7 +245,7 @@ class KlapTransport(BaseTransport):
|
||||
# Handshake 2 has the following payload:
|
||||
# sha256(serverBytes | authenticator)
|
||||
|
||||
url = f"http://{self.host}/app/handshake2"
|
||||
url = f"http://{self._host}/app/handshake2"
|
||||
|
||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||
|
||||
@ -252,44 +256,24 @@ class KlapTransport(BaseTransport):
|
||||
"Handshake2 posted %s. Host is %s, Response status is %s, "
|
||||
+ "Request was %s",
|
||||
datetime.datetime.now(),
|
||||
self.host,
|
||||
self._host,
|
||||
response_status,
|
||||
payload.hex(),
|
||||
)
|
||||
|
||||
if response_status != 200:
|
||||
raise AuthenticationException(
|
||||
f"Device {self.host} responded with {response_status} to handshake2"
|
||||
f"Device {self._host} responded with {response_status} to handshake2"
|
||||
)
|
||||
|
||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||
|
||||
@property
|
||||
def needs_login(self) -> bool:
|
||||
"""Will return false as KLAP does not do a login."""
|
||||
return False
|
||||
|
||||
async def login(self, request: str) -> None:
|
||||
"""Will raise and exception as KLAP does not do a login."""
|
||||
raise SmartDeviceException(
|
||||
"KLAP does not perform logins and return needs_login == False"
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
return not self._handshake_done or self._handshake_session_expired()
|
||||
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
await self.perform_handshake()
|
||||
|
||||
async def perform_handshake(self) -> Any:
|
||||
"""Perform handshake1 and handshake2.
|
||||
|
||||
Sets the encryption_session if successful.
|
||||
"""
|
||||
_LOGGER.debug("Starting handshake with %s", self.host)
|
||||
_LOGGER.debug("Starting handshake with %s", self._host)
|
||||
self._handshake_done = False
|
||||
self._session_expire_at = None
|
||||
self._session_cookie = None
|
||||
@ -307,7 +291,7 @@ class KlapTransport(BaseTransport):
|
||||
)
|
||||
self._handshake_done = True
|
||||
|
||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
||||
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||
|
||||
def _handshake_session_expired(self):
|
||||
"""Return true if session has expired."""
|
||||
@ -318,18 +302,14 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
async def send(self, request: str):
|
||||
"""Send the request."""
|
||||
if self.needs_handshake:
|
||||
raise SmartDeviceException(
|
||||
"Handshake must be complete before trying to send"
|
||||
)
|
||||
if self.needs_login:
|
||||
raise SmartDeviceException("Login must be complete before trying to send")
|
||||
if not self._handshake_done or self._handshake_session_expired():
|
||||
await self.perform_handshake()
|
||||
|
||||
# Check for mypy
|
||||
if self._encryption_session is not None:
|
||||
payload, seq = self._encryption_session.encrypt(request.encode())
|
||||
|
||||
url = f"http://{self.host}/app/request"
|
||||
url = f"http://{self._host}/app/request"
|
||||
|
||||
response_status, response_data = await self.client_post(
|
||||
url,
|
||||
@ -338,7 +318,7 @@ class KlapTransport(BaseTransport):
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"at {datetime.datetime.now()}. Host is {self.host}, "
|
||||
f"at {datetime.datetime.now()}. Host is {self._host}, "
|
||||
+ f"Sequence is {seq}, "
|
||||
+ f"Response status is {response_status}, Request was {request}"
|
||||
)
|
||||
@ -348,12 +328,12 @@ class KlapTransport(BaseTransport):
|
||||
if response_status == 403:
|
||||
self._handshake_done = False
|
||||
raise AuthenticationException(
|
||||
f"Got a security error from {self.host} after handshake "
|
||||
f"Got a security error from {self._host} after handshake "
|
||||
+ "completed"
|
||||
)
|
||||
else:
|
||||
raise SmartDeviceException(
|
||||
f"Device {self.host} responded with {response_status} to"
|
||||
f"Device {self._host} responded with {response_status} to"
|
||||
+ f"request with seq {seq}"
|
||||
)
|
||||
else:
|
||||
@ -367,7 +347,7 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
self._host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
|
||||
)
|
||||
|
||||
|
108
kasa/protocol.py
108
kasa/protocol.py
@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes:
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for all TP-Link protocol transports."""
|
||||
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_login(self) -> bool:
|
||||
"""Return true if the transport needs to do a login."""
|
||||
|
||||
@abstractmethod
|
||||
async def login(self, request: str) -> None:
|
||||
"""Login to the device."""
|
||||
|
||||
@abstractmethod
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
self._timeout = timeout or self.DEFAULT_TIMEOUT
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, request: str) -> Dict:
|
||||
@ -90,14 +76,14 @@ class TPLinkProtocol(ABC):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
self._transport = transport
|
||||
|
||||
@property
|
||||
def _host(self):
|
||||
return self._transport._host
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
@ -108,6 +94,40 @@ class TPLinkProtocol(ABC):
|
||||
"""Close the protocol. Abstract method to be overriden."""
|
||||
|
||||
|
||||
class _XorTransport(BaseTransport):
|
||||
"""Implementation of the Xor encryption transport.
|
||||
|
||||
WIP, currently only to ensure consistent __init__ method signatures
|
||||
for protocol classes. Will eventually incorporate the logic from
|
||||
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
|
||||
class.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 9999
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def send(self, request: str) -> Dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
|
||||
|
||||
class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
"""Implementation of the TP-Link Smart Home protocol."""
|
||||
|
||||
@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(
|
||||
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
|
||||
)
|
||||
super().__init__(host, transport=transport)
|
||||
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||
|
||||
self._timeout = self._transport._timeout
|
||||
self._port = self._transport._port
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Request information from a TP-Link SmartHome Device.
|
||||
@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count, self.timeout)
|
||||
return await self._query(request, retry_count, self._timeout)
|
||||
|
||||
async def _connect(self, timeout: int) -> None:
|
||||
"""Try to connect or reconnect to the device."""
|
||||
@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
return
|
||||
self.reader = self.writer = None
|
||||
|
||||
task = asyncio.open_connection(self.host, self.port)
|
||||
task = asyncio.open_connection(self._host, self._port)
|
||||
async with asyncio_timeout(timeout):
|
||||
self.reader, self.writer = await task
|
||||
sock: socket.socket = self.writer.get_extra_info("socket")
|
||||
@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
if debug_log:
|
||||
_LOGGER.debug("%s >> %s", self.host, request)
|
||||
_LOGGER.debug("%s >> %s", self._host, request)
|
||||
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
await self.writer.drain()
|
||||
|
||||
@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
||||
json_payload = json_loads(response)
|
||||
if debug_log:
|
||||
_LOGGER.debug("%s << %s", self.host, pf(json_payload))
|
||||
_LOGGER.debug("%s << %s", self._host, pf(json_payload))
|
||||
|
||||
return json_payload
|
||||
|
||||
@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
except ConnectionRefusedError as ex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
|
||||
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except OSError as ex:
|
||||
await self.close()
|
||||
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device:"
|
||||
f" {self.host}:{self.port}: {ex}"
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device:"
|
||||
f" {self.host}:{self.port}: {ex}"
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to query the device {self.host}:{self.port}: {ex}"
|
||||
f"Unable to query the device {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device %s, retrying: %s", self.host, ex
|
||||
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||
)
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
|
@ -24,7 +24,7 @@ from .device_type import DeviceType
|
||||
from .emeterstatus import EmeterStatus
|
||||
from .exceptions import SmartDeviceException
|
||||
from .modules import Emeter, Module
|
||||
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, _XorTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -202,7 +202,7 @@ class SmartDevice:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
|
||||
host, port=port, timeout=timeout
|
||||
host, transport=_XorTransport(host, port=port, timeout=timeout)
|
||||
)
|
||||
self.credentials = credentials
|
||||
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
||||
|
@ -10,12 +10,10 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .aestransport import AesTransport
|
||||
from .credentials import Credentials
|
||||
from .exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False
|
||||
class SmartProtocol(TPLinkProtocol):
|
||||
"""Class for the new TPLink SMART protocol."""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
SLEEP_SECONDS_AFTER_TIMEOUT = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self._credentials: Credentials = credentials or Credentials(
|
||||
username="", password=""
|
||||
)
|
||||
self._transport: BaseTransport = transport or AesTransport(
|
||||
host, credentials=self._credentials, timeout=timeout
|
||||
)
|
||||
self._terminal_uuid: Optional[str] = None
|
||||
"""Create a protocol object."""
|
||||
super().__init__(host, transport=transport)
|
||||
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
||||
self._request_id_generator = SnowflakeId(1, 1)
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol):
|
||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||
) != SmartErrorCode.SUCCESS:
|
||||
msg = (
|
||||
f"Error querying device: {self.host}: "
|
||||
f"Error querying device: {self._host}: "
|
||||
+ f"{error_code.name}({error_code.value})"
|
||||
)
|
||||
if error_code in SMART_TIMEOUT_ERRORS:
|
||||
@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol):
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
f"Unable to connect to the device: {self._host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
"Unable to connect to the device, "
|
||||
+ f"timed out: {self.host}: {tex}"
|
||||
+ f"timed out: {self._host}: {tex}"
|
||||
) from tex
|
||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except AuthenticationException as auex:
|
||||
await self.close()
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except RetryableException as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutException as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except Exception as ex:
|
||||
if retry >= retry_count:
|
||||
await self.close()
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to query the device {self.host}:{self.port}: {ex}"
|
||||
f"Unable to connect to the device: {self._host}: {ex}"
|
||||
) from ex
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device %s, retrying: %s", self.host, ex
|
||||
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||
)
|
||||
continue
|
||||
|
||||
@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol):
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
if self._transport.needs_handshake:
|
||||
await self._transport.handshake()
|
||||
|
||||
if self._transport.needs_login:
|
||||
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
|
||||
"UTF-8"
|
||||
)
|
||||
login_request = self.get_smart_request("login_device")
|
||||
await self._transport.login(login_request)
|
||||
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self.host,
|
||||
self._host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
|
||||
)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
self._host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import AuthenticationException
|
||||
from ..smartdevice import SmartDevice
|
||||
@ -27,7 +28,12 @@ class TapoDevice(SmartDevice):
|
||||
self._components: Optional[Dict[str, Any]] = None
|
||||
self._state_information: Dict[str, Any] = {}
|
||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
|
||||
self.protocol = SmartProtocol(
|
||||
host,
|
||||
transport=AesTransport(
|
||||
host, credentials=credentials, timeout=timeout, port=port
|
||||
),
|
||||
)
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
|
@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol):
|
||||
|
||||
class FakeSmartTransport(BaseTransport):
|
||||
def __init__(self, info):
|
||||
super().__init__(
|
||||
"127.0.0.123",
|
||||
)
|
||||
self.info = info
|
||||
|
||||
@property
|
||||
|
174
kasa/tests/test_aestransport.py
Normal file
174
kasa/tests/test_aestransport.py
Normal file
@ -0,0 +1,174 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||
|
||||
from ..aestransport import AesEncyptionSession, AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..exceptions import SmartDeviceException
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
|
||||
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
|
||||
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
|
||||
KEY_IV = key + iv
|
||||
|
||||
|
||||
def test_encrypt():
|
||||
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||
|
||||
d = json.dumps({"foo": 1, "bar": 2})
|
||||
encrypted = encryption_session.encrypt(d.encode())
|
||||
assert d == encryption_session.decrypt(encrypted)
|
||||
|
||||
# test encrypt unicode
|
||||
d = "{'snowman': '\u2603'}"
|
||||
encrypted = encryption_session.encrypt(d.encode())
|
||||
assert d == encryption_session.decrypt(encrypted)
|
||||
|
||||
|
||||
status_parameters = pytest.mark.parametrize(
|
||||
"status_code, error_code, inner_error_code, expectation",
|
||||
[
|
||||
(200, 0, 0, does_not_raise()),
|
||||
(400, 0, 0, pytest.raises(SmartDeviceException)),
|
||||
(200, -1, 0, pytest.raises(SmartDeviceException)),
|
||||
],
|
||||
ids=("success", "status_code", "error_code"),
|
||||
)
|
||||
|
||||
|
||||
@status_parameters
|
||||
async def test_handshake(
|
||||
mocker, status_code, error_code, inner_error_code, expectation
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
|
||||
assert transport._encryption_session is None
|
||||
assert transport._handshake_done is False
|
||||
with expectation:
|
||||
await transport.perform_handshake()
|
||||
assert transport._encryption_session is not None
|
||||
assert transport._handshake_done is True
|
||||
|
||||
|
||||
@status_parameters
|
||||
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
|
||||
assert transport._login_token is None
|
||||
with expectation:
|
||||
await transport.perform_login()
|
||||
assert transport._login_token == mock_aes_device.token
|
||||
|
||||
|
||||
@status_parameters
|
||||
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
|
||||
host = "127.0.0.1"
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._login_token = mock_aes_device.token
|
||||
|
||||
un, pw = transport.hash_credentials(True)
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
"params": None,
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"requestID": 1,
|
||||
"terminal_uuid": "foobar",
|
||||
}
|
||||
with expectation:
|
||||
res = await transport.send(json_dumps(request))
|
||||
assert "result" in res
|
||||
|
||||
|
||||
class MockAesDevice:
|
||||
class _mock_response:
|
||||
def __init__(self, status_code, json: dict):
|
||||
self.status_code = status_code
|
||||
self._json = json
|
||||
|
||||
def json(self):
|
||||
return self._json
|
||||
|
||||
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||
token = "test_token" # noqa
|
||||
|
||||
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
|
||||
self.host = host
|
||||
self.status_code = status_code
|
||||
self.error_code = error_code
|
||||
self.inner_error_code = inner_error_code
|
||||
|
||||
async def post(self, url, params=None, json=None, *_, **__):
|
||||
return await self._post(url, json)
|
||||
|
||||
async def _post(self, url, json):
|
||||
if json["method"] == "handshake":
|
||||
return await self._return_handshake_response(url, json)
|
||||
elif json["method"] == "securePassthrough":
|
||||
return await self._return_secure_passthrough_response(url, json)
|
||||
elif json["method"] == "login_device":
|
||||
return await self._return_login_response(url, json)
|
||||
else:
|
||||
assert url == f"http://{self.host}/app?token={self.token}"
|
||||
return await self._return_send_response(url, json)
|
||||
|
||||
async def _return_handshake_response(self, url, json):
|
||||
start = len("-----BEGIN PUBLIC KEY-----\n")
|
||||
end = len("\n-----END PUBLIC KEY-----\n")
|
||||
client_pub_key = json["params"]["key"][start:-end]
|
||||
|
||||
client_pub_key_data = base64.b64decode(client_pub_key.encode())
|
||||
client_pub_key = serialization.load_der_public_key(client_pub_key_data, None)
|
||||
encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15())
|
||||
key_64 = base64.b64encode(encrypted_key).decode()
|
||||
return self._mock_response(
|
||||
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
||||
)
|
||||
|
||||
async def _return_secure_passthrough_response(self, url, json):
|
||||
encrypted_request = json["params"]["request"]
|
||||
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||
decrypted_request_dict = json_loads(decrypted_request)
|
||||
decrypted_response = await self._post(url, decrypted_request_dict)
|
||||
decrypted_response_dict = decrypted_response.json()
|
||||
encrypted_response = self.encryption_session.encrypt(
|
||||
json_dumps(decrypted_response_dict).encode()
|
||||
)
|
||||
result = {
|
||||
"result": {"response": encrypted_response.decode()},
|
||||
"error_code": self.error_code,
|
||||
}
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
||||
async def _return_login_response(self, url, json):
|
||||
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
||||
return self._mock_response(self.status_code, result)
|
||||
|
||||
async def _return_send_response(self, url, json):
|
||||
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
||||
return self._mock_response(self.status_code, result)
|
@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
||||
|
||||
return mock_response
|
||||
|
||||
mocker.patch.object(
|
||||
transport_class, "needs_handshake", property(lambda self: False)
|
||||
)
|
||||
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
|
||||
mocker.patch.object(transport_class, "perform_handshake")
|
||||
if hasattr(transport_class, "perform_login"):
|
||||
mocker.patch.object(transport_class, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(
|
||||
transport_class,
|
||||
@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = IotProtocol("127.0.0.1")
|
||||
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
|
||||
|
||||
protocol._transport._handshake_done = True
|
||||
protocol._transport._session_expire_at = time.time() + 86400
|
||||
@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
with expectation:
|
||||
@ -243,7 +245,10 @@ async def test_handshake(mocker):
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
|
||||
response_status = 200
|
||||
@ -289,7 +294,10 @@ async def test_query(mocker):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
resp = await protocol.query({})
|
||||
@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
|
||||
with expectation:
|
||||
await protocol.query({})
|
||||
|
@ -1,13 +1,21 @@
|
||||
import errno
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import struct
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..protocol import TPLinkSmartHomeProtocol
|
||||
from ..protocol import (
|
||||
BaseTransport,
|
||||
TPLinkProtocol,
|
||||
TPLinkSmartHomeProtocol,
|
||||
_XorTransport,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||
@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count):
|
||||
|
||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=retry_count)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker):
|
||||
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker):
|
||||
side_effect=ConnectionRefusedError,
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker):
|
||||
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
|
||||
assert conn.call_count == 6
|
||||
|
||||
@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port)
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
|
||||
)
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@ -227,3 +249,63 @@ def test_decrypt_unicode():
|
||||
d = "{'snowman': '\u2603'}"
|
||||
|
||||
assert d == TPLinkSmartHomeProtocol.decrypt(e)
|
||||
|
||||
|
||||
def _get_subclasses(of_class):
|
||||
import kasa
|
||||
|
||||
package = sys.modules["kasa"]
|
||||
subclasses = set()
|
||||
for _, modname, _ in pkgutil.iter_modules(package.__path__):
|
||||
importlib.import_module("." + modname, package="kasa")
|
||||
module = sys.modules["kasa." + modname]
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, of_class):
|
||||
subclasses.add((name, obj))
|
||||
return subclasses
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0]
|
||||
)
|
||||
def test_protocol_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 3
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[2].name == "transport"
|
||||
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0]
|
||||
)
|
||||
def test_transport_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 5
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
assert (
|
||||
params[3].name == "credentials"
|
||||
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
assert (
|
||||
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice):
|
||||
async def test_create_smart_device_with_timeout():
|
||||
"""Make sure timeout is passed to the protocol."""
|
||||
dev = SmartDevice(host="127.0.0.1", timeout=100)
|
||||
assert dev.protocol.timeout == 100
|
||||
assert dev.protocol._transport._timeout == 100
|
||||
|
||||
|
||||
async def test_create_thin_wrapper():
|
||||
|
Loading…
Reference in New Issue
Block a user