Do login entirely within AesTransport (#580)

* Do login entirely within AesTransport

* Remove login and handshake attributes from BaseTransport

* Add AesTransport tests

* Synchronise transport and protocol __init__ signatures and rename internal variables

* Update after review
This commit is contained in:
sdb9696 2023-12-19 14:11:59 +00:00 committed by GitHub
parent 209391c422
commit 20ea6700a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 468 additions and 237 deletions

View File

@ -8,7 +8,7 @@ import base64
import hashlib import hashlib
import logging import logging
import time import time
from typing import Optional, Union from typing import Optional
import httpx import httpx
from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives import padding, serialization
@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
protocol, sometimes used by newer firmware versions on kasa devices. protocol, sometimes used by newer firmware versions on kasa devices.
""" """
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5 DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID" SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = { COMMON_HEADERS = {
@ -59,12 +60,16 @@ class AesTransport(BaseTransport):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host) super().__init__(
host,
self._credentials = credentials or Credentials(username="", password="") port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
self._handshake_done = False self._handshake_done = False
@ -77,7 +82,7 @@ class AesTransport(BaseTransport):
self._http_client: httpx.AsyncClient = httpx.AsyncClient() self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None self._login_token = None
_LOGGER.debug("Created AES object for %s", self.host) _LOGGER.debug("Created AES transport for %s", self._host)
def hash_credentials(self, login_v2): def hash_credentials(self, login_v2):
"""Hash the credentials.""" """Hash the credentials."""
@ -123,7 +128,7 @@ class AesTransport(BaseTransport):
if ( if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS: ) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})" msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS: if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg) raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS: if error_code in SMART_RETRYABLE_ERRORS:
@ -136,7 +141,7 @@ class AesTransport(BaseTransport):
async def send_secure_passthrough(self, request: str): async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough.""" """Send encrypted message as passthrough."""
url = f"http://{self.host}/app" url = f"http://{self._host}/app"
if self._login_token: if self._login_token:
url += f"?token={self._login_token}" url += f"?token={self._login_token}"
@ -150,7 +155,7 @@ class AesTransport(BaseTransport):
if status_code != 200: if status_code != 200:
raise SmartDeviceException( raise SmartDeviceException(
f"{self.host} responded with an unexpected " f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough" + f"status code {status_code} to passthrough"
) )
@ -164,49 +169,31 @@ class AesTransport(BaseTransport):
resp_dict = json_loads(response) resp_dict = json_loads(response)
return resp_dict return resp_dict
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool): async def _perform_login_for_version(self, *, login_version: int = 1):
"""Login to the device.""" """Login to the device."""
self._login_token = None self._login_token = None
un, pw = self.hash_credentials(login_version == 2)
if isinstance(login_request, str): password_field_name = "password2" if login_version == 2 else "password"
login_request_dict: dict = json_loads(login_request) login_request = {
else: "method": "login_device",
login_request_dict = login_request "params": {password_field_name: pw, "username": un},
"request_time_milis": round(time.time() * 1000),
un, pw = self.hash_credentials(login_v2) }
login_request_dict["params"] = {"password": pw, "username": un} request = json_dumps(login_request)
request = json_dumps(login_request_dict)
try: try:
resp_dict = await self.send_secure_passthrough(request) resp_dict = await self.send_secure_passthrough(request)
except SmartDeviceException as ex: except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex raise AuthenticationException(ex) from ex
self._login_token = resp_dict["result"]["token"] self._login_token = resp_dict["result"]["token"]
@property async def perform_login(self) -> None:
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
return self._login_token is None
async def login(self, request: str) -> None:
"""Login to the device.""" """Login to the device."""
try: try:
if self.needs_handshake: await self._perform_login_for_version(login_version=2)
raise SmartDeviceException(
"Handshake must be complete before trying to login"
)
await self.perform_login(request, login_v2=False)
except AuthenticationException: except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake() await self.perform_handshake()
await self.perform_login(request, login_v2=True) await self._perform_login_for_version(login_version=1)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self): async def perform_handshake(self):
"""Perform the handshake.""" """Perform the handshake."""
@ -217,7 +204,7 @@ class AesTransport(BaseTransport):
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None
url = f"http://{self.host}/app" url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair() key_pair = KeyPair.create_key_pair()
pub_key = ( pub_key = (
@ -238,7 +225,7 @@ class AesTransport(BaseTransport):
if status_code != 200: if status_code != 200:
raise SmartDeviceException( raise SmartDeviceException(
f"{self.host} responded with an unexpected " f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake" + f"status code {status_code} to handshake"
) )
@ -261,7 +248,7 @@ class AesTransport(BaseTransport):
self._handshake_done = True self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host) _LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self): def _handshake_session_expired(self):
"""Return true if session has expired.""" """Return true if session has expired."""
@ -272,12 +259,10 @@ class AesTransport(BaseTransport):
async def send(self, request: str): async def send(self, request: str):
"""Send the request.""" """Send the request."""
if self.needs_handshake: if not self._handshake_done or self._handshake_session_expired():
raise SmartDeviceException( await self.perform_handshake()
"Handshake must be complete before trying to send" if not self._login_token:
) await self.perform_login()
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
return await self.send_secure_passthrough(request) return await self.send_secure_passthrough(request)

View File

@ -74,7 +74,12 @@ async def connect(
host=host, port=port, credentials=credentials, timeout=timeout host=host, port=port, credentials=credentials, timeout=timeout
) )
if protocol_class is not None: if protocol_class is not None:
dev.protocol = protocol_class(host, credentials=credentials) dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await dev.update() await dev.update()
if debug_enabled: if debug_enabled:
end_time = time.perf_counter() end_time = time.perf_counter()
@ -90,7 +95,13 @@ async def connect(
host=host, port=port, credentials=credentials, timeout=timeout host=host, port=port, credentials=credentials, timeout=timeout
) )
if protocol_class is not None: if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials) # TODO this will be replaced with connection params
unknown_dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await unknown_dev.update() await unknown_dev.update()
device_class = get_device_class_from_sys_info(unknown_dev.internal_state) device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
@ -163,7 +174,5 @@ def get_protocol_from_connection_name(
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
transport: BaseTransport = transport_class(host, credentials=credentials) transport: BaseTransport = transport_class(host, credentials=credentials)
protocol: TPLinkProtocol = protocol_class( protocol: TPLinkProtocol = protocol_class(host, transport=transport)
host, credentials=credentials, transport=transport
)
return protocol return protocol

View File

@ -1,14 +1,12 @@
"""Module for the IOT legacy IOT KASA protocol.""" """Module for the IOT legacy IOT KASA protocol."""
import asyncio import asyncio
import logging import logging
from typing import Dict, Optional, Union from typing import Dict, Union
import httpx import httpx
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps from .json import dumps as json_dumps
from .klaptransport import KlapTransport
from .protocol import BaseTransport, TPLinkProtocol from .protocol import BaseTransport, TPLinkProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -17,24 +15,14 @@ _LOGGER = logging.getLogger(__name__)
class IotProtocol(TPLinkProtocol): class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol.""" """Class for the legacy TPLink IOT KASA Protocol."""
DEFAULT_PORT = 80
def __init__( def __init__(
self, self,
host: str, host: str,
*, *,
transport: Optional[BaseTransport] = None, transport: BaseTransport,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT) """Create a protocol object."""
super().__init__(host, transport=transport)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or KlapTransport(
host, credentials=self._credentials, timeout=timeout
)
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
@ -54,30 +42,32 @@ class IotProtocol(TPLinkProtocol):
except httpx.CloseError as sdex: except httpx.CloseError as sdex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}" f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex ) from sdex
continue continue
except httpx.ConnectError as cex: except httpx.ConnectError as cex:
await self.close() await self.close()
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}" f"Unable to connect to the device: {self._host}: {cex}"
) from cex ) from cex
except TimeoutError as tex: except TimeoutError as tex:
await self.close() await self.close()
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}" f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex ) from tex
except AuthenticationException as auex: except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) _LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex raise auex
except Exception as ex: except Exception as ex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}" f"Unable to connect to the device: {self._host}: {ex}"
) from ex ) from ex
continue continue
@ -85,14 +75,6 @@ class IotProtocol(TPLinkProtocol):
raise SmartDeviceException("Query reached somehow to unreachable") raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> Dict: async def _execute_query(self, request: str, retry_count: int) -> Dict:
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login: # This shouln't happen
raise SmartDeviceException(
"IOT Protocol needs to login to transport but is not login aware"
)
return await self._transport.send(request) return await self._transport.send(request)
async def close(self) -> None: async def close(self) -> None:

View File

@ -82,7 +82,7 @@ class KlapTransport(BaseTransport):
protocol, used by newer firmware versions. protocol, used by newer firmware versions.
""" """
DEFAULT_TIMEOUT = 5 DEFAULT_PORT = 80
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net" KASA_SETUP_EMAIL = "kasa@tp-link.net"
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
@ -92,12 +92,17 @@ class KlapTransport(BaseTransport):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host) super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
self._credentials = credentials or Credentials(username="", password="")
self._local_seed: Optional[bytes] = None self._local_seed: Optional[bytes] = None
self._local_auth_hash = self.generate_auth_hash(self._credentials) self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
@ -110,11 +115,10 @@ class KlapTransport(BaseTransport):
self._encryption_session: Optional[KlapEncryptionSession] = None self._encryption_session: Optional[KlapEncryptionSession] = None
self._session_expire_at: Optional[float] = None self._session_expire_at: Optional[float] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient() self._http_client: httpx.AsyncClient = httpx.AsyncClient()
_LOGGER.debug("Created KLAP object for %s", self.host) _LOGGER.debug("Created KLAP transport for %s", self._host)
async def client_post(self, url, params=None, data=None): async def client_post(self, url, params=None, data=None):
"""Send an http post request to the device.""" """Send an http post request to the device."""
@ -148,7 +152,7 @@ class KlapTransport(BaseTransport):
payload = local_seed payload = local_seed
url = f"http://{self.host}/app/handshake1" url = f"http://{self._host}/app/handshake1"
response_status, response_data = await self.client_post(url, data=payload) response_status, response_data = await self.client_post(url, data=payload)
@ -157,14 +161,14 @@ class KlapTransport(BaseTransport):
"Handshake1 posted at %s. Host is %s, Response" "Handshake1 posted at %s. Host is %s, Response"
+ "status is %s, Request was %s", + "status is %s, Request was %s",
datetime.datetime.now(), datetime.datetime.now(),
self.host, self._host,
response_status, response_status,
payload.hex(), payload.hex(),
) )
if response_status != 200: if response_status != 200:
raise AuthenticationException( raise AuthenticationException(
f"Device {self.host} responded with {response_status} to handshake1" f"Device {self._host} responded with {response_status} to handshake1"
) )
remote_seed: bytes = response_data[0:16] remote_seed: bytes = response_data[0:16]
@ -175,7 +179,7 @@ class KlapTransport(BaseTransport):
"Handshake1 success at %s. Host is %s, " "Handshake1 success at %s. Host is %s, "
+ "Server remote_seed is: %s, server hash is: %s", + "Server remote_seed is: %s, server hash is: %s",
datetime.datetime.now(), datetime.datetime.now(),
self.host, self._host,
remote_seed.hex(), remote_seed.hex(),
server_hash.hex(), server_hash.hex(),
) )
@ -207,7 +211,7 @@ class KlapTransport(BaseTransport):
_LOGGER.debug( _LOGGER.debug(
"Server response doesn't match our expected hash on ip %s" "Server response doesn't match our expected hash on ip %s"
+ " but an authentication with kasa setup credentials matched", + " but an authentication with kasa setup credentials matched",
self.host, self._host,
) )
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
@ -226,11 +230,11 @@ class KlapTransport(BaseTransport):
_LOGGER.debug( _LOGGER.debug(
"Server response doesn't match our expected hash on ip %s" "Server response doesn't match our expected hash on ip %s"
+ " but an authentication with blank credentials matched", + " but an authentication with blank credentials matched",
self.host, self._host,
) )
return local_seed, remote_seed, self._blank_auth_hash # type: ignore return local_seed, remote_seed, self._blank_auth_hash # type: ignore
msg = f"Server response doesn't match our challenge on ip {self.host}" msg = f"Server response doesn't match our challenge on ip {self._host}"
_LOGGER.debug(msg) _LOGGER.debug(msg)
raise AuthenticationException(msg) raise AuthenticationException(msg)
@ -241,7 +245,7 @@ class KlapTransport(BaseTransport):
# Handshake 2 has the following payload: # Handshake 2 has the following payload:
# sha256(serverBytes | authenticator) # sha256(serverBytes | authenticator)
url = f"http://{self.host}/app/handshake2" url = f"http://{self._host}/app/handshake2"
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
@ -252,44 +256,24 @@ class KlapTransport(BaseTransport):
"Handshake2 posted %s. Host is %s, Response status is %s, " "Handshake2 posted %s. Host is %s, Response status is %s, "
+ "Request was %s", + "Request was %s",
datetime.datetime.now(), datetime.datetime.now(),
self.host, self._host,
response_status, response_status,
payload.hex(), payload.hex(),
) )
if response_status != 200: if response_status != 200:
raise AuthenticationException( raise AuthenticationException(
f"Device {self.host} responded with {response_status} to handshake2" f"Device {self._host} responded with {response_status} to handshake2"
) )
return KlapEncryptionSession(local_seed, remote_seed, auth_hash) return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
@property
def needs_login(self) -> bool:
"""Will return false as KLAP does not do a login."""
return False
async def login(self, request: str) -> None:
"""Will raise and exception as KLAP does not do a login."""
raise SmartDeviceException(
"KLAP does not perform logins and return needs_login == False"
)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self) -> Any: async def perform_handshake(self) -> Any:
"""Perform handshake1 and handshake2. """Perform handshake1 and handshake2.
Sets the encryption_session if successful. Sets the encryption_session if successful.
""" """
_LOGGER.debug("Starting handshake with %s", self.host) _LOGGER.debug("Starting handshake with %s", self._host)
self._handshake_done = False self._handshake_done = False
self._session_expire_at = None self._session_expire_at = None
self._session_cookie = None self._session_cookie = None
@ -307,7 +291,7 @@ class KlapTransport(BaseTransport):
) )
self._handshake_done = True self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host) _LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self): def _handshake_session_expired(self):
"""Return true if session has expired.""" """Return true if session has expired."""
@ -318,18 +302,14 @@ class KlapTransport(BaseTransport):
async def send(self, request: str): async def send(self, request: str):
"""Send the request.""" """Send the request."""
if self.needs_handshake: if not self._handshake_done or self._handshake_session_expired():
raise SmartDeviceException( await self.perform_handshake()
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
# Check for mypy # Check for mypy
if self._encryption_session is not None: if self._encryption_session is not None:
payload, seq = self._encryption_session.encrypt(request.encode()) payload, seq = self._encryption_session.encrypt(request.encode())
url = f"http://{self.host}/app/request" url = f"http://{self._host}/app/request"
response_status, response_data = await self.client_post( response_status, response_data = await self.client_post(
url, url,
@ -338,7 +318,7 @@ class KlapTransport(BaseTransport):
) )
msg = ( msg = (
f"at {datetime.datetime.now()}. Host is {self.host}, " f"at {datetime.datetime.now()}. Host is {self._host}, "
+ f"Sequence is {seq}, " + f"Sequence is {seq}, "
+ f"Response status is {response_status}, Request was {request}" + f"Response status is {response_status}, Request was {request}"
) )
@ -348,12 +328,12 @@ class KlapTransport(BaseTransport):
if response_status == 403: if response_status == 403:
self._handshake_done = False self._handshake_done = False
raise AuthenticationException( raise AuthenticationException(
f"Got a security error from {self.host} after handshake " f"Got a security error from {self._host} after handshake "
+ "completed" + "completed"
) )
else: else:
raise SmartDeviceException( raise SmartDeviceException(
f"Device {self.host} responded with {response_status} to" f"Device {self._host} responded with {response_status} to"
+ f"request with seq {seq}" + f"request with seq {seq}"
) )
else: else:
@ -367,7 +347,7 @@ class KlapTransport(BaseTransport):
_LOGGER.debug( _LOGGER.debug(
"%s << %s", "%s << %s",
self.host, self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload), _LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
) )

View File

@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes:
class BaseTransport(ABC): class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports.""" """Base class for all TP-Link protocol transports."""
DEFAULT_TIMEOUT = 5
def __init__( def __init__(
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, port: Optional[int] = None,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
self.host = host self._host = host
self.port = port self._port = port
self.credentials = credentials self._credentials = credentials or Credentials(username="", password="")
self._timeout = timeout or self.DEFAULT_TIMEOUT
@property
@abstractmethod
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
@property
@abstractmethod
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
@abstractmethod
async def login(self, request: str) -> None:
"""Login to the device."""
@abstractmethod
async def handshake(self) -> None:
"""Perform the encryption handshake."""
@abstractmethod @abstractmethod
async def send(self, request: str) -> Dict: async def send(self, request: str) -> Dict:
@ -90,14 +76,14 @@ class TPLinkProtocol(ABC):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, transport: BaseTransport,
credentials: Optional[Credentials] = None,
transport: Optional[BaseTransport] = None,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
self.host = host self._transport = transport
self.port = port
self.credentials = credentials @property
def _host(self):
return self._transport._host
@abstractmethod @abstractmethod
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
@ -108,6 +94,40 @@ class TPLinkProtocol(ABC):
"""Close the protocol. Abstract method to be overriden.""" """Close the protocol. Abstract method to be overriden."""
class _XorTransport(BaseTransport):
"""Implementation of the Xor encryption transport.
WIP, currently only to ensure consistent __init__ method signatures
for protocol classes. Will eventually incorporate the logic from
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
class.
"""
DEFAULT_PORT = 9999
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
return {}
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
class TPLinkSmartHomeProtocol(TPLinkProtocol): class TPLinkSmartHomeProtocol(TPLinkProtocol):
"""Implementation of the TP-Link Smart Home protocol.""" """Implementation of the TP-Link Smart Home protocol."""
@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
self, self,
host: str, host: str,
*, *,
port: Optional[int] = None, transport: BaseTransport,
timeout: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None: ) -> None:
"""Create a protocol object.""" """Create a protocol object."""
super().__init__( super().__init__(host, transport=transport)
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
)
self.reader: Optional[asyncio.StreamReader] = None self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock = asyncio.Lock() self.query_lock = asyncio.Lock()
self.loop: Optional[asyncio.AbstractEventLoop] = None self.loop: Optional[asyncio.AbstractEventLoop] = None
self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
self._timeout = self._transport._timeout
self._port = self._transport._port
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device. """Request information from a TP-Link SmartHome Device.
@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
assert isinstance(request, str) # noqa: S101 assert isinstance(request, str) # noqa: S101
async with self.query_lock: async with self.query_lock:
return await self._query(request, retry_count, self.timeout) return await self._query(request, retry_count, self._timeout)
async def _connect(self, timeout: int) -> None: async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device.""" """Try to connect or reconnect to the device."""
@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
return return
self.reader = self.writer = None self.reader = self.writer = None
task = asyncio.open_connection(self.host, self.port) task = asyncio.open_connection(self._host, self._port)
async with asyncio_timeout(timeout): async with asyncio_timeout(timeout):
self.reader, self.writer = await task self.reader, self.writer = await task
sock: socket.socket = self.writer.get_extra_info("socket") sock: socket.socket = self.writer.get_extra_info("socket")
@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
debug_log = _LOGGER.isEnabledFor(logging.DEBUG) debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_log: if debug_log:
_LOGGER.debug("%s >> %s", self.host, request) _LOGGER.debug("%s >> %s", self._host, request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain() await self.writer.drain()
@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
response = TPLinkSmartHomeProtocol.decrypt(buffer) response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json_loads(response) json_payload = json_loads(response)
if debug_log: if debug_log:
_LOGGER.debug("%s << %s", self.host, pf(json_payload)) _LOGGER.debug("%s << %s", self._host, pf(json_payload))
return json_payload return json_payload
@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
except ConnectionRefusedError as ex: except ConnectionRefusedError as ex:
await self.close() await self.close()
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}:{self.port}: {ex}" f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex ) from ex
except OSError as ex: except OSError as ex:
await self.close() await self.close()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device:" f"Unable to connect to the device:"
f" {self.host}:{self.port}: {ex}" f" {self._host}:{self._port}: {ex}"
) from ex ) from ex
continue continue
except Exception as ex: except Exception as ex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device:" f"Unable to connect to the device:"
f" {self.host}:{self.port}: {ex}" f" {self._host}:{self._port}: {ex}"
) from ex ) from ex
continue continue
@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
except Exception as ex: except Exception as ex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to query the device {self.host}:{self.port}: {ex}" f"Unable to query the device {self._host}:{self._port}: {ex}"
) from ex ) from ex
_LOGGER.debug( _LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex "Unable to query the device %s, retrying: %s", self._host, ex
) )
# make mypy happy, this should never be reached.. # make mypy happy, this should never be reached..

View File

@ -24,7 +24,7 @@ from .device_type import DeviceType
from .emeterstatus import EmeterStatus from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
from .modules import Emeter, Module from .modules import Emeter, Module
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, _XorTransport
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -202,7 +202,7 @@ class SmartDevice:
self.host = host self.host = host
self.port = port self.port = port
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
host, port=port, timeout=timeout host, transport=_XorTransport(host, port=port, timeout=timeout)
) )
self.credentials = credentials self.credentials = credentials
_LOGGER.debug("Initializing %s of type %s", self.host, type(self)) _LOGGER.debug("Initializing %s of type %s", self.host, type(self))

View File

@ -10,12 +10,10 @@ import logging
import time import time
import uuid import uuid
from pprint import pformat as pf from pprint import pformat as pf
from typing import Dict, Optional, Union from typing import Dict, Union
import httpx import httpx
from .aestransport import AesTransport
from .credentials import Credentials
from .exceptions import ( from .exceptions import (
SMART_AUTHENTICATION_ERRORS, SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS, SMART_RETRYABLE_ERRORS,
@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol): class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol.""" """Class for the new TPLink SMART protocol."""
DEFAULT_PORT = 80
SLEEP_SECONDS_AFTER_TIMEOUT = 1 SLEEP_SECONDS_AFTER_TIMEOUT = 1
def __init__( def __init__(
self, self,
host: str, host: str,
*, *,
transport: Optional[BaseTransport] = None, transport: BaseTransport,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None: ) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT) """Create a protocol object."""
super().__init__(host, transport=transport)
self._credentials: Credentials = credentials or Credentials( self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
username="", password=""
)
self._transport: BaseTransport = transport or AesTransport(
host, credentials=self._credentials, timeout=timeout
)
self._terminal_uuid: Optional[str] = None
self._request_id_generator = SnowflakeId(1, 1) self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol):
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS: ) != SmartErrorCode.SUCCESS:
msg = ( msg = (
f"Error querying device: {self.host}: " f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})" + f"{error_code.name}({error_code.value})"
) )
if error_code in SMART_TIMEOUT_ERRORS: if error_code in SMART_TIMEOUT_ERRORS:
@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol):
except httpx.CloseError as sdex: except httpx.CloseError as sdex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}" f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex ) from sdex
continue continue
except httpx.ConnectError as cex: except httpx.ConnectError as cex:
await self.close() await self.close()
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}" f"Unable to connect to the device: {self._host}: {cex}"
) from cex ) from cex
except TimeoutError as tex: except TimeoutError as tex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
raise SmartDeviceException( raise SmartDeviceException(
"Unable to connect to the device, " "Unable to connect to the device, "
+ f"timed out: {self.host}: {tex}" + f"timed out: {self._host}: {tex}"
) from tex ) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue continue
except AuthenticationException as auex: except AuthenticationException as auex:
await self.close() await self.close()
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) _LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex raise auex
except RetryableException as ex: except RetryableException as ex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
continue continue
except TimeoutException as ex: except TimeoutException as ex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue continue
except Exception as ex: except Exception as ex:
if retry >= retry_count: if retry >= retry_count:
await self.close() await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to query the device {self.host}:{self.port}: {ex}" f"Unable to connect to the device: {self._host}: {ex}"
) from ex ) from ex
_LOGGER.debug( _LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex "Unable to query the device %s, retrying: %s", self._host, ex
) )
continue continue
@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol):
smart_method = request smart_method = request
smart_params = None smart_params = None
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login:
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
login_request = self.get_smart_request("login_device")
await self._transport.login(login_request)
smart_request = self.get_smart_request(smart_method, smart_params) smart_request = self.get_smart_request(smart_method, smart_params)
_LOGGER.debug( _LOGGER.debug(
"%s >> %s", "%s >> %s",
self.host, self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request), _LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
) )
response_data = await self._transport.send(smart_request) response_data = await self._transport.send(smart_request)
_LOGGER.debug( _LOGGER.debug(
"%s << %s", "%s << %s",
self.host, self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
) )

View File

@ -4,6 +4,7 @@ import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast from typing import Any, Dict, Optional, Set, cast
from ..aestransport import AesTransport
from ..credentials import Credentials from ..credentials import Credentials
from ..exceptions import AuthenticationException from ..exceptions import AuthenticationException
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
@ -27,7 +28,12 @@ class TapoDevice(SmartDevice):
self._components: Optional[Dict[str, Any]] = None self._components: Optional[Dict[str, Any]] = None
self._state_information: Dict[str, Any] = {} self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout) self.protocol = SmartProtocol(
host,
transport=AesTransport(
host, credentials=credentials, timeout=timeout, port=port
),
)
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""

View File

@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport): class FakeSmartTransport(BaseTransport):
def __init__(self, info): def __init__(self, info):
super().__init__(
"127.0.0.123",
)
self.info = info self.info = info
@property @property

View File

@ -0,0 +1,174 @@
import base64
import json
import time
from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials
from ..exceptions import SmartDeviceException
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
KEY_IV = key + iv
def test_encrypt():
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
d = json.dumps({"foo": 1, "bar": 2})
encrypted = encryption_session.encrypt(d.encode())
assert d == encryption_session.decrypt(encrypted)
# test encrypt unicode
d = "{'snowman': '\u2603'}"
encrypted = encryption_session.encrypt(d.encode())
assert d == encryption_session.decrypt(encrypted)
status_parameters = pytest.mark.parametrize(
"status_code, error_code, inner_error_code, expectation",
[
(200, 0, 0, does_not_raise()),
(400, 0, 0, pytest.raises(SmartDeviceException)),
(200, -1, 0, pytest.raises(SmartDeviceException)),
],
ids=("success", "status_code", "error_code"),
)
@status_parameters
async def test_handshake(
mocker, status_code, error_code, inner_error_code, expectation
):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
assert transport._encryption_session is None
assert transport._handshake_done is False
with expectation:
await transport.perform_handshake()
assert transport._encryption_session is not None
assert transport._handshake_done is True
@status_parameters
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None
with expectation:
await transport.perform_login()
assert transport._login_token == mock_aes_device.token
@status_parameters
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
un, pw = transport.hash_credentials(True)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with expectation:
res = await transport.send(json_dumps(request))
assert "result" in res
class MockAesDevice:
class _mock_response:
def __init__(self, status_code, json: dict):
self.status_code = status_code
self._json = json
def json(self):
return self._json
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
token = "test_token" # noqa
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
self.host = host
self.status_code = status_code
self.error_code = error_code
self.inner_error_code = inner_error_code
async def post(self, url, params=None, json=None, *_, **__):
return await self._post(url, json)
async def _post(self, url, json):
if json["method"] == "handshake":
return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough":
return await self._return_secure_passthrough_response(url, json)
elif json["method"] == "login_device":
return await self._return_login_response(url, json)
else:
assert url == f"http://{self.host}/app?token={self.token}"
return await self._return_send_response(url, json)
async def _return_handshake_response(self, url, json):
start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end]
client_pub_key_data = base64.b64decode(client_pub_key.encode())
client_pub_key = serialization.load_der_public_key(client_pub_key_data, None)
encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15())
key_64 = base64.b64encode(encrypted_key).decode()
return self._mock_response(
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
)
async def _return_secure_passthrough_response(self, url, json):
encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
decrypted_response_dict = decrypted_response.json()
encrypted_response = self.encryption_session.encrypt(
json_dumps(decrypted_response_dict).encode()
)
result = {
"result": {"response": encrypted_response.decode()},
"error_code": self.error_code,
}
return self._mock_response(self.status_code, result)
async def _return_login_response(self, url, json):
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
return self._mock_response(self.status_code, result)
async def _return_send_response(self, url, json):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
return self._mock_response(self.status_code, result)

View File

@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
return mock_response return mock_response
mocker.patch.object( mocker.patch.object(transport_class, "perform_handshake")
transport_class, "needs_handshake", property(lambda self: False) if hasattr(transport_class, "perform_login"):
) mocker.patch.object(transport_class, "perform_login")
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
send_mock = mocker.patch.object( send_mock = mocker.patch.object(
transport_class, transport_class,
@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
seed = secrets.token_bytes(16) seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = IotProtocol("127.0.0.1") protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
protocol._transport._handshake_done = True protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400 protocol._transport._session_expire_at = time.time() + 86400
@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation):
httpx.AsyncClient, "post", side_effect=_return_handshake1_response httpx.AsyncClient, "post", side_effect=_return_handshake1_response
) )
protocol = IotProtocol("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
protocol._transport.http_client = httpx.AsyncClient() protocol._transport.http_client = httpx.AsyncClient()
with expectation: with expectation:
@ -243,7 +245,10 @@ async def test_handshake(mocker):
httpx.AsyncClient, "post", side_effect=_return_handshake_response httpx.AsyncClient, "post", side_effect=_return_handshake_response
) )
protocol = IotProtocol("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
protocol._transport.http_client = httpx.AsyncClient() protocol._transport.http_client = httpx.AsyncClient()
response_status = 200 response_status = 200
@ -289,7 +294,10 @@ async def test_query(mocker):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
for _ in range(10): for _ in range(10):
resp = await protocol.query({}) resp = await protocol.query({})
@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials) protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
with expectation: with expectation:
await protocol.query({}) await protocol.query({})

View File

@ -1,13 +1,21 @@
import errno import errno
import importlib
import inspect
import json import json
import logging import logging
import pkgutil
import struct import struct
import sys import sys
import pytest import pytest
from ..exceptions import SmartDeviceException from ..exceptions import SmartDeviceException
from ..protocol import TPLinkSmartHomeProtocol from ..protocol import (
BaseTransport,
TPLinkProtocol,
TPLinkSmartHomeProtocol,
_XorTransport,
)
@pytest.mark.parametrize("retry_count", [1, 3, 5]) @pytest.mark.parametrize("retry_count", [1, 3, 5])
@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count):
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count) await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=retry_count)
assert conn.call_count == retry_count + 1 assert conn.call_count == retry_count + 1
@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker):
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
) )
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
assert conn.call_count == 1 assert conn.call_count == 1
@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker):
side_effect=ConnectionRefusedError, side_effect=ConnectionRefusedError,
) )
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
assert conn.call_count == 1 assert conn.call_count == 1
@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker):
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
) )
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5) await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
assert conn.call_count == 6 assert conn.call_count == 6
@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count):
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1") protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count) response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"} assert response == {"great": "success"}
@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1") protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port):
mocker.patch.object(reader, "readexactly", _mock_read) mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port) protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}) response = await protocol.query({})
assert response == {"great": "success"} assert response == {"great": "success"}
@ -227,3 +249,63 @@ def test_decrypt_unicode():
d = "{'snowman': '\u2603'}" d = "{'snowman': '\u2603'}"
assert d == TPLinkSmartHomeProtocol.decrypt(e) assert d == TPLinkSmartHomeProtocol.decrypt(e)
def _get_subclasses(of_class):
import kasa
package = sys.modules["kasa"]
subclasses = set()
for _, modname, _ in pkgutil.iter_modules(package.__path__):
importlib.import_module("." + modname, package="kasa")
module = sys.modules["kasa." + modname]
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, of_class):
subclasses.add((name, obj))
return subclasses
@pytest.mark.parametrize(
"class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0]
)
def test_protocol_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 3
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[2].name == "transport"
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
)
@pytest.mark.parametrize(
"class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0]
)
def test_transport_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 5
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
assert (
params[3].name == "credentials"
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
)
assert (
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
)

View File

@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice):
async def test_create_smart_device_with_timeout(): async def test_create_smart_device_with_timeout():
"""Make sure timeout is passed to the protocol.""" """Make sure timeout is passed to the protocol."""
dev = SmartDevice(host="127.0.0.1", timeout=100) dev = SmartDevice(host="127.0.0.1", timeout=100)
assert dev.protocol.timeout == 100 assert dev.protocol._transport._timeout == 100
async def test_create_thin_wrapper(): async def test_create_thin_wrapper():