Do login entirely within AesTransport (#580)

* Do login entirely within AesTransport

* Remove login and handshake attributes from BaseTransport

* Add AesTransport tests

* Synchronise transport and protocol __init__ signatures and rename internal variables

* Update after review
This commit is contained in:
sdb9696 2023-12-19 14:11:59 +00:00 committed by GitHub
parent 209391c422
commit 20ea6700a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 468 additions and 237 deletions

View File

@ -8,7 +8,7 @@ import base64
import hashlib
import logging
import time
from typing import Optional, Union
from typing import Optional
import httpx
from cryptography.hazmat.primitives import padding, serialization
@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
@ -59,12 +60,16 @@ class AesTransport(BaseTransport):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host)
self._credentials = credentials or Credentials(username="", password="")
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
self._handshake_done = False
@ -77,7 +82,7 @@ class AesTransport(BaseTransport):
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
_LOGGER.debug("Created AES transport for %s", self._host)
def hash_credentials(self, login_v2):
"""Hash the credentials."""
@ -123,7 +128,7 @@ class AesTransport(BaseTransport):
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
@ -136,7 +141,7 @@ class AesTransport(BaseTransport):
async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
url = f"http://{self._host}/app"
if self._login_token:
url += f"?token={self._login_token}"
@ -150,7 +155,7 @@ class AesTransport(BaseTransport):
if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
@ -164,49 +169,31 @@ class AesTransport(BaseTransport):
resp_dict = json_loads(response)
return resp_dict
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
async def _perform_login_for_version(self, *, login_version: int = 1):
"""Login to the device."""
self._login_token = None
if isinstance(login_request, str):
login_request_dict: dict = json_loads(login_request)
else:
login_request_dict = login_request
un, pw = self.hash_credentials(login_v2)
login_request_dict["params"] = {"password": pw, "username": un}
request = json_dumps(login_request_dict)
un, pw = self.hash_credentials(login_version == 2)
password_field_name = "password2" if login_version == 2 else "password"
login_request = {
"method": "login_device",
"params": {password_field_name: pw, "username": un},
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
try:
resp_dict = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self._login_token = resp_dict["result"]["token"]
@property
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
return self._login_token is None
async def login(self, request: str) -> None:
async def perform_login(self) -> None:
"""Login to the device."""
try:
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to login"
)
await self.perform_login(request, login_v2=False)
await self._perform_login_for_version(login_version=2)
except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake()
await self.perform_login(request, login_v2=True)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
await self._perform_login_for_version(login_version=1)
async def perform_handshake(self):
"""Perform the handshake."""
@ -217,7 +204,7 @@ class AesTransport(BaseTransport):
self._session_expire_at = None
self._session_cookie = None
url = f"http://{self.host}/app"
url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
@ -238,7 +225,7 @@ class AesTransport(BaseTransport):
if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)
@ -261,7 +248,7 @@ class AesTransport(BaseTransport):
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self):
"""Return true if session has expired."""
@ -272,12 +259,10 @@ class AesTransport(BaseTransport):
async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
if not self._login_token:
await self.perform_login()
return await self.send_secure_passthrough(request)

View File

@ -74,7 +74,12 @@ async def connect(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
dev.protocol = protocol_class(host, credentials=credentials)
dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
@ -90,7 +95,13 @@ async def connect(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials)
# TODO this will be replaced with connection params
unknown_dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await unknown_dev.update()
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
@ -163,7 +174,5 @@ def get_protocol_from_connection_name(
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
transport: BaseTransport = transport_class(host, credentials=credentials)
protocol: TPLinkProtocol = protocol_class(
host, credentials=credentials, transport=transport
)
protocol: TPLinkProtocol = protocol_class(host, transport=transport)
return protocol

View File

@ -1,14 +1,12 @@
"""Module for the IOT legacy IOT KASA protocol."""
import asyncio
import logging
from typing import Dict, Optional, Union
from typing import Dict, Union
import httpx
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .klaptransport import KlapTransport
from .protocol import BaseTransport, TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
@ -17,24 +15,14 @@ _LOGGER = logging.getLogger(__name__)
class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol."""
DEFAULT_PORT = 80
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
transport: BaseTransport,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or KlapTransport(
host, credentials=self._credentials, timeout=timeout
)
"""Create a protocol object."""
super().__init__(host, transport=transport)
self._query_lock = asyncio.Lock()
@ -54,30 +42,32 @@ class IotProtocol(TPLinkProtocol):
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
continue
@ -85,14 +75,6 @@ class IotProtocol(TPLinkProtocol):
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> Dict:
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login: # This shouln't happen
raise SmartDeviceException(
"IOT Protocol needs to login to transport but is not login aware"
)
return await self._transport.send(request)
async def close(self) -> None:

View File

@ -82,7 +82,7 @@ class KlapTransport(BaseTransport):
protocol, used by newer firmware versions.
"""
DEFAULT_TIMEOUT = 5
DEFAULT_PORT = 80
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
KASA_SETUP_EMAIL = "kasa@tp-link.net"
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
@ -92,12 +92,17 @@ class KlapTransport(BaseTransport):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host)
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
self._credentials = credentials or Credentials(username="", password="")
self._local_seed: Optional[bytes] = None
self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
@ -110,11 +115,10 @@ class KlapTransport(BaseTransport):
self._encryption_session: Optional[KlapEncryptionSession] = None
self._session_expire_at: Optional[float] = None
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self._session_cookie = None
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
_LOGGER.debug("Created KLAP object for %s", self.host)
_LOGGER.debug("Created KLAP transport for %s", self._host)
async def client_post(self, url, params=None, data=None):
"""Send an http post request to the device."""
@ -148,7 +152,7 @@ class KlapTransport(BaseTransport):
payload = local_seed
url = f"http://{self.host}/app/handshake1"
url = f"http://{self._host}/app/handshake1"
response_status, response_data = await self.client_post(url, data=payload)
@ -157,14 +161,14 @@ class KlapTransport(BaseTransport):
"Handshake1 posted at %s. Host is %s, Response"
+ "status is %s, Request was %s",
datetime.datetime.now(),
self.host,
self._host,
response_status,
payload.hex(),
)
if response_status != 200:
raise AuthenticationException(
f"Device {self.host} responded with {response_status} to handshake1"
f"Device {self._host} responded with {response_status} to handshake1"
)
remote_seed: bytes = response_data[0:16]
@ -175,7 +179,7 @@ class KlapTransport(BaseTransport):
"Handshake1 success at %s. Host is %s, "
+ "Server remote_seed is: %s, server hash is: %s",
datetime.datetime.now(),
self.host,
self._host,
remote_seed.hex(),
server_hash.hex(),
)
@ -207,7 +211,7 @@ class KlapTransport(BaseTransport):
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ " but an authentication with kasa setup credentials matched",
self.host,
self._host,
)
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
@ -226,11 +230,11 @@ class KlapTransport(BaseTransport):
_LOGGER.debug(
"Server response doesn't match our expected hash on ip %s"
+ " but an authentication with blank credentials matched",
self.host,
self._host,
)
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
msg = f"Server response doesn't match our challenge on ip {self.host}"
msg = f"Server response doesn't match our challenge on ip {self._host}"
_LOGGER.debug(msg)
raise AuthenticationException(msg)
@ -241,7 +245,7 @@ class KlapTransport(BaseTransport):
# Handshake 2 has the following payload:
# sha256(serverBytes | authenticator)
url = f"http://{self.host}/app/handshake2"
url = f"http://{self._host}/app/handshake2"
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
@ -252,44 +256,24 @@ class KlapTransport(BaseTransport):
"Handshake2 posted %s. Host is %s, Response status is %s, "
+ "Request was %s",
datetime.datetime.now(),
self.host,
self._host,
response_status,
payload.hex(),
)
if response_status != 200:
raise AuthenticationException(
f"Device {self.host} responded with {response_status} to handshake2"
f"Device {self._host} responded with {response_status} to handshake2"
)
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
@property
def needs_login(self) -> bool:
"""Will return false as KLAP does not do a login."""
return False
async def login(self, request: str) -> None:
"""Will raise and exception as KLAP does not do a login."""
raise SmartDeviceException(
"KLAP does not perform logins and return needs_login == False"
)
@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()
async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
async def perform_handshake(self) -> Any:
"""Perform handshake1 and handshake2.
Sets the encryption_session if successful.
"""
_LOGGER.debug("Starting handshake with %s", self.host)
_LOGGER.debug("Starting handshake with %s", self._host)
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None
@ -307,7 +291,7 @@ class KlapTransport(BaseTransport):
)
self._handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
_LOGGER.debug("Handshake with %s complete", self._host)
def _handshake_session_expired(self):
"""Return true if session has expired."""
@ -318,18 +302,14 @@ class KlapTransport(BaseTransport):
async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
# Check for mypy
if self._encryption_session is not None:
payload, seq = self._encryption_session.encrypt(request.encode())
url = f"http://{self.host}/app/request"
url = f"http://{self._host}/app/request"
response_status, response_data = await self.client_post(
url,
@ -338,7 +318,7 @@ class KlapTransport(BaseTransport):
)
msg = (
f"at {datetime.datetime.now()}. Host is {self.host}, "
f"at {datetime.datetime.now()}. Host is {self._host}, "
+ f"Sequence is {seq}, "
+ f"Response status is {response_status}, Request was {request}"
)
@ -348,12 +328,12 @@ class KlapTransport(BaseTransport):
if response_status == 403:
self._handshake_done = False
raise AuthenticationException(
f"Got a security error from {self.host} after handshake "
f"Got a security error from {self._host} after handshake "
+ "completed"
)
else:
raise SmartDeviceException(
f"Device {self.host} responded with {response_status} to"
f"Device {self._host} responded with {response_status} to"
+ f"request with seq {seq}"
)
else:
@ -367,7 +347,7 @@ class KlapTransport(BaseTransport):
_LOGGER.debug(
"%s << %s",
self.host,
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
)

View File

@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes:
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
DEFAULT_TIMEOUT = 5
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
"""Create a protocol object."""
self.host = host
self.port = port
self.credentials = credentials
@property
@abstractmethod
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
@property
@abstractmethod
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
@abstractmethod
async def login(self, request: str) -> None:
"""Login to the device."""
@abstractmethod
async def handshake(self) -> None:
"""Perform the encryption handshake."""
self._host = host
self._port = port
self._credentials = credentials or Credentials(username="", password="")
self._timeout = timeout or self.DEFAULT_TIMEOUT
@abstractmethod
async def send(self, request: str) -> Dict:
@ -90,14 +76,14 @@ class TPLinkProtocol(ABC):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
transport: Optional[BaseTransport] = None,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
self.host = host
self.port = port
self.credentials = credentials
self._transport = transport
@property
def _host(self):
return self._transport._host
@abstractmethod
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
@ -108,6 +94,40 @@ class TPLinkProtocol(ABC):
"""Close the protocol. Abstract method to be overriden."""
class _XorTransport(BaseTransport):
"""Implementation of the Xor encryption transport.
WIP, currently only to ensure consistent __init__ method signatures
for protocol classes. Will eventually incorporate the logic from
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
class.
"""
DEFAULT_PORT = 9999
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
return {}
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
class TPLinkSmartHomeProtocol(TPLinkProtocol):
"""Implementation of the TP-Link Smart Home protocol."""
@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
self,
host: str,
*,
port: Optional[int] = None,
timeout: Optional[int] = None,
credentials: Optional[Credentials] = None,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
)
super().__init__(host, transport=transport)
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock = asyncio.Lock()
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
self._timeout = self._transport._timeout
self._port = self._transport._port
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
assert isinstance(request, str) # noqa: S101
async with self.query_lock:
return await self._query(request, retry_count, self.timeout)
return await self._query(request, retry_count, self._timeout)
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""
@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
return
self.reader = self.writer = None
task = asyncio.open_connection(self.host, self.port)
task = asyncio.open_connection(self._host, self._port)
async with asyncio_timeout(timeout):
self.reader, self.writer = await task
sock: socket.socket = self.writer.get_extra_info("socket")
@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_log:
_LOGGER.debug("%s >> %s", self.host, request)
_LOGGER.debug("%s >> %s", self._host, request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()
@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json_loads(response)
if debug_log:
_LOGGER.debug("%s << %s", self.host, pf(json_payload))
_LOGGER.debug("%s << %s", self._host, pf(json_payload))
return json_payload
@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
except ConnectionRefusedError as ex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.close()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device:"
f" {self.host}:{self.port}: {ex}"
f" {self._host}:{self._port}: {ex}"
) from ex
continue
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device:"
f" {self.host}:{self.port}: {ex}"
f" {self._host}:{self._port}: {ex}"
) from ex
continue
@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to query the device {self.host}:{self.port}: {ex}"
f"Unable to query the device {self._host}:{self._port}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
"Unable to query the device %s, retrying: %s", self._host, ex
)
# make mypy happy, this should never be reached..

View File

@ -24,7 +24,7 @@ from .device_type import DeviceType
from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException
from .modules import Emeter, Module
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, _XorTransport
_LOGGER = logging.getLogger(__name__)
@ -202,7 +202,7 @@ class SmartDevice:
self.host = host
self.port = port
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
host, port=port, timeout=timeout
host, transport=_XorTransport(host, port=port, timeout=timeout)
)
self.credentials = credentials
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))

View File

@ -10,12 +10,10 @@ import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
from typing import Dict, Union
import httpx
from .aestransport import AesTransport
from .credentials import Credentials
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol."""
DEFAULT_PORT = 80
SLEEP_SECONDS_AFTER_TIMEOUT = 1
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
transport: BaseTransport,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or AesTransport(
host, credentials=self._credentials, timeout=timeout
)
self._terminal_uuid: Optional[str] = None
"""Create a protocol object."""
super().__init__(host, transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock()
@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol):
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = (
f"Error querying device: {self.host}: "
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
)
if error_code in SMART_TIMEOUT_ERRORS:
@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol):
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex:
if retry >= retry_count:
await self.close()
raise SmartDeviceException(
"Unable to connect to the device, "
+ f"timed out: {self.host}: {tex}"
+ f"timed out: {self._host}: {tex}"
) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except AuthenticationException as auex:
await self.close()
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to query the device {self.host}:{self.port}: {ex}"
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
"Unable to query the device %s, retrying: %s", self._host, ex
)
continue
@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol):
smart_method = request
smart_params = None
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login:
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
login_request = self.get_smart_request("login_device")
await self._transport.login(login_request)
smart_request = self.get_smart_request(smart_method, smart_params)
_LOGGER.debug(
"%s >> %s",
self.host,
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
)
response_data = await self._transport.send(smart_request)
_LOGGER.debug(
"%s << %s",
self.host,
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)

View File

@ -4,6 +4,7 @@ import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..exceptions import AuthenticationException
from ..smartdevice import SmartDevice
@ -27,7 +28,12 @@ class TapoDevice(SmartDevice):
self._components: Optional[Dict[str, Any]] = None
self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
self.protocol = SmartProtocol(
host,
transport=AesTransport(
host, credentials=credentials, timeout=timeout, port=port
),
)
async def update(self, update_children: bool = True):
"""Update the device."""

View File

@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
super().__init__(
"127.0.0.123",
)
self.info = info
@property

View File

@ -0,0 +1,174 @@
import base64
import json
import time
from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials
from ..exceptions import SmartDeviceException
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
KEY_IV = key + iv
def test_encrypt():
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
d = json.dumps({"foo": 1, "bar": 2})
encrypted = encryption_session.encrypt(d.encode())
assert d == encryption_session.decrypt(encrypted)
# test encrypt unicode
d = "{'snowman': '\u2603'}"
encrypted = encryption_session.encrypt(d.encode())
assert d == encryption_session.decrypt(encrypted)
status_parameters = pytest.mark.parametrize(
"status_code, error_code, inner_error_code, expectation",
[
(200, 0, 0, does_not_raise()),
(400, 0, 0, pytest.raises(SmartDeviceException)),
(200, -1, 0, pytest.raises(SmartDeviceException)),
],
ids=("success", "status_code", "error_code"),
)
@status_parameters
async def test_handshake(
mocker, status_code, error_code, inner_error_code, expectation
):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
assert transport._encryption_session is None
assert transport._handshake_done is False
with expectation:
await transport.perform_handshake()
assert transport._encryption_session is not None
assert transport._handshake_done is True
@status_parameters
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
assert transport._login_token is None
with expectation:
await transport.perform_login()
assert transport._login_token == mock_aes_device.token
@status_parameters
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
un, pw = transport.hash_credentials(True)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with expectation:
res = await transport.send(json_dumps(request))
assert "result" in res
class MockAesDevice:
class _mock_response:
def __init__(self, status_code, json: dict):
self.status_code = status_code
self._json = json
def json(self):
return self._json
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
token = "test_token" # noqa
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
self.host = host
self.status_code = status_code
self.error_code = error_code
self.inner_error_code = inner_error_code
async def post(self, url, params=None, json=None, *_, **__):
return await self._post(url, json)
async def _post(self, url, json):
if json["method"] == "handshake":
return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough":
return await self._return_secure_passthrough_response(url, json)
elif json["method"] == "login_device":
return await self._return_login_response(url, json)
else:
assert url == f"http://{self.host}/app?token={self.token}"
return await self._return_send_response(url, json)
async def _return_handshake_response(self, url, json):
start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end]
client_pub_key_data = base64.b64decode(client_pub_key.encode())
client_pub_key = serialization.load_der_public_key(client_pub_key_data, None)
encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15())
key_64 = base64.b64encode(encrypted_key).decode()
return self._mock_response(
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
)
async def _return_secure_passthrough_response(self, url, json):
encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
decrypted_response_dict = decrypted_response.json()
encrypted_response = self.encryption_session.encrypt(
json_dumps(decrypted_response_dict).encode()
)
result = {
"result": {"response": encrypted_response.decode()},
"error_code": self.error_code,
}
return self._mock_response(self.status_code, result)
async def _return_login_response(self, url, json):
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
return self._mock_response(self.status_code, result)
async def _return_send_response(self, url, json):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
return self._mock_response(self.status_code, result)

View File

@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
return mock_response
mocker.patch.object(
transport_class, "needs_handshake", property(lambda self: False)
)
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
mocker.patch.object(transport_class, "perform_handshake")
if hasattr(transport_class, "perform_login"):
mocker.patch.object(transport_class, "perform_login")
send_mock = mocker.patch.object(
transport_class,
@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = IotProtocol("127.0.0.1")
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400
@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation):
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
protocol._transport.http_client = httpx.AsyncClient()
with expectation:
@ -243,7 +245,10 @@ async def test_handshake(mocker):
httpx.AsyncClient, "post", side_effect=_return_handshake_response
)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
protocol._transport.http_client = httpx.AsyncClient()
response_status = 200
@ -289,7 +294,10 @@ async def test_query(mocker):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
for _ in range(10):
resp = await protocol.query({})
@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
with expectation:
await protocol.query({})

View File

@ -1,13 +1,21 @@
import errno
import importlib
import inspect
import json
import logging
import pkgutil
import struct
import sys
import pytest
from ..exceptions import SmartDeviceException
from ..protocol import TPLinkSmartHomeProtocol
from ..protocol import (
BaseTransport,
TPLinkProtocol,
TPLinkSmartHomeProtocol,
_XorTransport,
)
@pytest.mark.parametrize("retry_count", [1, 3, 5])
@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count):
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=retry_count)
assert conn.call_count == retry_count + 1
@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker):
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
assert conn.call_count == 1
@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker):
side_effect=ConnectionRefusedError,
)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
assert conn.call_count == 1
@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker):
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
assert conn.call_count == 6
@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port)
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
@ -227,3 +249,63 @@ def test_decrypt_unicode():
d = "{'snowman': '\u2603'}"
assert d == TPLinkSmartHomeProtocol.decrypt(e)
def _get_subclasses(of_class):
import kasa
package = sys.modules["kasa"]
subclasses = set()
for _, modname, _ in pkgutil.iter_modules(package.__path__):
importlib.import_module("." + modname, package="kasa")
module = sys.modules["kasa." + modname]
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, of_class):
subclasses.add((name, obj))
return subclasses
@pytest.mark.parametrize(
"class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0]
)
def test_protocol_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 3
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[2].name == "transport"
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
)
@pytest.mark.parametrize(
"class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0]
)
def test_transport_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 5
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
assert (
params[3].name == "credentials"
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
)
assert (
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
)

View File

@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice):
async def test_create_smart_device_with_timeout():
"""Make sure timeout is passed to the protocol."""
dev = SmartDevice(host="127.0.0.1", timeout=100)
assert dev.protocol.timeout == 100
assert dev.protocol._transport._timeout == 100
async def test_create_thin_wrapper():