Move protocol modules into protocols package (#1254)

This commit is contained in:
Steven B.
2024-11-13 17:50:21 +00:00
committed by GitHub
parent 1eaae37c55
commit e55731c110
32 changed files with 94 additions and 94 deletions

View File

@@ -0,0 +1,12 @@
"""Package containing all supported protocols."""
from .iotprotocol import IotProtocol
from .protocol import BaseProtocol
from .smartprotocol import SmartErrorCode, SmartProtocol
__all__ = [
"BaseProtocol",
"IotProtocol",
"SmartErrorCode",
"SmartProtocol",
]

170
kasa/protocols/iotprotocol.py Executable file
View File

@@ -0,0 +1,170 @@
"""Module for the IOT legacy IOT KASA protocol."""
from __future__ import annotations
import asyncio
import logging
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable
from ..deviceconfig import DeviceConfig
from ..exceptions import (
AuthenticationError,
KasaException,
TimeoutError,
_ConnectionError,
_RetryableError,
)
from ..json import dumps as json_dumps
from ..transports import XorEncryption, XorTransport
from .protocol import BaseProtocol, mask_mac, redact_data
if TYPE_CHECKING:
from ..transports import BaseTransport
_LOGGER = logging.getLogger(__name__)
REDACTORS: dict[str, Callable[[Any], Any] | None] = {
"latitude": lambda x: 0,
"longitude": lambda x: 0,
"latitude_i": lambda x: 0,
"longitude_i": lambda x: 0,
"deviceId": lambda x: "REDACTED_" + x[9::],
"id": lambda x: "REDACTED_" + x[9::],
"alias": lambda x: "#MASKED_NAME#" if x else "",
"mac": mask_mac,
"mic_mac": mask_mac,
"ssid": lambda x: "#MASKED_SSID#" if x else "",
"oemId": lambda x: "REDACTED_" + x[9::],
"username": lambda _: "user@example.com", # cnCloud
}
class IotProtocol(BaseProtocol):
"""Class for the legacy TPLink IOT KASA Protocol."""
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
def __init__(
self,
*,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(transport=transport)
self._query_lock = asyncio.Lock()
self._redact_data = True
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Query the device retrying for retry_count on failure."""
if isinstance(request, dict):
request = json_dumps(request)
assert isinstance(request, str) # noqa: S101
async with self._query_lock:
return await self._query(request, retry_count)
async def _query(self, request: str, retry_count: int = 3) -> dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except _ConnectionError as sdex:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationError as auex:
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except _RetryableError as ex:
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutError as ex:
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except KasaException as ex:
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
ex,
)
raise ex
# make mypy happy, this should never be reached..
raise KasaException("Query reached somehow to unreachable")
async def _execute_query(self, request: str, retry_count: int) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled:
_LOGGER.debug(
"%s >> %s",
self._host,
request,
)
resp = await self._transport.send(request)
if debug_enabled:
data = redact_data(resp, REDACTORS) if self._redact_data else resp
_LOGGER.debug(
"%s << %s",
self._host,
pf(data),
)
return resp
async def close(self) -> None:
"""Close the underlying transport."""
await self._transport.close()
class _deprecated_TPLinkSmartHomeProtocol(IotProtocol):
def __init__(
self,
host: str | None = None,
*,
port: int | None = None,
timeout: int | None = None,
transport: BaseTransport | None = None,
) -> None:
"""Create a protocol object."""
if not host and not transport:
raise KasaException("host or transport must be supplied")
if not transport:
config = DeviceConfig(
host=host, # type: ignore[arg-type]
port_override=port,
timeout=timeout or XorTransport.DEFAULT_TIMEOUT,
)
transport = XorTransport(config=config)
super().__init__(transport=transport)
@staticmethod
def encrypt(request: str) -> bytes:
"""Encrypt a request for a TP-Link Smart Home Device.
:param request: plaintext request data
:return: ciphertext to be send over wire, in bytes
"""
return XorEncryption.encrypt(request)
@staticmethod
def decrypt(ciphertext: bytes) -> str:
"""Decrypt a response of a TP-Link Smart Home Device.
:param ciphertext: encrypted response data
:return: plaintext response
"""
return XorEncryption.decrypt(ciphertext)

106
kasa/protocols/protocol.py Executable file
View File

@@ -0,0 +1,106 @@
"""Implementation of the TP-Link Smart Home Protocol.
Encryption/Decryption methods based on the works of
Lubomir Stroetmann and Tobias Esser
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
https://github.com/softScheck/tplink-smartplug/
which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
from __future__ import annotations
import errno
import hashlib
import logging
import struct
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from ..deviceconfig import DeviceConfig
_LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
_T = TypeVar("_T")
if TYPE_CHECKING:
from ..transports import BaseTransport
def redact_data(data: _T, redactors: dict[str, Callable[[Any], Any] | None]) -> _T:
"""Redact sensitive data for logging."""
if not isinstance(data, (dict, list)):
return data
if isinstance(data, list):
return cast(_T, [redact_data(val, redactors) for val in data])
redacted = {**data}
for key, value in redacted.items():
if value is None:
continue
if isinstance(value, str) and not value:
continue
if key in redactors:
if redactor := redactors[key]:
try:
redacted[key] = redactor(value)
except: # noqa: E722
redacted[key] = "**REDACTEX**"
else:
redacted[key] = "**REDACTED**"
elif isinstance(value, dict):
redacted[key] = redact_data(value, redactors)
elif isinstance(value, list):
redacted[key] = [redact_data(item, redactors) for item in value]
return cast(_T, redacted)
def mask_mac(mac: str) -> str:
"""Return mac address with last two octects blanked."""
delim = ":" if ":" in mac else "-"
rest = delim.join(format(s, "02x") for s in bytes.fromhex("000000"))
return f"{mac[:8]}{delim}{rest}"
def md5(payload: bytes) -> bytes:
"""Return the MD5 hash of the payload."""
return hashlib.md5(payload).digest() # noqa: S324
class BaseProtocol(ABC):
"""Base class for all TP-Link Smart Home communication."""
def __init__(
self,
*,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
self._transport = transport
@property
def _host(self) -> str:
return self._transport._host
@property
def config(self) -> DeviceConfig:
"""Return the connection parameters the device is using."""
return self._transport._config
@abstractmethod
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Query the device for the protocol. Abstract method to be overriden."""
@abstractmethod
async def close(self) -> None:
"""Close the protocol. Abstract method to be overriden."""

View File

@@ -0,0 +1,445 @@
"""Implementation of the TP-Link AES Protocol.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
from __future__ import annotations
import asyncio
import base64
import logging
import time
import uuid
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable
from ..exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
TimeoutError,
_ConnectionError,
_RetryableError,
)
from ..json import dumps as json_dumps
from .protocol import BaseProtocol, mask_mac, md5, redact_data
if TYPE_CHECKING:
from ..transports import BaseTransport
_LOGGER = logging.getLogger(__name__)
REDACTORS: dict[str, Callable[[Any], Any] | None] = {
"latitude": lambda x: 0,
"longitude": lambda x: 0,
"la": lambda x: 0, # lat on ks240
"lo": lambda x: 0, # lon on ks240
"device_id": lambda x: "REDACTED_" + x[9::],
"parent_device_id": lambda x: "REDACTED_" + x[9::], # Hub attached children
"original_device_id": lambda x: "REDACTED_" + x[9::], # Strip children
"nickname": lambda x: "I01BU0tFRF9OQU1FIw==" if x else "",
"mac": mask_mac,
"ssid": lambda x: "I01BU0tFRF9TU0lEIw=" if x else "",
"bssid": lambda _: "000000000000",
"oem_id": lambda x: "REDACTED_" + x[9::],
"setup_code": None, # matter
"setup_payload": None, # matter
"mfi_setup_code": None, # mfi_ for homekit
"mfi_setup_id": None,
"mfi_token_token": None,
"mfi_token_uuid": None,
}
class SmartProtocol(BaseProtocol):
"""Class for the new TPLink SMART protocol."""
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5
def __init__(
self,
*,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._query_lock = asyncio.Lock()
self._multi_request_batch_size = (
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
)
self._redact_data = True
def get_smart_request(self, method: str, params: dict | None = None) -> str:
"""Get a request message as a string."""
request = {
"method": method,
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self._terminal_uuid,
}
if params:
request["params"] = params
return json_dumps(request)
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Query the device retrying for retry_count on failure."""
async with self._query_lock:
return await self._query(request, retry_count)
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(
request, retry_count=retry, iterate_list_pages=True
)
except _ConnectionError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a connection error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except AuthenticationError as ex:
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying: %s", self._host, ex
)
raise ex
except _RetryableError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a retryable error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except TimeoutError as ex:
if retry == 0:
_LOGGER.debug(
"Device %s got a timeout error, will retry %s times: %s",
self._host,
retry_count,
ex,
)
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except KasaException as ex:
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
ex,
)
raise ex
# make mypy happy, this should never be reached..
raise KasaException("Query reached somehow to unreachable")
async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: dict[str, Any] = {}
smart_method = "multipleRequest"
multi_requests = [
{"method": method, "params": params} if params else {"method": method}
for method, params in requests.items()
]
end = len(multi_requests)
# The SmartCameraProtocol sends requests with a length 1 as a
# multipleRequest. The SmartProtocol doesn't so will never
# raise_on_error
raise_on_error = end == 1
# Break the requests down as there can be a size limit
step = self._multi_request_batch_size
if step == 1:
# If step is 1 do not send request batches
for request in multi_requests:
method = request["method"]
req = self.get_smart_request(method, request.get("params"))
resp = await self._transport.send(req)
self._handle_response_error_code(
resp, method, raise_on_error=raise_on_error
)
multi_result[method] = resp["result"]
return multi_result
for batch_num, i in enumerate(range(0, end, step)):
requests_step = multi_requests[i : i + step]
smart_params = {"requests": requests_step}
smart_request = self.get_smart_request(smart_method, smart_params)
batch_name = f"multi-request-batch-{batch_num+1}-of-{int(end/step)+1}"
if debug_enabled:
_LOGGER.debug(
"%s %s >> %s",
self._host,
batch_name,
pf(smart_request),
)
response_step = await self._transport.send(smart_request)
if debug_enabled:
if self._redact_data:
data = redact_data(response_step, REDACTORS)
else:
data = response_step
_LOGGER.debug(
"%s %s << %s",
self._host,
batch_name,
pf(data),
)
try:
self._handle_response_error_code(response_step, batch_name)
except DeviceError as ex:
# P100 sometimes raises JSON_DECODE_FAIL_ERROR or INTERNAL_UNKNOWN_ERROR
# on batched request so disable batching
if (
ex.error_code
in {
SmartErrorCode.JSON_DECODE_FAIL_ERROR,
SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
}
and self._multi_request_batch_size != 1
):
self._multi_request_batch_size = 1
raise _RetryableError(
"JSON Decode failure, multi requests disabled"
) from ex
raise ex
responses = response_step["result"]["responses"]
for response in responses:
method = response["method"]
self._handle_response_error_code(
response, method, raise_on_error=raise_on_error
)
result = response.get("result", None)
await self._handle_response_lists(
result, method, retry_count=retry_count
)
multi_result[method] = result
# Multi requests don't continue after errors so requery any missing
for method, params in requests.items():
if method not in multi_result:
resp = await self._transport.send(
self.get_smart_request(method, params)
)
self._handle_response_error_code(resp, method, raise_on_error=False)
multi_result[method] = resp.get("result")
return multi_result
async def _execute_query(
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if isinstance(request, dict):
if len(request) == 1:
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
return await self._execute_multiple_query(request, retry_count)
else:
smart_method = request
smart_params = None
smart_request = self.get_smart_request(smart_method, smart_params)
if debug_enabled:
_LOGGER.debug(
"%s >> %s",
self._host,
pf(smart_request),
)
response_data = await self._transport.send(smart_request)
if debug_enabled:
_LOGGER.debug(
"%s << %s",
self._host,
pf(response_data),
)
self._handle_response_error_code(response_data, smart_method)
# Single set_ requests do not return a result
result = response_data.get("result")
if iterate_list_pages and result:
await self._handle_response_lists(
result, smart_method, retry_count=retry_count
)
return {smart_method: result}
async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int
) -> None:
if (
response_result is None
or isinstance(response_result, SmartErrorCode)
or "start_index" not in response_result
or (list_sum := response_result.get("sum")) is None
):
return
response_list_name = next(
iter(
[
key
for key in response_result
if isinstance(response_result[key], list)
]
)
)
while (list_length := len(response_result[response_list_name])) < list_sum:
response = await self._execute_query(
{method: {"start_index": list_length}},
retry_count=retry_count,
iterate_list_pages=False,
)
next_batch = response[method]
# In case the device returns empty lists avoid infinite looping
if not next_batch[response_list_name]:
_LOGGER.error(
"Device %s returned empty results list for method %s",
self._host,
method,
)
break
response_result[response_list_name].extend(next_batch[response_list_name])
def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
if error_code is SmartErrorCode.SUCCESS:
return
if not raise_on_error:
resp_dict["result"] = error_code
return
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
+ f" for method: {method}"
)
if error_code in SMART_RETRYABLE_ERRORS:
raise _RetryableError(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
async def close(self) -> None:
"""Close the underlying transport."""
await self._transport.close()
class _ChildProtocolWrapper(SmartProtocol):
"""Protocol wrapper for controlling child devices.
This is an internal class used to communicate with child devices,
and should not be used directly.
This class overrides query() method of the protocol to modify all
outgoing queries to use ``control_child`` command, and unwraps the
device responses before returning to the caller.
"""
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
self._device_id = device_id
self._protocol = base_protocol
self._transport = base_protocol._transport
def _get_method_and_params_for_request(self, request: dict[str, Any] | str) -> Any:
"""Return payload for wrapping.
TODO: this does not support batches and requires refactoring in the future.
"""
if isinstance(request, dict):
if len(request) == 1:
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
smart_method = "multipleRequest"
requests = [
{"method": method, "params": params}
if params
else {"method": method}
for method, params in request.items()
]
smart_params = {"requests": requests}
else:
smart_method = request
smart_params = None
return smart_method, smart_params
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Wrap request inside control_child envelope."""
return await self._query(request, retry_count)
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Wrap request inside control_child envelope."""
method, params = self._get_method_and_params_for_request(request)
request_data = {
"method": method,
"params": params,
}
wrapped_payload = {
"control_child": {
"device_id": self._device_id,
"requestData": request_data,
}
}
response = await self._protocol.query(wrapped_payload, retry_count)
result = response.get("control_child")
# Unwrap responseData for control_child
if result and (response_data := result.get("responseData")):
result = response_data.get("result")
if result and (multi_responses := result.get("responses")):
ret_val = {}
for multi_response in multi_responses:
method = multi_response["method"]
self._handle_response_error_code(
multi_response, method, raise_on_error=False
)
ret_val[method] = multi_response.get("result")
return ret_val
self._handle_response_error_code(response_data, "control_child")
return {method: result}
async def close(self) -> None:
"""Do nothing as the parent owns the protocol."""