mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Do login entirely within AesTransport (#580)
* Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review
This commit is contained in:
parent
209391c422
commit
20ea6700a5
@ -8,7 +8,7 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from cryptography.hazmat.primitives import padding, serialization
|
from cryptography.hazmat.primitives import padding, serialization
|
||||||
@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
|
|||||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
DEFAULT_PORT = 80
|
||||||
DEFAULT_TIMEOUT = 5
|
DEFAULT_TIMEOUT = 5
|
||||||
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
SESSION_COOKIE_NAME = "TP_SESSIONID"
|
||||||
COMMON_HEADERS = {
|
COMMON_HEADERS = {
|
||||||
@ -59,12 +60,16 @@ class AesTransport(BaseTransport):
|
|||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
|
port: Optional[int] = None,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(host=host)
|
super().__init__(
|
||||||
|
host,
|
||||||
self._credentials = credentials or Credentials(username="", password="")
|
port=port or self.DEFAULT_PORT,
|
||||||
|
credentials=credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
|
|
||||||
@ -77,7 +82,7 @@ class AesTransport(BaseTransport):
|
|||||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
|
|
||||||
_LOGGER.debug("Created AES object for %s", self.host)
|
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||||
|
|
||||||
def hash_credentials(self, login_v2):
|
def hash_credentials(self, login_v2):
|
||||||
"""Hash the credentials."""
|
"""Hash the credentials."""
|
||||||
@ -123,7 +128,7 @@ class AesTransport(BaseTransport):
|
|||||||
if (
|
if (
|
||||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
) != SmartErrorCode.SUCCESS:
|
) != SmartErrorCode.SUCCESS:
|
||||||
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
|
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||||
if error_code in SMART_TIMEOUT_ERRORS:
|
if error_code in SMART_TIMEOUT_ERRORS:
|
||||||
raise TimeoutException(msg)
|
raise TimeoutException(msg)
|
||||||
if error_code in SMART_RETRYABLE_ERRORS:
|
if error_code in SMART_RETRYABLE_ERRORS:
|
||||||
@ -136,7 +141,7 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
async def send_secure_passthrough(self, request: str):
|
async def send_secure_passthrough(self, request: str):
|
||||||
"""Send encrypted message as passthrough."""
|
"""Send encrypted message as passthrough."""
|
||||||
url = f"http://{self.host}/app"
|
url = f"http://{self._host}/app"
|
||||||
if self._login_token:
|
if self._login_token:
|
||||||
url += f"?token={self._login_token}"
|
url += f"?token={self._login_token}"
|
||||||
|
|
||||||
@ -150,7 +155,7 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"{self.host} responded with an unexpected "
|
f"{self._host} responded with an unexpected "
|
||||||
+ f"status code {status_code} to passthrough"
|
+ f"status code {status_code} to passthrough"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -164,49 +169,31 @@ class AesTransport(BaseTransport):
|
|||||||
resp_dict = json_loads(response)
|
resp_dict = json_loads(response)
|
||||||
return resp_dict
|
return resp_dict
|
||||||
|
|
||||||
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
|
async def _perform_login_for_version(self, *, login_version: int = 1):
|
||||||
"""Login to the device."""
|
"""Login to the device."""
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
|
un, pw = self.hash_credentials(login_version == 2)
|
||||||
if isinstance(login_request, str):
|
password_field_name = "password2" if login_version == 2 else "password"
|
||||||
login_request_dict: dict = json_loads(login_request)
|
login_request = {
|
||||||
else:
|
"method": "login_device",
|
||||||
login_request_dict = login_request
|
"params": {password_field_name: pw, "username": un},
|
||||||
|
"request_time_milis": round(time.time() * 1000),
|
||||||
un, pw = self.hash_credentials(login_v2)
|
}
|
||||||
login_request_dict["params"] = {"password": pw, "username": un}
|
request = json_dumps(login_request)
|
||||||
request = json_dumps(login_request_dict)
|
|
||||||
try:
|
try:
|
||||||
resp_dict = await self.send_secure_passthrough(request)
|
resp_dict = await self.send_secure_passthrough(request)
|
||||||
except SmartDeviceException as ex:
|
except SmartDeviceException as ex:
|
||||||
raise AuthenticationException(ex) from ex
|
raise AuthenticationException(ex) from ex
|
||||||
self._login_token = resp_dict["result"]["token"]
|
self._login_token = resp_dict["result"]["token"]
|
||||||
|
|
||||||
@property
|
async def perform_login(self) -> None:
|
||||||
def needs_login(self) -> bool:
|
|
||||||
"""Return true if the transport needs to do a login."""
|
|
||||||
return self._login_token is None
|
|
||||||
|
|
||||||
async def login(self, request: str) -> None:
|
|
||||||
"""Login to the device."""
|
"""Login to the device."""
|
||||||
try:
|
try:
|
||||||
if self.needs_handshake:
|
await self._perform_login_for_version(login_version=2)
|
||||||
raise SmartDeviceException(
|
|
||||||
"Handshake must be complete before trying to login"
|
|
||||||
)
|
|
||||||
await self.perform_login(request, login_v2=False)
|
|
||||||
except AuthenticationException:
|
except AuthenticationException:
|
||||||
|
_LOGGER.warning("Login version 2 failed, trying version 1")
|
||||||
await self.perform_handshake()
|
await self.perform_handshake()
|
||||||
await self.perform_login(request, login_v2=True)
|
await self._perform_login_for_version(login_version=1)
|
||||||
|
|
||||||
@property
|
|
||||||
def needs_handshake(self) -> bool:
|
|
||||||
"""Return true if the transport needs to do a handshake."""
|
|
||||||
return not self._handshake_done or self._handshake_session_expired()
|
|
||||||
|
|
||||||
async def handshake(self) -> None:
|
|
||||||
"""Perform the encryption handshake."""
|
|
||||||
await self.perform_handshake()
|
|
||||||
|
|
||||||
async def perform_handshake(self):
|
async def perform_handshake(self):
|
||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
@ -217,7 +204,7 @@ class AesTransport(BaseTransport):
|
|||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
url = f"http://{self.host}/app"
|
url = f"http://{self._host}/app"
|
||||||
key_pair = KeyPair.create_key_pair()
|
key_pair = KeyPair.create_key_pair()
|
||||||
|
|
||||||
pub_key = (
|
pub_key = (
|
||||||
@ -238,7 +225,7 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"{self.host} responded with an unexpected "
|
f"{self._host} responded with an unexpected "
|
||||||
+ f"status code {status_code} to handshake"
|
+ f"status code {status_code} to handshake"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -261,7 +248,7 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
self._handshake_done = True
|
self._handshake_done = True
|
||||||
|
|
||||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||||
|
|
||||||
def _handshake_session_expired(self):
|
def _handshake_session_expired(self):
|
||||||
"""Return true if session has expired."""
|
"""Return true if session has expired."""
|
||||||
@ -272,12 +259,10 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
async def send(self, request: str):
|
async def send(self, request: str):
|
||||||
"""Send the request."""
|
"""Send the request."""
|
||||||
if self.needs_handshake:
|
if not self._handshake_done or self._handshake_session_expired():
|
||||||
raise SmartDeviceException(
|
await self.perform_handshake()
|
||||||
"Handshake must be complete before trying to send"
|
if not self._login_token:
|
||||||
)
|
await self.perform_login()
|
||||||
if self.needs_login:
|
|
||||||
raise SmartDeviceException("Login must be complete before trying to send")
|
|
||||||
|
|
||||||
return await self.send_secure_passthrough(request)
|
return await self.send_secure_passthrough(request)
|
||||||
|
|
||||||
|
@ -74,7 +74,12 @@ async def connect(
|
|||||||
host=host, port=port, credentials=credentials, timeout=timeout
|
host=host, port=port, credentials=credentials, timeout=timeout
|
||||||
)
|
)
|
||||||
if protocol_class is not None:
|
if protocol_class is not None:
|
||||||
dev.protocol = protocol_class(host, credentials=credentials)
|
dev.protocol = protocol_class(
|
||||||
|
host,
|
||||||
|
transport=AesTransport(
|
||||||
|
host, port=port, credentials=credentials, timeout=timeout
|
||||||
|
),
|
||||||
|
)
|
||||||
await dev.update()
|
await dev.update()
|
||||||
if debug_enabled:
|
if debug_enabled:
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
@ -90,7 +95,13 @@ async def connect(
|
|||||||
host=host, port=port, credentials=credentials, timeout=timeout
|
host=host, port=port, credentials=credentials, timeout=timeout
|
||||||
)
|
)
|
||||||
if protocol_class is not None:
|
if protocol_class is not None:
|
||||||
unknown_dev.protocol = protocol_class(host, credentials=credentials)
|
# TODO this will be replaced with connection params
|
||||||
|
unknown_dev.protocol = protocol_class(
|
||||||
|
host,
|
||||||
|
transport=AesTransport(
|
||||||
|
host, port=port, credentials=credentials, timeout=timeout
|
||||||
|
),
|
||||||
|
)
|
||||||
await unknown_dev.update()
|
await unknown_dev.update()
|
||||||
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
|
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
|
||||||
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||||
@ -163,7 +174,5 @@ def get_protocol_from_connection_name(
|
|||||||
|
|
||||||
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
|
protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
|
||||||
transport: BaseTransport = transport_class(host, credentials=credentials)
|
transport: BaseTransport = transport_class(host, credentials=credentials)
|
||||||
protocol: TPLinkProtocol = protocol_class(
|
protocol: TPLinkProtocol = protocol_class(host, transport=transport)
|
||||||
host, credentials=credentials, transport=transport
|
|
||||||
)
|
|
||||||
return protocol
|
return protocol
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
"""Module for the IOT legacy IOT KASA protocol."""
|
"""Module for the IOT legacy IOT KASA protocol."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .credentials import Credentials
|
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
from .json import dumps as json_dumps
|
from .json import dumps as json_dumps
|
||||||
from .klaptransport import KlapTransport
|
|
||||||
from .protocol import BaseTransport, TPLinkProtocol
|
from .protocol import BaseTransport, TPLinkProtocol
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -17,24 +15,14 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
class IotProtocol(TPLinkProtocol):
|
class IotProtocol(TPLinkProtocol):
|
||||||
"""Class for the legacy TPLink IOT KASA Protocol."""
|
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||||
|
|
||||||
DEFAULT_PORT = 80
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
transport: Optional[BaseTransport] = None,
|
transport: BaseTransport,
|
||||||
credentials: Optional[Credentials] = None,
|
|
||||||
timeout: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
"""Create a protocol object."""
|
||||||
|
super().__init__(host, transport=transport)
|
||||||
self._credentials: Credentials = credentials or Credentials(
|
|
||||||
username="", password=""
|
|
||||||
)
|
|
||||||
self._transport: BaseTransport = transport or KlapTransport(
|
|
||||||
host, credentials=self._credentials, timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
self._query_lock = asyncio.Lock()
|
self._query_lock = asyncio.Lock()
|
||||||
|
|
||||||
@ -54,30 +42,32 @@ class IotProtocol(TPLinkProtocol):
|
|||||||
except httpx.CloseError as sdex:
|
except httpx.CloseError as sdex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||||
) from sdex
|
) from sdex
|
||||||
continue
|
continue
|
||||||
except httpx.ConnectError as cex:
|
except httpx.ConnectError as cex:
|
||||||
await self.close()
|
await self.close()
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {cex}"
|
f"Unable to connect to the device: {self._host}: {cex}"
|
||||||
) from cex
|
) from cex
|
||||||
except TimeoutError as tex:
|
except TimeoutError as tex:
|
||||||
await self.close()
|
await self.close()
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
f"Unable to connect to the device, timed out: {self._host}: {tex}"
|
||||||
) from tex
|
) from tex
|
||||||
except AuthenticationException as auex:
|
except AuthenticationException as auex:
|
||||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
_LOGGER.debug(
|
||||||
|
"Unable to authenticate with %s, not retrying", self._host
|
||||||
|
)
|
||||||
raise auex
|
raise auex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {ex}"
|
f"Unable to connect to the device: {self._host}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -85,14 +75,6 @@ class IotProtocol(TPLinkProtocol):
|
|||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
|
||||||
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
async def _execute_query(self, request: str, retry_count: int) -> Dict:
|
||||||
if self._transport.needs_handshake:
|
|
||||||
await self._transport.handshake()
|
|
||||||
|
|
||||||
if self._transport.needs_login: # This shouln't happen
|
|
||||||
raise SmartDeviceException(
|
|
||||||
"IOT Protocol needs to login to transport but is not login aware"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._transport.send(request)
|
return await self._transport.send(request)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
@ -82,7 +82,7 @@ class KlapTransport(BaseTransport):
|
|||||||
protocol, used by newer firmware versions.
|
protocol, used by newer firmware versions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 5
|
DEFAULT_PORT = 80
|
||||||
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
DISCOVERY_QUERY = {"system": {"get_sysinfo": None}}
|
||||||
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
KASA_SETUP_EMAIL = "kasa@tp-link.net"
|
||||||
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
|
KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105
|
||||||
@ -92,12 +92,17 @@ class KlapTransport(BaseTransport):
|
|||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
|
port: Optional[int] = None,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(host=host)
|
super().__init__(
|
||||||
|
host,
|
||||||
|
port=port or self.DEFAULT_PORT,
|
||||||
|
credentials=credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
self._credentials = credentials or Credentials(username="", password="")
|
|
||||||
self._local_seed: Optional[bytes] = None
|
self._local_seed: Optional[bytes] = None
|
||||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||||
@ -110,11 +115,10 @@ class KlapTransport(BaseTransport):
|
|||||||
self._encryption_session: Optional[KlapEncryptionSession] = None
|
self._encryption_session: Optional[KlapEncryptionSession] = None
|
||||||
self._session_expire_at: Optional[float] = None
|
self._session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT
|
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
|
||||||
|
|
||||||
_LOGGER.debug("Created KLAP object for %s", self.host)
|
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||||
|
|
||||||
async def client_post(self, url, params=None, data=None):
|
async def client_post(self, url, params=None, data=None):
|
||||||
"""Send an http post request to the device."""
|
"""Send an http post request to the device."""
|
||||||
@ -148,7 +152,7 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
payload = local_seed
|
payload = local_seed
|
||||||
|
|
||||||
url = f"http://{self.host}/app/handshake1"
|
url = f"http://{self._host}/app/handshake1"
|
||||||
|
|
||||||
response_status, response_data = await self.client_post(url, data=payload)
|
response_status, response_data = await self.client_post(url, data=payload)
|
||||||
|
|
||||||
@ -157,14 +161,14 @@ class KlapTransport(BaseTransport):
|
|||||||
"Handshake1 posted at %s. Host is %s, Response"
|
"Handshake1 posted at %s. Host is %s, Response"
|
||||||
+ "status is %s, Request was %s",
|
+ "status is %s, Request was %s",
|
||||||
datetime.datetime.now(),
|
datetime.datetime.now(),
|
||||||
self.host,
|
self._host,
|
||||||
response_status,
|
response_status,
|
||||||
payload.hex(),
|
payload.hex(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if response_status != 200:
|
if response_status != 200:
|
||||||
raise AuthenticationException(
|
raise AuthenticationException(
|
||||||
f"Device {self.host} responded with {response_status} to handshake1"
|
f"Device {self._host} responded with {response_status} to handshake1"
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_seed: bytes = response_data[0:16]
|
remote_seed: bytes = response_data[0:16]
|
||||||
@ -175,7 +179,7 @@ class KlapTransport(BaseTransport):
|
|||||||
"Handshake1 success at %s. Host is %s, "
|
"Handshake1 success at %s. Host is %s, "
|
||||||
+ "Server remote_seed is: %s, server hash is: %s",
|
+ "Server remote_seed is: %s, server hash is: %s",
|
||||||
datetime.datetime.now(),
|
datetime.datetime.now(),
|
||||||
self.host,
|
self._host,
|
||||||
remote_seed.hex(),
|
remote_seed.hex(),
|
||||||
server_hash.hex(),
|
server_hash.hex(),
|
||||||
)
|
)
|
||||||
@ -207,7 +211,7 @@ class KlapTransport(BaseTransport):
|
|||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Server response doesn't match our expected hash on ip %s"
|
"Server response doesn't match our expected hash on ip %s"
|
||||||
+ " but an authentication with kasa setup credentials matched",
|
+ " but an authentication with kasa setup credentials matched",
|
||||||
self.host,
|
self._host,
|
||||||
)
|
)
|
||||||
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
|
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
|
||||||
|
|
||||||
@ -226,11 +230,11 @@ class KlapTransport(BaseTransport):
|
|||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Server response doesn't match our expected hash on ip %s"
|
"Server response doesn't match our expected hash on ip %s"
|
||||||
+ " but an authentication with blank credentials matched",
|
+ " but an authentication with blank credentials matched",
|
||||||
self.host,
|
self._host,
|
||||||
)
|
)
|
||||||
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
|
return local_seed, remote_seed, self._blank_auth_hash # type: ignore
|
||||||
|
|
||||||
msg = f"Server response doesn't match our challenge on ip {self.host}"
|
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
||||||
_LOGGER.debug(msg)
|
_LOGGER.debug(msg)
|
||||||
raise AuthenticationException(msg)
|
raise AuthenticationException(msg)
|
||||||
|
|
||||||
@ -241,7 +245,7 @@ class KlapTransport(BaseTransport):
|
|||||||
# Handshake 2 has the following payload:
|
# Handshake 2 has the following payload:
|
||||||
# sha256(serverBytes | authenticator)
|
# sha256(serverBytes | authenticator)
|
||||||
|
|
||||||
url = f"http://{self.host}/app/handshake2"
|
url = f"http://{self._host}/app/handshake2"
|
||||||
|
|
||||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||||
|
|
||||||
@ -252,44 +256,24 @@ class KlapTransport(BaseTransport):
|
|||||||
"Handshake2 posted %s. Host is %s, Response status is %s, "
|
"Handshake2 posted %s. Host is %s, Response status is %s, "
|
||||||
+ "Request was %s",
|
+ "Request was %s",
|
||||||
datetime.datetime.now(),
|
datetime.datetime.now(),
|
||||||
self.host,
|
self._host,
|
||||||
response_status,
|
response_status,
|
||||||
payload.hex(),
|
payload.hex(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if response_status != 200:
|
if response_status != 200:
|
||||||
raise AuthenticationException(
|
raise AuthenticationException(
|
||||||
f"Device {self.host} responded with {response_status} to handshake2"
|
f"Device {self._host} responded with {response_status} to handshake2"
|
||||||
)
|
)
|
||||||
|
|
||||||
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
return KlapEncryptionSession(local_seed, remote_seed, auth_hash)
|
||||||
|
|
||||||
@property
|
|
||||||
def needs_login(self) -> bool:
|
|
||||||
"""Will return false as KLAP does not do a login."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def login(self, request: str) -> None:
|
|
||||||
"""Will raise and exception as KLAP does not do a login."""
|
|
||||||
raise SmartDeviceException(
|
|
||||||
"KLAP does not perform logins and return needs_login == False"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def needs_handshake(self) -> bool:
|
|
||||||
"""Return true if the transport needs to do a handshake."""
|
|
||||||
return not self._handshake_done or self._handshake_session_expired()
|
|
||||||
|
|
||||||
async def handshake(self) -> None:
|
|
||||||
"""Perform the encryption handshake."""
|
|
||||||
await self.perform_handshake()
|
|
||||||
|
|
||||||
async def perform_handshake(self) -> Any:
|
async def perform_handshake(self) -> Any:
|
||||||
"""Perform handshake1 and handshake2.
|
"""Perform handshake1 and handshake2.
|
||||||
|
|
||||||
Sets the encryption_session if successful.
|
Sets the encryption_session if successful.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug("Starting handshake with %s", self.host)
|
_LOGGER.debug("Starting handshake with %s", self._host)
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
@ -307,7 +291,7 @@ class KlapTransport(BaseTransport):
|
|||||||
)
|
)
|
||||||
self._handshake_done = True
|
self._handshake_done = True
|
||||||
|
|
||||||
_LOGGER.debug("Handshake with %s complete", self.host)
|
_LOGGER.debug("Handshake with %s complete", self._host)
|
||||||
|
|
||||||
def _handshake_session_expired(self):
|
def _handshake_session_expired(self):
|
||||||
"""Return true if session has expired."""
|
"""Return true if session has expired."""
|
||||||
@ -318,18 +302,14 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
async def send(self, request: str):
|
async def send(self, request: str):
|
||||||
"""Send the request."""
|
"""Send the request."""
|
||||||
if self.needs_handshake:
|
if not self._handshake_done or self._handshake_session_expired():
|
||||||
raise SmartDeviceException(
|
await self.perform_handshake()
|
||||||
"Handshake must be complete before trying to send"
|
|
||||||
)
|
|
||||||
if self.needs_login:
|
|
||||||
raise SmartDeviceException("Login must be complete before trying to send")
|
|
||||||
|
|
||||||
# Check for mypy
|
# Check for mypy
|
||||||
if self._encryption_session is not None:
|
if self._encryption_session is not None:
|
||||||
payload, seq = self._encryption_session.encrypt(request.encode())
|
payload, seq = self._encryption_session.encrypt(request.encode())
|
||||||
|
|
||||||
url = f"http://{self.host}/app/request"
|
url = f"http://{self._host}/app/request"
|
||||||
|
|
||||||
response_status, response_data = await self.client_post(
|
response_status, response_data = await self.client_post(
|
||||||
url,
|
url,
|
||||||
@ -338,7 +318,7 @@ class KlapTransport(BaseTransport):
|
|||||||
)
|
)
|
||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
f"at {datetime.datetime.now()}. Host is {self.host}, "
|
f"at {datetime.datetime.now()}. Host is {self._host}, "
|
||||||
+ f"Sequence is {seq}, "
|
+ f"Sequence is {seq}, "
|
||||||
+ f"Response status is {response_status}, Request was {request}"
|
+ f"Response status is {response_status}, Request was {request}"
|
||||||
)
|
)
|
||||||
@ -348,12 +328,12 @@ class KlapTransport(BaseTransport):
|
|||||||
if response_status == 403:
|
if response_status == 403:
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
raise AuthenticationException(
|
raise AuthenticationException(
|
||||||
f"Got a security error from {self.host} after handshake "
|
f"Got a security error from {self._host} after handshake "
|
||||||
+ "completed"
|
+ "completed"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Device {self.host} responded with {response_status} to"
|
f"Device {self._host} responded with {response_status} to"
|
||||||
+ f"request with seq {seq}"
|
+ f"request with seq {seq}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -367,7 +347,7 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s << %s",
|
"%s << %s",
|
||||||
self.host,
|
self._host,
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
108
kasa/protocol.py
108
kasa/protocol.py
@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes:
|
|||||||
class BaseTransport(ABC):
|
class BaseTransport(ABC):
|
||||||
"""Base class for all TP-Link protocol transports."""
|
"""Base class for all TP-Link protocol transports."""
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 5
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a protocol object."""
|
"""Create a protocol object."""
|
||||||
self.host = host
|
self._host = host
|
||||||
self.port = port
|
self._port = port
|
||||||
self.credentials = credentials
|
self._credentials = credentials or Credentials(username="", password="")
|
||||||
|
self._timeout = timeout or self.DEFAULT_TIMEOUT
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def needs_handshake(self) -> bool:
|
|
||||||
"""Return true if the transport needs to do a handshake."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def needs_login(self) -> bool:
|
|
||||||
"""Return true if the transport needs to do a login."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def login(self, request: str) -> None:
|
|
||||||
"""Login to the device."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handshake(self) -> None:
|
|
||||||
"""Perform the encryption handshake."""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def send(self, request: str) -> Dict:
|
async def send(self, request: str) -> Dict:
|
||||||
@ -90,14 +76,14 @@ class TPLinkProtocol(ABC):
|
|||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
port: Optional[int] = None,
|
transport: BaseTransport,
|
||||||
credentials: Optional[Credentials] = None,
|
|
||||||
transport: Optional[BaseTransport] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a protocol object."""
|
"""Create a protocol object."""
|
||||||
self.host = host
|
self._transport = transport
|
||||||
self.port = port
|
|
||||||
self.credentials = credentials
|
@property
|
||||||
|
def _host(self):
|
||||||
|
return self._transport._host
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
@ -108,6 +94,40 @@ class TPLinkProtocol(ABC):
|
|||||||
"""Close the protocol. Abstract method to be overriden."""
|
"""Close the protocol. Abstract method to be overriden."""
|
||||||
|
|
||||||
|
|
||||||
|
class _XorTransport(BaseTransport):
|
||||||
|
"""Implementation of the Xor encryption transport.
|
||||||
|
|
||||||
|
WIP, currently only to ensure consistent __init__ method signatures
|
||||||
|
for protocol classes. Will eventually incorporate the logic from
|
||||||
|
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
|
||||||
|
class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_PORT = 9999
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
host,
|
||||||
|
port=port or self.DEFAULT_PORT,
|
||||||
|
credentials=credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send(self, request: str) -> Dict:
|
||||||
|
"""Send a message to the device and return a response."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the transport. Abstract method to be overriden."""
|
||||||
|
|
||||||
|
|
||||||
class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||||
"""Implementation of the TP-Link Smart Home protocol."""
|
"""Implementation of the TP-Link Smart Home protocol."""
|
||||||
|
|
||||||
@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
port: Optional[int] = None,
|
transport: BaseTransport,
|
||||||
timeout: Optional[int] = None,
|
|
||||||
credentials: Optional[Credentials] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a protocol object."""
|
"""Create a protocol object."""
|
||||||
super().__init__(
|
super().__init__(host, transport=transport)
|
||||||
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
self.reader: Optional[asyncio.StreamReader] = None
|
self.reader: Optional[asyncio.StreamReader] = None
|
||||||
self.writer: Optional[asyncio.StreamWriter] = None
|
self.writer: Optional[asyncio.StreamWriter] = None
|
||||||
self.query_lock = asyncio.Lock()
|
self.query_lock = asyncio.Lock()
|
||||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
|
||||||
|
self._timeout = self._transport._timeout
|
||||||
|
self._port = self._transport._port
|
||||||
|
|
||||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||||
"""Request information from a TP-Link SmartHome Device.
|
"""Request information from a TP-Link SmartHome Device.
|
||||||
@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
assert isinstance(request, str) # noqa: S101
|
assert isinstance(request, str) # noqa: S101
|
||||||
|
|
||||||
async with self.query_lock:
|
async with self.query_lock:
|
||||||
return await self._query(request, retry_count, self.timeout)
|
return await self._query(request, retry_count, self._timeout)
|
||||||
|
|
||||||
async def _connect(self, timeout: int) -> None:
|
async def _connect(self, timeout: int) -> None:
|
||||||
"""Try to connect or reconnect to the device."""
|
"""Try to connect or reconnect to the device."""
|
||||||
@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
return
|
return
|
||||||
self.reader = self.writer = None
|
self.reader = self.writer = None
|
||||||
|
|
||||||
task = asyncio.open_connection(self.host, self.port)
|
task = asyncio.open_connection(self._host, self._port)
|
||||||
async with asyncio_timeout(timeout):
|
async with asyncio_timeout(timeout):
|
||||||
self.reader, self.writer = await task
|
self.reader, self.writer = await task
|
||||||
sock: socket.socket = self.writer.get_extra_info("socket")
|
sock: socket.socket = self.writer.get_extra_info("socket")
|
||||||
@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
|
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
|
|
||||||
if debug_log:
|
if debug_log:
|
||||||
_LOGGER.debug("%s >> %s", self.host, request)
|
_LOGGER.debug("%s >> %s", self._host, request)
|
||||||
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
||||||
json_payload = json_loads(response)
|
json_payload = json_loads(response)
|
||||||
if debug_log:
|
if debug_log:
|
||||||
_LOGGER.debug("%s << %s", self.host, pf(json_payload))
|
_LOGGER.debug("%s << %s", self._host, pf(json_payload))
|
||||||
|
|
||||||
return json_payload
|
return json_payload
|
||||||
|
|
||||||
@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
except ConnectionRefusedError as ex:
|
except ConnectionRefusedError as ex:
|
||||||
await self.close()
|
await self.close()
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
|
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
except OSError as ex:
|
except OSError as ex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
|
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device:"
|
f"Unable to connect to the device:"
|
||||||
f" {self.host}:{self.port}: {ex}"
|
f" {self._host}:{self._port}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
continue
|
continue
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device:"
|
f"Unable to connect to the device:"
|
||||||
f" {self.host}:{self.port}: {ex}"
|
f" {self._host}:{self._port}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to query the device {self.host}:{self.port}: {ex}"
|
f"Unable to query the device {self._host}:{self._port}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Unable to query the device %s, retrying: %s", self.host, ex
|
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||||
)
|
)
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
|
@ -24,7 +24,7 @@ from .device_type import DeviceType
|
|||||||
from .emeterstatus import EmeterStatus
|
from .emeterstatus import EmeterStatus
|
||||||
from .exceptions import SmartDeviceException
|
from .exceptions import SmartDeviceException
|
||||||
from .modules import Emeter, Module
|
from .modules import Emeter, Module
|
||||||
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, _XorTransport
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ class SmartDevice:
|
|||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
|
self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol(
|
||||||
host, port=port, timeout=timeout
|
host, transport=_XorTransport(host, port=port, timeout=timeout)
|
||||||
)
|
)
|
||||||
self.credentials = credentials
|
self.credentials = credentials
|
||||||
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
|
||||||
|
@ -10,12 +10,10 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .aestransport import AesTransport
|
|
||||||
from .credentials import Credentials
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
SMART_AUTHENTICATION_ERRORS,
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
SMART_RETRYABLE_ERRORS,
|
SMART_RETRYABLE_ERRORS,
|
||||||
@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False
|
|||||||
class SmartProtocol(TPLinkProtocol):
|
class SmartProtocol(TPLinkProtocol):
|
||||||
"""Class for the new TPLink SMART protocol."""
|
"""Class for the new TPLink SMART protocol."""
|
||||||
|
|
||||||
DEFAULT_PORT = 80
|
|
||||||
SLEEP_SECONDS_AFTER_TIMEOUT = 1
|
SLEEP_SECONDS_AFTER_TIMEOUT = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
*,
|
*,
|
||||||
transport: Optional[BaseTransport] = None,
|
transport: BaseTransport,
|
||||||
credentials: Optional[Credentials] = None,
|
|
||||||
timeout: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
"""Create a protocol object."""
|
||||||
|
super().__init__(host, transport=transport)
|
||||||
self._credentials: Credentials = credentials or Credentials(
|
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
||||||
username="", password=""
|
|
||||||
)
|
|
||||||
self._transport: BaseTransport = transport or AesTransport(
|
|
||||||
host, credentials=self._credentials, timeout=timeout
|
|
||||||
)
|
|
||||||
self._terminal_uuid: Optional[str] = None
|
|
||||||
self._request_id_generator = SnowflakeId(1, 1)
|
self._request_id_generator = SnowflakeId(1, 1)
|
||||||
self._query_lock = asyncio.Lock()
|
self._query_lock = asyncio.Lock()
|
||||||
|
|
||||||
@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
) != SmartErrorCode.SUCCESS:
|
) != SmartErrorCode.SUCCESS:
|
||||||
msg = (
|
msg = (
|
||||||
f"Error querying device: {self.host}: "
|
f"Error querying device: {self._host}: "
|
||||||
+ f"{error_code.name}({error_code.value})"
|
+ f"{error_code.name}({error_code.value})"
|
||||||
)
|
)
|
||||||
if error_code in SMART_TIMEOUT_ERRORS:
|
if error_code in SMART_TIMEOUT_ERRORS:
|
||||||
@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
except httpx.CloseError as sdex:
|
except httpx.CloseError as sdex:
|
||||||
await self.close()
|
await self.close()
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
f"Unable to connect to the device: {self._host}: {sdex}"
|
||||||
) from sdex
|
) from sdex
|
||||||
continue
|
continue
|
||||||
except httpx.ConnectError as cex:
|
except httpx.ConnectError as cex:
|
||||||
await self.close()
|
await self.close()
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to connect to the device: {self.host}: {cex}"
|
f"Unable to connect to the device: {self._host}: {cex}"
|
||||||
) from cex
|
) from cex
|
||||||
except TimeoutError as tex:
|
except TimeoutError as tex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
"Unable to connect to the device, "
|
"Unable to connect to the device, "
|
||||||
+ f"timed out: {self.host}: {tex}"
|
+ f"timed out: {self._host}: {tex}"
|
||||||
) from tex
|
) from tex
|
||||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||||
continue
|
continue
|
||||||
except AuthenticationException as auex:
|
except AuthenticationException as auex:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
_LOGGER.debug(
|
||||||
|
"Unable to authenticate with %s, not retrying", self._host
|
||||||
|
)
|
||||||
raise auex
|
raise auex
|
||||||
except RetryableException as ex:
|
except RetryableException as ex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise ex
|
raise ex
|
||||||
continue
|
continue
|
||||||
except TimeoutException as ex:
|
except TimeoutException as ex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise ex
|
raise ex
|
||||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
||||||
continue
|
continue
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Unable to query the device {self.host}:{self.port}: {ex}"
|
f"Unable to connect to the device: {self._host}: {ex}"
|
||||||
) from ex
|
) from ex
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Unable to query the device %s, retrying: %s", self.host, ex
|
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
smart_method = request
|
smart_method = request
|
||||||
smart_params = None
|
smart_params = None
|
||||||
|
|
||||||
if self._transport.needs_handshake:
|
|
||||||
await self._transport.handshake()
|
|
||||||
|
|
||||||
if self._transport.needs_login:
|
|
||||||
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
|
|
||||||
"UTF-8"
|
|
||||||
)
|
|
||||||
login_request = self.get_smart_request("login_device")
|
|
||||||
await self._transport.login(login_request)
|
|
||||||
|
|
||||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s >> %s",
|
"%s >> %s",
|
||||||
self.host,
|
self._host,
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
|
||||||
)
|
)
|
||||||
response_data = await self._transport.send(smart_request)
|
response_data = await self._transport.send(smart_request)
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s << %s",
|
"%s << %s",
|
||||||
self.host,
|
self._host,
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, Optional, Set, cast
|
from typing import Any, Dict, Optional, Set, cast
|
||||||
|
|
||||||
|
from ..aestransport import AesTransport
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
from ..exceptions import AuthenticationException
|
from ..exceptions import AuthenticationException
|
||||||
from ..smartdevice import SmartDevice
|
from ..smartdevice import SmartDevice
|
||||||
@ -27,7 +28,12 @@ class TapoDevice(SmartDevice):
|
|||||||
self._components: Optional[Dict[str, Any]] = None
|
self._components: Optional[Dict[str, Any]] = None
|
||||||
self._state_information: Dict[str, Any] = {}
|
self._state_information: Dict[str, Any] = {}
|
||||||
self._discovery_info: Optional[Dict[str, Any]] = None
|
self._discovery_info: Optional[Dict[str, Any]] = None
|
||||||
self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout)
|
self.protocol = SmartProtocol(
|
||||||
|
host,
|
||||||
|
transport=AesTransport(
|
||||||
|
host, credentials=credentials, timeout=timeout, port=port
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async def update(self, update_children: bool = True):
|
async def update(self, update_children: bool = True):
|
||||||
"""Update the device."""
|
"""Update the device."""
|
||||||
|
@ -301,6 +301,9 @@ class FakeSmartProtocol(SmartProtocol):
|
|||||||
|
|
||||||
class FakeSmartTransport(BaseTransport):
|
class FakeSmartTransport(BaseTransport):
|
||||||
def __init__(self, info):
|
def __init__(self, info):
|
||||||
|
super().__init__(
|
||||||
|
"127.0.0.123",
|
||||||
|
)
|
||||||
self.info = info
|
self.info = info
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
174
kasa/tests/test_aestransport.py
Normal file
174
kasa/tests/test_aestransport.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext as does_not_raise
|
||||||
|
from json import dumps as json_dumps
|
||||||
|
from json import loads as json_loads
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
|
|
||||||
|
from ..aestransport import AesEncyptionSession, AesTransport
|
||||||
|
from ..credentials import Credentials
|
||||||
|
from ..exceptions import SmartDeviceException
|
||||||
|
|
||||||
|
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||||
|
|
||||||
|
key = b"8\x89\x02\xfa\xf5Xs\x1c\xa1 H\x9a\x82\xc7\xd9\t"
|
||||||
|
iv = b"9=\xf8\x1bS\xcd0\xb5\x89i\xba\xfd^9\x9f\xfa"
|
||||||
|
KEY_IV = key + iv
|
||||||
|
|
||||||
|
|
||||||
|
def test_encrypt():
|
||||||
|
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||||
|
|
||||||
|
d = json.dumps({"foo": 1, "bar": 2})
|
||||||
|
encrypted = encryption_session.encrypt(d.encode())
|
||||||
|
assert d == encryption_session.decrypt(encrypted)
|
||||||
|
|
||||||
|
# test encrypt unicode
|
||||||
|
d = "{'snowman': '\u2603'}"
|
||||||
|
encrypted = encryption_session.encrypt(d.encode())
|
||||||
|
assert d == encryption_session.decrypt(encrypted)
|
||||||
|
|
||||||
|
|
||||||
|
status_parameters = pytest.mark.parametrize(
|
||||||
|
"status_code, error_code, inner_error_code, expectation",
|
||||||
|
[
|
||||||
|
(200, 0, 0, does_not_raise()),
|
||||||
|
(400, 0, 0, pytest.raises(SmartDeviceException)),
|
||||||
|
(200, -1, 0, pytest.raises(SmartDeviceException)),
|
||||||
|
],
|
||||||
|
ids=("success", "status_code", "error_code"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@status_parameters
|
||||||
|
async def test_handshake(
|
||||||
|
mocker, status_code, error_code, inner_error_code, expectation
|
||||||
|
):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||||
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||||
|
|
||||||
|
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||||
|
|
||||||
|
assert transport._encryption_session is None
|
||||||
|
assert transport._handshake_done is False
|
||||||
|
with expectation:
|
||||||
|
await transport.perform_handshake()
|
||||||
|
assert transport._encryption_session is not None
|
||||||
|
assert transport._handshake_done is True
|
||||||
|
|
||||||
|
|
||||||
|
@status_parameters
|
||||||
|
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||||
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||||
|
|
||||||
|
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||||
|
transport._handshake_done = True
|
||||||
|
transport._session_expire_at = time.time() + 86400
|
||||||
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
|
||||||
|
assert transport._login_token is None
|
||||||
|
with expectation:
|
||||||
|
await transport.perform_login()
|
||||||
|
assert transport._login_token == mock_aes_device.token
|
||||||
|
|
||||||
|
|
||||||
|
@status_parameters
|
||||||
|
async def test_send(mocker, status_code, error_code, inner_error_code, expectation):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||||
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||||
|
|
||||||
|
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||||
|
transport._handshake_done = True
|
||||||
|
transport._session_expire_at = time.time() + 86400
|
||||||
|
transport._encryption_session = mock_aes_device.encryption_session
|
||||||
|
transport._login_token = mock_aes_device.token
|
||||||
|
|
||||||
|
un, pw = transport.hash_credentials(True)
|
||||||
|
request = {
|
||||||
|
"method": "get_device_info",
|
||||||
|
"params": None,
|
||||||
|
"request_time_milis": round(time.time() * 1000),
|
||||||
|
"requestID": 1,
|
||||||
|
"terminal_uuid": "foobar",
|
||||||
|
}
|
||||||
|
with expectation:
|
||||||
|
res = await transport.send(json_dumps(request))
|
||||||
|
assert "result" in res
|
||||||
|
|
||||||
|
|
||||||
|
class MockAesDevice:
|
||||||
|
class _mock_response:
|
||||||
|
def __init__(self, status_code, json: dict):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json = json
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json
|
||||||
|
|
||||||
|
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
|
||||||
|
token = "test_token" # noqa
|
||||||
|
|
||||||
|
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
|
||||||
|
self.host = host
|
||||||
|
self.status_code = status_code
|
||||||
|
self.error_code = error_code
|
||||||
|
self.inner_error_code = inner_error_code
|
||||||
|
|
||||||
|
async def post(self, url, params=None, json=None, *_, **__):
|
||||||
|
return await self._post(url, json)
|
||||||
|
|
||||||
|
async def _post(self, url, json):
|
||||||
|
if json["method"] == "handshake":
|
||||||
|
return await self._return_handshake_response(url, json)
|
||||||
|
elif json["method"] == "securePassthrough":
|
||||||
|
return await self._return_secure_passthrough_response(url, json)
|
||||||
|
elif json["method"] == "login_device":
|
||||||
|
return await self._return_login_response(url, json)
|
||||||
|
else:
|
||||||
|
assert url == f"http://{self.host}/app?token={self.token}"
|
||||||
|
return await self._return_send_response(url, json)
|
||||||
|
|
||||||
|
async def _return_handshake_response(self, url, json):
|
||||||
|
start = len("-----BEGIN PUBLIC KEY-----\n")
|
||||||
|
end = len("\n-----END PUBLIC KEY-----\n")
|
||||||
|
client_pub_key = json["params"]["key"][start:-end]
|
||||||
|
|
||||||
|
client_pub_key_data = base64.b64decode(client_pub_key.encode())
|
||||||
|
client_pub_key = serialization.load_der_public_key(client_pub_key_data, None)
|
||||||
|
encrypted_key = client_pub_key.encrypt(KEY_IV, asymmetric_padding.PKCS1v15())
|
||||||
|
key_64 = base64.b64encode(encrypted_key).decode()
|
||||||
|
return self._mock_response(
|
||||||
|
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _return_secure_passthrough_response(self, url, json):
|
||||||
|
encrypted_request = json["params"]["request"]
|
||||||
|
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
|
||||||
|
decrypted_request_dict = json_loads(decrypted_request)
|
||||||
|
decrypted_response = await self._post(url, decrypted_request_dict)
|
||||||
|
decrypted_response_dict = decrypted_response.json()
|
||||||
|
encrypted_response = self.encryption_session.encrypt(
|
||||||
|
json_dumps(decrypted_response_dict).encode()
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"result": {"response": encrypted_response.decode()},
|
||||||
|
"error_code": self.error_code,
|
||||||
|
}
|
||||||
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
|
async def _return_login_response(self, url, json):
|
||||||
|
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
|
||||||
|
return self._mock_response(self.status_code, result)
|
||||||
|
|
||||||
|
async def _return_send_response(self, url, json):
|
||||||
|
result = {"result": {"method": None}, "error_code": self.inner_error_code}
|
||||||
|
return self._mock_response(self.status_code, result)
|
@ -96,10 +96,9 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
|||||||
|
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(transport_class, "perform_handshake")
|
||||||
transport_class, "needs_handshake", property(lambda self: False)
|
if hasattr(transport_class, "perform_login"):
|
||||||
)
|
mocker.patch.object(transport_class, "perform_login")
|
||||||
mocker.patch.object(transport_class, "needs_login", property(lambda self: False))
|
|
||||||
|
|
||||||
send_mock = mocker.patch.object(
|
send_mock = mocker.patch.object(
|
||||||
transport_class,
|
transport_class,
|
||||||
@ -128,7 +127,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
|||||||
seed = secrets.token_bytes(16)
|
seed = secrets.token_bytes(16)
|
||||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||||
protocol = IotProtocol("127.0.0.1")
|
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
|
||||||
|
|
||||||
protocol._transport._handshake_done = True
|
protocol._transport._handshake_done = True
|
||||||
protocol._transport._session_expire_at = time.time() + 86400
|
protocol._transport._session_expire_at = time.time() + 86400
|
||||||
@ -206,7 +205,10 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
|||||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||||
)
|
)
|
||||||
|
|
||||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol(
|
||||||
|
"127.0.0.1",
|
||||||
|
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||||
|
)
|
||||||
|
|
||||||
protocol._transport.http_client = httpx.AsyncClient()
|
protocol._transport.http_client = httpx.AsyncClient()
|
||||||
with expectation:
|
with expectation:
|
||||||
@ -243,7 +245,10 @@ async def test_handshake(mocker):
|
|||||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||||
)
|
)
|
||||||
|
|
||||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol(
|
||||||
|
"127.0.0.1",
|
||||||
|
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||||
|
)
|
||||||
protocol._transport.http_client = httpx.AsyncClient()
|
protocol._transport.http_client = httpx.AsyncClient()
|
||||||
|
|
||||||
response_status = 200
|
response_status = 200
|
||||||
@ -289,7 +294,10 @@ async def test_query(mocker):
|
|||||||
|
|
||||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||||
|
|
||||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol(
|
||||||
|
"127.0.0.1",
|
||||||
|
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
resp = await protocol.query({})
|
resp = await protocol.query({})
|
||||||
@ -333,7 +341,10 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
|||||||
|
|
||||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||||
|
|
||||||
protocol = IotProtocol("127.0.0.1", credentials=client_credentials)
|
protocol = IotProtocol(
|
||||||
|
"127.0.0.1",
|
||||||
|
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||||
|
)
|
||||||
|
|
||||||
with expectation:
|
with expectation:
|
||||||
await protocol.query({})
|
await protocol.query({})
|
||||||
|
@ -1,13 +1,21 @@
|
|||||||
import errno
|
import errno
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import pkgutil
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..exceptions import SmartDeviceException
|
from ..exceptions import SmartDeviceException
|
||||||
from ..protocol import TPLinkSmartHomeProtocol
|
from ..protocol import (
|
||||||
|
BaseTransport,
|
||||||
|
TPLinkProtocol,
|
||||||
|
TPLinkSmartHomeProtocol,
|
||||||
|
_XorTransport,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
@pytest.mark.parametrize("retry_count", [1, 3, 5])
|
||||||
@ -24,7 +32,9 @@ async def test_protocol_retries(mocker, retry_count):
|
|||||||
|
|
||||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count)
|
await TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||||
|
).query({}, retry_count=retry_count)
|
||||||
|
|
||||||
assert conn.call_count == retry_count + 1
|
assert conn.call_count == retry_count + 1
|
||||||
|
|
||||||
@ -35,7 +45,9 @@ async def test_protocol_no_retry_on_unreachable(mocker):
|
|||||||
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
|
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
|
||||||
)
|
)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
await TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||||
|
).query({}, retry_count=5)
|
||||||
|
|
||||||
assert conn.call_count == 1
|
assert conn.call_count == 1
|
||||||
|
|
||||||
@ -46,7 +58,9 @@ async def test_protocol_no_retry_connection_refused(mocker):
|
|||||||
side_effect=ConnectionRefusedError,
|
side_effect=ConnectionRefusedError,
|
||||||
)
|
)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
await TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||||
|
).query({}, retry_count=5)
|
||||||
|
|
||||||
assert conn.call_count == 1
|
assert conn.call_count == 1
|
||||||
|
|
||||||
@ -57,7 +71,9 @@ async def test_protocol_retry_recoverable_error(mocker):
|
|||||||
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
|
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
|
||||||
)
|
)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)
|
await TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||||
|
).query({}, retry_count=5)
|
||||||
|
|
||||||
assert conn.call_count == 6
|
assert conn.call_count == 6
|
||||||
|
|
||||||
@ -91,7 +107,9 @@ async def test_protocol_reconnect(mocker, retry_count):
|
|||||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||||
return reader, writer
|
return reader, writer
|
||||||
|
|
||||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
protocol = TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||||
|
)
|
||||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
response = await protocol.query({}, retry_count=retry_count)
|
response = await protocol.query({}, retry_count=retry_count)
|
||||||
assert response == {"great": "success"}
|
assert response == {"great": "success"}
|
||||||
@ -119,7 +137,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
|||||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||||
return reader, writer
|
return reader, writer
|
||||||
|
|
||||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1")
|
protocol = TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||||
|
)
|
||||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
response = await protocol.query({})
|
response = await protocol.query({})
|
||||||
assert response == {"great": "success"}
|
assert response == {"great": "success"}
|
||||||
@ -153,7 +173,9 @@ async def test_protocol_custom_port(mocker, custom_port):
|
|||||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||||
return reader, writer
|
return reader, writer
|
||||||
|
|
||||||
protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port)
|
protocol = TPLinkSmartHomeProtocol(
|
||||||
|
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
|
||||||
|
)
|
||||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||||
response = await protocol.query({})
|
response = await protocol.query({})
|
||||||
assert response == {"great": "success"}
|
assert response == {"great": "success"}
|
||||||
@ -227,3 +249,63 @@ def test_decrypt_unicode():
|
|||||||
d = "{'snowman': '\u2603'}"
|
d = "{'snowman': '\u2603'}"
|
||||||
|
|
||||||
assert d == TPLinkSmartHomeProtocol.decrypt(e)
|
assert d == TPLinkSmartHomeProtocol.decrypt(e)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_subclasses(of_class):
|
||||||
|
import kasa
|
||||||
|
|
||||||
|
package = sys.modules["kasa"]
|
||||||
|
subclasses = set()
|
||||||
|
for _, modname, _ in pkgutil.iter_modules(package.__path__):
|
||||||
|
importlib.import_module("." + modname, package="kasa")
|
||||||
|
module = sys.modules["kasa." + modname]
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
if inspect.isclass(obj) and issubclass(obj, of_class):
|
||||||
|
subclasses.add((name, obj))
|
||||||
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"class_name_obj", _get_subclasses(TPLinkProtocol), ids=lambda t: t[0]
|
||||||
|
)
|
||||||
|
def test_protocol_init_signature(class_name_obj):
|
||||||
|
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||||
|
|
||||||
|
assert len(params) == 3
|
||||||
|
assert (
|
||||||
|
params[0].name == "self"
|
||||||
|
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
params[1].name == "host"
|
||||||
|
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
params[2].name == "transport"
|
||||||
|
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"class_name_obj", _get_subclasses(BaseTransport), ids=lambda t: t[0]
|
||||||
|
)
|
||||||
|
def test_transport_init_signature(class_name_obj):
|
||||||
|
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||||
|
|
||||||
|
assert len(params) == 5
|
||||||
|
assert (
|
||||||
|
params[0].name == "self"
|
||||||
|
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
params[1].name == "host"
|
||||||
|
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
)
|
||||||
|
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
assert (
|
||||||
|
params[3].name == "credentials"
|
||||||
|
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
)
|
||||||
|
@ -232,7 +232,7 @@ async def test_modules_preserved(dev: SmartDevice):
|
|||||||
async def test_create_smart_device_with_timeout():
|
async def test_create_smart_device_with_timeout():
|
||||||
"""Make sure timeout is passed to the protocol."""
|
"""Make sure timeout is passed to the protocol."""
|
||||||
dev = SmartDevice(host="127.0.0.1", timeout=100)
|
dev = SmartDevice(host="127.0.0.1", timeout=100)
|
||||||
assert dev.protocol.timeout == 100
|
assert dev.protocol._transport._timeout == 100
|
||||||
|
|
||||||
|
|
||||||
async def test_create_thin_wrapper():
|
async def test_create_thin_wrapper():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user