mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-06 10:44:04 +00:00
Move protocol modules into protocols package (#1254)
This commit is contained in:
12
kasa/protocols/__init__.py
Normal file
12
kasa/protocols/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Package containing all supported protocols."""
|
||||
|
||||
from .iotprotocol import IotProtocol
|
||||
from .protocol import BaseProtocol
|
||||
from .smartprotocol import SmartErrorCode, SmartProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseProtocol",
|
||||
"IotProtocol",
|
||||
"SmartErrorCode",
|
||||
"SmartProtocol",
|
||||
]
|
170
kasa/protocols/iotprotocol.py
Executable file
170
kasa/protocols/iotprotocol.py
Executable file
@@ -0,0 +1,170 @@
|
||||
"""Module for the IOT legacy IOT KASA protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pprint import pformat as pf
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
AuthenticationError,
|
||||
KasaException,
|
||||
TimeoutError,
|
||||
_ConnectionError,
|
||||
_RetryableError,
|
||||
)
|
||||
from ..json import dumps as json_dumps
|
||||
from ..transports import XorEncryption, XorTransport
|
||||
from .protocol import BaseProtocol, mask_mac, redact_data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transports import BaseTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
||||
"latitude": lambda x: 0,
|
||||
"longitude": lambda x: 0,
|
||||
"latitude_i": lambda x: 0,
|
||||
"longitude_i": lambda x: 0,
|
||||
"deviceId": lambda x: "REDACTED_" + x[9::],
|
||||
"id": lambda x: "REDACTED_" + x[9::],
|
||||
"alias": lambda x: "#MASKED_NAME#" if x else "",
|
||||
"mac": mask_mac,
|
||||
"mic_mac": mask_mac,
|
||||
"ssid": lambda x: "#MASKED_SSID#" if x else "",
|
||||
"oemId": lambda x: "REDACTED_" + x[9::],
|
||||
"username": lambda _: "user@example.com", # cnCloud
|
||||
}
|
||||
|
||||
|
||||
class IotProtocol(BaseProtocol):
|
||||
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||
|
||||
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(transport=transport)
|
||||
|
||||
self._query_lock = asyncio.Lock()
|
||||
self._redact_data = True
|
||||
|
||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
if isinstance(request, dict):
|
||||
request = json_dumps(request)
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self._query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str, retry_count: int = 3) -> dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except _ConnectionError as sdex:
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise sdex
|
||||
continue
|
||||
except AuthenticationError as auex:
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying", self._host
|
||||
)
|
||||
raise auex
|
||||
except _RetryableError as ex:
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except TimeoutError as ex:
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except KasaException as ex:
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
ex,
|
||||
)
|
||||
raise ex
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise KasaException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: str, retry_count: int) -> dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self._host,
|
||||
request,
|
||||
)
|
||||
resp = await self._transport.send(request)
|
||||
|
||||
if debug_enabled:
|
||||
data = redact_data(resp, REDACTORS) if self._redact_data else resp
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self._host,
|
||||
pf(data),
|
||||
)
|
||||
return resp
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
class _deprecated_TPLinkSmartHomeProtocol(IotProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
host: str | None = None,
|
||||
*,
|
||||
port: int | None = None,
|
||||
timeout: int | None = None,
|
||||
transport: BaseTransport | None = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
if not host and not transport:
|
||||
raise KasaException("host or transport must be supplied")
|
||||
if not transport:
|
||||
config = DeviceConfig(
|
||||
host=host, # type: ignore[arg-type]
|
||||
port_override=port,
|
||||
timeout=timeout or XorTransport.DEFAULT_TIMEOUT,
|
||||
)
|
||||
transport = XorTransport(config=config)
|
||||
super().__init__(transport=transport)
|
||||
|
||||
@staticmethod
|
||||
def encrypt(request: str) -> bytes:
|
||||
"""Encrypt a request for a TP-Link Smart Home Device.
|
||||
|
||||
:param request: plaintext request data
|
||||
:return: ciphertext to be send over wire, in bytes
|
||||
"""
|
||||
return XorEncryption.encrypt(request)
|
||||
|
||||
@staticmethod
|
||||
def decrypt(ciphertext: bytes) -> str:
|
||||
"""Decrypt a response of a TP-Link Smart Home Device.
|
||||
|
||||
:param ciphertext: encrypted response data
|
||||
:return: plaintext response
|
||||
"""
|
||||
return XorEncryption.decrypt(ciphertext)
|
106
kasa/protocols/protocol.py
Executable file
106
kasa/protocols/protocol.py
Executable file
@@ -0,0 +1,106 @@
|
||||
"""Implementation of the TP-Link Smart Home Protocol.
|
||||
|
||||
Encryption/Decryption methods based on the works of
|
||||
Lubomir Stroetmann and Tobias Esser
|
||||
|
||||
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
|
||||
https://github.com/softScheck/tplink-smartplug/
|
||||
|
||||
which are licensed under the Apache License, Version 2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import hashlib
|
||||
import logging
|
||||
import struct
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
|
||||
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from ..deviceconfig import DeviceConfig
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
|
||||
_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transports import BaseTransport
|
||||
|
||||
|
||||
def redact_data(data: _T, redactors: dict[str, Callable[[Any], Any] | None]) -> _T:
|
||||
"""Redact sensitive data for logging."""
|
||||
if not isinstance(data, (dict, list)):
|
||||
return data
|
||||
|
||||
if isinstance(data, list):
|
||||
return cast(_T, [redact_data(val, redactors) for val in data])
|
||||
|
||||
redacted = {**data}
|
||||
|
||||
for key, value in redacted.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and not value:
|
||||
continue
|
||||
if key in redactors:
|
||||
if redactor := redactors[key]:
|
||||
try:
|
||||
redacted[key] = redactor(value)
|
||||
except: # noqa: E722
|
||||
redacted[key] = "**REDACTEX**"
|
||||
else:
|
||||
redacted[key] = "**REDACTED**"
|
||||
elif isinstance(value, dict):
|
||||
redacted[key] = redact_data(value, redactors)
|
||||
elif isinstance(value, list):
|
||||
redacted[key] = [redact_data(item, redactors) for item in value]
|
||||
|
||||
return cast(_T, redacted)
|
||||
|
||||
|
||||
def mask_mac(mac: str) -> str:
|
||||
"""Return mac address with last two octects blanked."""
|
||||
delim = ":" if ":" in mac else "-"
|
||||
rest = delim.join(format(s, "02x") for s in bytes.fromhex("000000"))
|
||||
return f"{mac[:8]}{delim}{rest}"
|
||||
|
||||
|
||||
def md5(payload: bytes) -> bytes:
|
||||
"""Return the MD5 hash of the payload."""
|
||||
return hashlib.md5(payload).digest() # noqa: S324
|
||||
|
||||
|
||||
class BaseProtocol(ABC):
|
||||
"""Base class for all TP-Link Smart Home communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self._transport = transport
|
||||
|
||||
@property
|
||||
def _host(self) -> str:
|
||||
return self._transport._host
|
||||
|
||||
@property
|
||||
def config(self) -> DeviceConfig:
|
||||
"""Return the connection parameters the device is using."""
|
||||
return self._transport._config
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Query the device for the protocol. Abstract method to be overriden."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol. Abstract method to be overriden."""
|
445
kasa/protocols/smartprotocol.py
Normal file
445
kasa/protocols/smartprotocol.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""Implementation of the TP-Link AES Protocol.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ..exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
AuthenticationError,
|
||||
DeviceError,
|
||||
KasaException,
|
||||
SmartErrorCode,
|
||||
TimeoutError,
|
||||
_ConnectionError,
|
||||
_RetryableError,
|
||||
)
|
||||
from ..json import dumps as json_dumps
|
||||
from .protocol import BaseProtocol, mask_mac, md5, redact_data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transports import BaseTransport
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
REDACTORS: dict[str, Callable[[Any], Any] | None] = {
|
||||
"latitude": lambda x: 0,
|
||||
"longitude": lambda x: 0,
|
||||
"la": lambda x: 0, # lat on ks240
|
||||
"lo": lambda x: 0, # lon on ks240
|
||||
"device_id": lambda x: "REDACTED_" + x[9::],
|
||||
"parent_device_id": lambda x: "REDACTED_" + x[9::], # Hub attached children
|
||||
"original_device_id": lambda x: "REDACTED_" + x[9::], # Strip children
|
||||
"nickname": lambda x: "I01BU0tFRF9OQU1FIw==" if x else "",
|
||||
"mac": mask_mac,
|
||||
"ssid": lambda x: "I01BU0tFRF9TU0lEIw=" if x else "",
|
||||
"bssid": lambda _: "000000000000",
|
||||
"oem_id": lambda x: "REDACTED_" + x[9::],
|
||||
"setup_code": None, # matter
|
||||
"setup_payload": None, # matter
|
||||
"mfi_setup_code": None, # mfi_ for homekit
|
||||
"mfi_setup_id": None,
|
||||
"mfi_token_token": None,
|
||||
"mfi_token_uuid": None,
|
||||
}
|
||||
|
||||
|
||||
class SmartProtocol(BaseProtocol):
|
||||
"""Class for the new TPLink SMART protocol."""
|
||||
|
||||
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||
DEFAULT_MULTI_REQUEST_BATCH_SIZE = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(transport=transport)
|
||||
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
||||
self._query_lock = asyncio.Lock()
|
||||
self._multi_request_batch_size = (
|
||||
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
||||
)
|
||||
self._redact_data = True
|
||||
|
||||
def get_smart_request(self, method: str, params: dict | None = None) -> str:
|
||||
"""Get a request message as a string."""
|
||||
request = {
|
||||
"method": method,
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"terminal_uuid": self._terminal_uuid,
|
||||
}
|
||||
if params:
|
||||
request["params"] = params
|
||||
return json_dumps(request)
|
||||
|
||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
async with self._query_lock:
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(
|
||||
request, retry_count=retry, iterate_list_pages=True
|
||||
)
|
||||
except _ConnectionError as ex:
|
||||
if retry == 0:
|
||||
_LOGGER.debug(
|
||||
"Device %s got a connection error, will retry %s times: %s",
|
||||
self._host,
|
||||
retry_count,
|
||||
ex,
|
||||
)
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
continue
|
||||
except AuthenticationError as ex:
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to authenticate with %s, not retrying: %s", self._host, ex
|
||||
)
|
||||
raise ex
|
||||
except _RetryableError as ex:
|
||||
if retry == 0:
|
||||
_LOGGER.debug(
|
||||
"Device %s got a retryable error, will retry %s times: %s",
|
||||
self._host,
|
||||
retry_count,
|
||||
ex,
|
||||
)
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except TimeoutError as ex:
|
||||
if retry == 0:
|
||||
_LOGGER.debug(
|
||||
"Device %s got a timeout error, will retry %s times: %s",
|
||||
self._host,
|
||||
retry_count,
|
||||
ex,
|
||||
)
|
||||
await self._transport.reset()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise ex
|
||||
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||
continue
|
||||
except KasaException as ex:
|
||||
await self._transport.reset()
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device: %s, not retrying: %s",
|
||||
self._host,
|
||||
ex,
|
||||
)
|
||||
raise ex
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise KasaException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
multi_result: dict[str, Any] = {}
|
||||
smart_method = "multipleRequest"
|
||||
|
||||
multi_requests = [
|
||||
{"method": method, "params": params} if params else {"method": method}
|
||||
for method, params in requests.items()
|
||||
]
|
||||
|
||||
end = len(multi_requests)
|
||||
# The SmartCameraProtocol sends requests with a length 1 as a
|
||||
# multipleRequest. The SmartProtocol doesn't so will never
|
||||
# raise_on_error
|
||||
raise_on_error = end == 1
|
||||
|
||||
# Break the requests down as there can be a size limit
|
||||
step = self._multi_request_batch_size
|
||||
if step == 1:
|
||||
# If step is 1 do not send request batches
|
||||
for request in multi_requests:
|
||||
method = request["method"]
|
||||
req = self.get_smart_request(method, request.get("params"))
|
||||
resp = await self._transport.send(req)
|
||||
self._handle_response_error_code(
|
||||
resp, method, raise_on_error=raise_on_error
|
||||
)
|
||||
multi_result[method] = resp["result"]
|
||||
return multi_result
|
||||
|
||||
for batch_num, i in enumerate(range(0, end, step)):
|
||||
requests_step = multi_requests[i : i + step]
|
||||
|
||||
smart_params = {"requests": requests_step}
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
batch_name = f"multi-request-batch-{batch_num+1}-of-{int(end/step)+1}"
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s %s >> %s",
|
||||
self._host,
|
||||
batch_name,
|
||||
pf(smart_request),
|
||||
)
|
||||
response_step = await self._transport.send(smart_request)
|
||||
if debug_enabled:
|
||||
if self._redact_data:
|
||||
data = redact_data(response_step, REDACTORS)
|
||||
else:
|
||||
data = response_step
|
||||
_LOGGER.debug(
|
||||
"%s %s << %s",
|
||||
self._host,
|
||||
batch_name,
|
||||
pf(data),
|
||||
)
|
||||
try:
|
||||
self._handle_response_error_code(response_step, batch_name)
|
||||
except DeviceError as ex:
|
||||
# P100 sometimes raises JSON_DECODE_FAIL_ERROR or INTERNAL_UNKNOWN_ERROR
|
||||
# on batched request so disable batching
|
||||
if (
|
||||
ex.error_code
|
||||
in {
|
||||
SmartErrorCode.JSON_DECODE_FAIL_ERROR,
|
||||
SmartErrorCode.INTERNAL_UNKNOWN_ERROR,
|
||||
}
|
||||
and self._multi_request_batch_size != 1
|
||||
):
|
||||
self._multi_request_batch_size = 1
|
||||
raise _RetryableError(
|
||||
"JSON Decode failure, multi requests disabled"
|
||||
) from ex
|
||||
raise ex
|
||||
|
||||
responses = response_step["result"]["responses"]
|
||||
for response in responses:
|
||||
method = response["method"]
|
||||
self._handle_response_error_code(
|
||||
response, method, raise_on_error=raise_on_error
|
||||
)
|
||||
result = response.get("result", None)
|
||||
await self._handle_response_lists(
|
||||
result, method, retry_count=retry_count
|
||||
)
|
||||
multi_result[method] = result
|
||||
# Multi requests don't continue after errors so requery any missing
|
||||
for method, params in requests.items():
|
||||
if method not in multi_result:
|
||||
resp = await self._transport.send(
|
||||
self.get_smart_request(method, params)
|
||||
)
|
||||
self._handle_response_error_code(resp, method, raise_on_error=False)
|
||||
multi_result[method] = resp.get("result")
|
||||
return multi_result
|
||||
|
||||
async def _execute_query(
|
||||
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
|
||||
) -> dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
if isinstance(request, dict):
|
||||
if len(request) == 1:
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
return await self._execute_multiple_query(request, retry_count)
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self._host,
|
||||
pf(smart_request),
|
||||
)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self._host,
|
||||
pf(response_data),
|
||||
)
|
||||
|
||||
self._handle_response_error_code(response_data, smart_method)
|
||||
|
||||
# Single set_ requests do not return a result
|
||||
result = response_data.get("result")
|
||||
if iterate_list_pages and result:
|
||||
await self._handle_response_lists(
|
||||
result, smart_method, retry_count=retry_count
|
||||
)
|
||||
return {smart_method: result}
|
||||
|
||||
async def _handle_response_lists(
|
||||
self, response_result: dict[str, Any], method: str, retry_count: int
|
||||
) -> None:
|
||||
if (
|
||||
response_result is None
|
||||
or isinstance(response_result, SmartErrorCode)
|
||||
or "start_index" not in response_result
|
||||
or (list_sum := response_result.get("sum")) is None
|
||||
):
|
||||
return
|
||||
|
||||
response_list_name = next(
|
||||
iter(
|
||||
[
|
||||
key
|
||||
for key in response_result
|
||||
if isinstance(response_result[key], list)
|
||||
]
|
||||
)
|
||||
)
|
||||
while (list_length := len(response_result[response_list_name])) < list_sum:
|
||||
response = await self._execute_query(
|
||||
{method: {"start_index": list_length}},
|
||||
retry_count=retry_count,
|
||||
iterate_list_pages=False,
|
||||
)
|
||||
next_batch = response[method]
|
||||
# In case the device returns empty lists avoid infinite looping
|
||||
if not next_batch[response_list_name]:
|
||||
_LOGGER.error(
|
||||
"Device %s returned empty results list for method %s",
|
||||
self._host,
|
||||
method,
|
||||
)
|
||||
break
|
||||
response_result[response_list_name].extend(next_batch[response_list_name])
|
||||
|
||||
def _handle_response_error_code(
|
||||
self, resp_dict: dict, method: str, raise_on_error: bool = True
|
||||
) -> None:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
||||
)
|
||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
|
||||
if error_code is SmartErrorCode.SUCCESS:
|
||||
return
|
||||
|
||||
if not raise_on_error:
|
||||
resp_dict["result"] = error_code
|
||||
return
|
||||
|
||||
msg = (
|
||||
f"Error querying device: {self._host}: "
|
||||
+ f"{error_code.name}({error_code.value})"
|
||||
+ f" for method: {method}"
|
||||
)
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
raise _RetryableError(msg, error_code=error_code)
|
||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||
raise AuthenticationError(msg, error_code=error_code)
|
||||
raise DeviceError(msg, error_code=error_code)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
class _ChildProtocolWrapper(SmartProtocol):
|
||||
"""Protocol wrapper for controlling child devices.
|
||||
|
||||
This is an internal class used to communicate with child devices,
|
||||
and should not be used directly.
|
||||
|
||||
This class overrides query() method of the protocol to modify all
|
||||
outgoing queries to use ``control_child`` command, and unwraps the
|
||||
device responses before returning to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
|
||||
self._device_id = device_id
|
||||
self._protocol = base_protocol
|
||||
self._transport = base_protocol._transport
|
||||
|
||||
def _get_method_and_params_for_request(self, request: dict[str, Any] | str) -> Any:
|
||||
"""Return payload for wrapping.
|
||||
|
||||
TODO: this does not support batches and requires refactoring in the future.
|
||||
"""
|
||||
if isinstance(request, dict):
|
||||
if len(request) == 1:
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
smart_method = "multipleRequest"
|
||||
requests = [
|
||||
{"method": method, "params": params}
|
||||
if params
|
||||
else {"method": method}
|
||||
for method, params in request.items()
|
||||
]
|
||||
smart_params = {"requests": requests}
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
return smart_method, smart_params
|
||||
|
||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Wrap request inside control_child envelope."""
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Wrap request inside control_child envelope."""
|
||||
method, params = self._get_method_and_params_for_request(request)
|
||||
request_data = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
wrapped_payload = {
|
||||
"control_child": {
|
||||
"device_id": self._device_id,
|
||||
"requestData": request_data,
|
||||
}
|
||||
}
|
||||
|
||||
response = await self._protocol.query(wrapped_payload, retry_count)
|
||||
result = response.get("control_child")
|
||||
# Unwrap responseData for control_child
|
||||
if result and (response_data := result.get("responseData")):
|
||||
result = response_data.get("result")
|
||||
if result and (multi_responses := result.get("responses")):
|
||||
ret_val = {}
|
||||
for multi_response in multi_responses:
|
||||
method = multi_response["method"]
|
||||
self._handle_response_error_code(
|
||||
multi_response, method, raise_on_error=False
|
||||
)
|
||||
ret_val[method] = multi_response.get("result")
|
||||
return ret_val
|
||||
|
||||
self._handle_response_error_code(response_data, "control_child")
|
||||
|
||||
return {method: result}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Do nothing as the parent owns the protocol."""
|
Reference in New Issue
Block a user