Move TAPO smartcamera out of experimental package (#1255)

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
Steven B.
2024-11-13 19:59:42 +00:00
committed by GitHub
parent e55731c110
commit 6213b90f62
21 changed files with 59 additions and 36 deletions

View File

@@ -1,15 +0,0 @@
"""Modules for SMARTCAMERA devices."""
from .camera import Camera
from .childdevice import ChildDevice
from .device import DeviceModule
from .led import Led
from .time import Time
__all__ = [
"Camera",
"ChildDevice",
"DeviceModule",
"Led",
"Time",
]

View File

@@ -1,71 +0,0 @@
"""Implementation of device module."""
from __future__ import annotations
from urllib.parse import quote_plus
from ...credentials import Credentials
from ...device_type import DeviceType
from ...feature import Feature
from ..smartcameramodule import SmartCameraModule
LOCAL_STREAMING_PORT = 554
class Camera(SmartCameraModule):
"""Implementation of device module."""
QUERY_GETTER_NAME = "getLensMaskConfig"
QUERY_MODULE_NAME = "lens_mask"
QUERY_SECTION_NAMES = "lens_mask_info"
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
self._device,
id="state",
name="State",
attribute_getter="is_on",
attribute_setter="set_state",
type=Feature.Type.Switch,
category=Feature.Category.Primary,
)
)
@property
def is_on(self) -> bool:
"""Return the device id."""
return self.data["lens_mask_info"]["enabled"] == "off"
def stream_rtsp_url(self, credentials: Credentials | None = None) -> str | None:
"""Return the local rtsp streaming url.
:param credentials: Credentials for camera account.
These could be different credentials to tplink cloud credentials.
If not provided will use tplink credentials if available
:return: rtsp url with escaped credentials or None if no credentials or
camera is off.
"""
if not self.is_on:
return None
dev = self._device
if not credentials:
credentials = dev.credentials
if not credentials or not credentials.username or not credentials.password:
return None
username = quote_plus(credentials.username)
password = quote_plus(credentials.password)
return f"rtsp://{username}:{password}@{dev.host}:{LOCAL_STREAMING_PORT}/stream1"
async def set_state(self, on: bool) -> dict:
"""Set the device state."""
# Turning off enables the privacy mask which is why value is reversed.
params = {"enabled": "off" if on else "on"}
return await self._device._query_setter_helper(
"setLensMaskConfig", self.QUERY_MODULE_NAME, "lens_mask_info", params
)
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device."""
return self._device.device_type is DeviceType.Camera

View File

@@ -1,26 +0,0 @@
"""Module for child devices."""
from ...device_type import DeviceType
from ..smartcameramodule import SmartCameraModule
class ChildDevice(SmartCameraModule):
"""Implementation for child devices."""
REQUIRED_COMPONENT = "childControl"
NAME = "childdevice"
QUERY_GETTER_NAME = "getChildDeviceList"
# This module is unusual in that QUERY_MODULE_NAME in the response is not
# the same one used in the request.
QUERY_MODULE_NAME = "child_device_list"
def query(self) -> dict:
"""Query to execute during the update cycle.
Default implementation uses the raw query getter w/o parameters.
"""
return {self.QUERY_GETTER_NAME: {"childControl": {"start_index": 0}}}
async def _check_supported(self) -> bool:
"""Additional check to see if the module is supported by the device."""
return self._device.device_type is DeviceType.Hub

View File

@@ -1,40 +0,0 @@
"""Implementation of device module."""
from __future__ import annotations
from ...feature import Feature
from ..smartcameramodule import SmartCameraModule
class DeviceModule(SmartCameraModule):
"""Implementation of device module."""
NAME = "devicemodule"
QUERY_GETTER_NAME = "getDeviceInfo"
QUERY_MODULE_NAME = "device_info"
QUERY_SECTION_NAMES = ["basic_info", "info"]
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
self._device,
id="device_id",
name="Device ID",
attribute_getter="device_id",
category=Feature.Category.Debug,
type=Feature.Type.Sensor,
)
)
async def _post_update_hook(self) -> None:
"""Overriden to prevent module disabling.
Overrides the default behaviour to disable a module if the query returns
an error because this module is critical.
"""
@property
def device_id(self) -> str:
"""Return the device id."""
return self.data["basic_info"]["dev_id"]

View File

@@ -1,28 +0,0 @@
"""Module for led controls."""
from __future__ import annotations
from ...interfaces.led import Led as LedInterface
from ..smartcameramodule import SmartCameraModule
class Led(SmartCameraModule, LedInterface):
"""Implementation of led controls."""
REQUIRED_COMPONENT = "led"
QUERY_GETTER_NAME = "getLedStatus"
QUERY_MODULE_NAME = "led"
QUERY_SECTION_NAMES = "config"
@property
def led(self) -> bool:
"""Return current led status."""
return self.data["config"]["enabled"] == "on"
async def set_led(self, enable: bool) -> dict:
"""Set led.
This should probably be a select with always/never/nightmode.
"""
params = {"enabled": "on"} if enable else {"enabled": "off"}
return await self.call("setLedStatus", {"led": {"config": params}})

View File

@@ -1,91 +0,0 @@
"""Implementation of time module."""
from __future__ import annotations
from datetime import datetime, timezone, tzinfo
from typing import cast
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from ...cachedzoneinfo import CachedZoneInfo
from ...feature import Feature
from ...interfaces import Time as TimeInterface
from ..smartcameramodule import SmartCameraModule
class Time(SmartCameraModule, TimeInterface):
"""Implementation of device_local_time."""
QUERY_GETTER_NAME = "getTimezone"
QUERY_MODULE_NAME = "system"
QUERY_SECTION_NAMES = "basic"
_timezone: tzinfo = timezone.utc
_time: datetime
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
device=self._device,
id="device_time",
name="Device time",
attribute_getter="time",
container=self,
category=Feature.Category.Debug,
type=Feature.Type.Sensor,
)
)
def query(self) -> dict:
"""Query to execute during the update cycle."""
q = super().query()
q["getClockStatus"] = {self.QUERY_MODULE_NAME: {"name": "clock_status"}}
return q
async def _post_update_hook(self) -> None:
"""Perform actions after a device update."""
time_data = self.data["getClockStatus"]["system"]["clock_status"]
timezone_data = self.data["getTimezone"]["system"]["basic"]
zone_id = timezone_data["zone_id"]
timestamp = time_data["seconds_from_1970"]
try:
# Zoneinfo will return a DST aware object
tz: tzinfo = await CachedZoneInfo.get_cached_zone_info(zone_id)
except ZoneInfoNotFoundError:
# timezone string like: UTC+10:00
timezone_str = timezone_data["timezone"]
tz = cast(tzinfo, datetime.strptime(timezone_str[-6:], "%z").tzinfo)
self._timezone = tz
self._time = datetime.fromtimestamp(
cast(float, timestamp),
tz=tz,
)
@property
def timezone(self) -> tzinfo:
"""Return current timezone."""
return self._timezone
@property
def time(self) -> datetime:
"""Return device's current datetime."""
return self._time
async def set_time(self, dt: datetime) -> dict:
"""Set device time."""
if not dt.tzinfo:
timestamp = dt.replace(tzinfo=self.timezone).timestamp()
else:
timestamp = dt.timestamp()
lt = datetime.fromtimestamp(timestamp).isoformat().replace("T", " ")
params = {"seconds_from_1970": int(timestamp), "local_time": lt}
# Doesn't seem to update the time, perhaps because timing_mode is ntp
res = await self.call("setTimezone", {"system": {"clock_status": params}})
if (zinfo := dt.tzinfo) and isinstance(zinfo, ZoneInfo):
tz_params = {"zone_id": zinfo.key}
res = await self.call("setTimezone", {"system": {"basic": tz_params}})
return res

View File

@@ -1,217 +0,0 @@
"""Module for smartcamera."""
from __future__ import annotations
import logging
from typing import Any
from ..device_type import DeviceType
from ..module import Module
from ..smart import SmartChildDevice, SmartDevice
from .modules.childdevice import ChildDevice
from .modules.device import DeviceModule
from .smartcameramodule import SmartCameraModule
from .smartcameraprotocol import _ChildCameraProtocolWrapper
_LOGGER = logging.getLogger(__name__)
class SmartCamera(SmartDevice):
"""Class for smart cameras."""
# Modules that are called as part of the init procedure on first update
FIRST_UPDATE_MODULES = {DeviceModule, ChildDevice}
@staticmethod
def _get_device_type_from_sysinfo(sysinfo: dict[str, Any]) -> DeviceType:
"""Find type to be displayed as a supported device category."""
device_type = sysinfo["device_type"]
if device_type.endswith("HUB"):
return DeviceType.Hub
return DeviceType.Camera
def _update_internal_info(self, info_resp: dict) -> None:
"""Update the internal device info."""
info = self._try_get_response(info_resp, "getDeviceInfo")
self._info = self._map_info(info["device_info"])
def _update_children_info(self) -> None:
"""Update the internal child device info from the parent info."""
if child_info := self._try_get_response(
self._last_update, "getChildDeviceList", {}
):
for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info)
async def _initialize_smart_child(
self, info: dict, child_components: dict
) -> SmartDevice:
"""Initialize a smart child device attached to a smartcamera."""
child_id = info["device_id"]
child_protocol = _ChildCameraProtocolWrapper(child_id, self.protocol)
try:
initial_response = await child_protocol.query(
{"get_connect_cloud_state": None}
)
except Exception as ex:
_LOGGER.exception("Error initialising child %s: %s", child_id, ex)
return await SmartChildDevice.create(
parent=self,
child_info=info,
child_components=child_components,
protocol=child_protocol,
last_update=initial_response,
)
async def _initialize_children(self) -> None:
"""Initialize children for hubs."""
child_info_query = {
"getChildDeviceList": {"childControl": {"start_index": 0}},
"getChildDeviceComponentList": {"childControl": {"start_index": 0}},
}
resp = await self.protocol.query(child_info_query)
self.internal_state.update(resp)
children_components = {
child["device_id"]: {
comp["id"]: int(comp["ver_code"]) for comp in child["component_list"]
}
for child in resp["getChildDeviceComponentList"]["child_component_list"]
}
children = {}
for info in resp["getChildDeviceList"]["child_device_list"]:
if (
category := info.get("category")
) and category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP:
child_id = info["device_id"]
children[child_id] = await self._initialize_smart_child(
info, children_components[child_id]
)
else:
_LOGGER.debug("Child device type not supported: %s", info)
self._children = children
async def _initialize_modules(self) -> None:
"""Initialize modules based on component negotiation response."""
for mod in SmartCameraModule.REGISTERED_MODULES.values():
if (
mod.REQUIRED_COMPONENT
and mod.REQUIRED_COMPONENT not in self._components
):
continue
module = mod(self, mod._module_name())
if await module._check_supported():
self._modules[module.name] = module
async def _initialize_features(self) -> None:
"""Initialize device features."""
for module in self.modules.values():
module._initialize_features()
for feat in module._module_features.values():
self._add_feature(feat)
for child in self._children.values():
await child._initialize_features()
async def _query_setter_helper(
self, method: str, module: str, section: str, params: dict | None = None
) -> dict:
res = await self.protocol.query({method: {module: {section: params}}})
return res
async def _query_getter_helper(
self, method: str, module: str, sections: str | list[str]
) -> Any:
res = await self.protocol.query({method: {module: {"name": sections}}})
return res
async def _negotiate(self) -> None:
"""Perform initialization.
We fetch the device info and the available components as early as possible.
If the device reports supporting child devices, they are also initialized.
"""
initial_query = {
"getDeviceInfo": {"device_info": {"name": ["basic_info", "info"]}},
"getAppComponentList": {"app_component": {"name": "app_component_list"}},
}
resp = await self.protocol.query(initial_query)
self._last_update.update(resp)
self._update_internal_info(resp)
self._components = {
comp["name"]: int(comp["version"])
for comp in resp["getAppComponentList"]["app_component"][
"app_component_list"
]
}
if "childControl" in self._components and not self.children:
await self._initialize_children()
def _map_info(self, device_info: dict) -> dict:
basic_info = device_info["basic_info"]
return {
"model": basic_info["device_model"],
"device_type": basic_info["device_type"],
"alias": basic_info["device_alias"],
"fw_ver": basic_info["sw_version"],
"hw_ver": basic_info["hw_version"],
"mac": basic_info["mac"],
"hwId": basic_info.get("hw_id"),
"oem_id": basic_info["oem_id"],
}
@property
def is_on(self) -> bool:
"""Return true if the device is on."""
if (camera := self.modules.get(Module.Camera)) and not camera.disabled:
return camera.is_on
return True
async def set_state(self, on: bool) -> dict:
"""Set the device state."""
if (camera := self.modules.get(Module.Camera)) and not camera.disabled:
return await camera.set_state(on)
return {}
@property
def device_type(self) -> DeviceType:
"""Return the device type."""
if self._device_type == DeviceType.Unknown:
self._device_type = self._get_device_type_from_sysinfo(self._info)
return self._device_type
@property
def alias(self) -> str | None:
"""Returns the device alias or nickname."""
if self._info:
return self._info.get("alias")
return None
async def set_alias(self, alias: str) -> dict:
"""Set the device name (alias)."""
return await self.protocol.query(
{
"setDeviceAlias": {"system": {"sys": {"dev_alias": alias}}},
}
)
@property
def hw_info(self) -> dict:
"""Return hardware info for the device."""
return {
"sw_ver": self._info.get("hw_ver"),
"hw_ver": self._info.get("fw_ver"),
"mac": self._info.get("mac"),
"type": self._info.get("type"),
"hwId": self._info.get("hwId"),
"dev_name": self.alias,
"oemId": self._info.get("oem_id"),
}

View File

@@ -1,100 +0,0 @@
"""Base implementation for SMART modules."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, cast
from ..exceptions import DeviceError, KasaException, SmartErrorCode
from ..smart.smartmodule import SmartModule
if TYPE_CHECKING:
from .smartcamera import SmartCamera
_LOGGER = logging.getLogger(__name__)
class SmartCameraModule(SmartModule):
"""Base class for SMARTCAMERA modules."""
#: Query to execute during the main update cycle
QUERY_GETTER_NAME: str
#: Module name to be queried
QUERY_MODULE_NAME: str
#: Section name or names to be queried
QUERY_SECTION_NAMES: str | list[str]
REGISTERED_MODULES = {}
_device: SmartCamera
def query(self) -> dict:
"""Query to execute during the update cycle.
Default implementation uses the raw query getter w/o parameters.
"""
return {
self.QUERY_GETTER_NAME: {
self.QUERY_MODULE_NAME: {"name": self.QUERY_SECTION_NAMES}
}
}
async def call(self, method: str, params: dict | None = None) -> dict:
"""Call a method.
Just a helper method.
"""
if params:
module = next(iter(params))
section = next(iter(params[module]))
else:
module = "system"
section = "null"
if method[:3] == "get":
return await self._device._query_getter_helper(method, module, section)
if TYPE_CHECKING:
params = cast(dict[str, dict[str, Any]], params)
return await self._device._query_setter_helper(
method, module, section, params[module][section]
)
@property
def data(self) -> dict:
"""Return response data for the module."""
dev = self._device
q = self.query()
if not q:
return dev.sys_info
if len(q) == 1:
query_resp = dev._last_update.get(self.QUERY_GETTER_NAME, {})
if isinstance(query_resp, SmartErrorCode):
raise DeviceError(
f"Error accessing module data in {self._module}",
error_code=query_resp,
)
if not query_resp:
raise KasaException(
f"You need to call update() prior accessing module data"
f" for '{self._module}'"
)
return query_resp.get(self.QUERY_MODULE_NAME)
else:
found = {key: val for key, val in dev._last_update.items() if key in q}
for key in q:
if key not in found:
raise KasaException(
f"{key} not found, you need to call update() prior accessing"
f" module data for '{self._module}'"
)
if isinstance(found[key], SmartErrorCode):
raise DeviceError(
f"Error accessing module data {key} in {self._module}",
error_code=found[key],
)
return found

View File

@@ -1,253 +0,0 @@
"""Module for SmartCamera Protocol."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pprint import pformat as pf
from typing import Any
from ..exceptions import (
AuthenticationError,
DeviceError,
KasaException,
_RetryableError,
)
from ..json import dumps as json_dumps
from ..protocols import SmartProtocol
from .sslaestransport import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SmartErrorCode,
)
_LOGGER = logging.getLogger(__name__)
# List of getMethodNames that should be sent as {"method":"do"}
# https://md.depau.eu/s/r1Ys_oWoP#Modules
GET_METHODS_AS_DO = {
"getSdCardFormatStatus",
"getConnectionType",
"getUserID",
"getP2PSharePassword",
"getAESEncryptKey",
"getFirmwareAFResult",
"getWhitelampStatus",
}
@dataclass
class SingleRequest:
"""Class for returning single request details from helper functions."""
method_type: str
method_name: str
param_name: str
request: dict[str, Any]
class SmartCameraProtocol(SmartProtocol):
"""Class for SmartCamera Protocol."""
async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int
) -> None:
pass
def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
if error_code is SmartErrorCode.SUCCESS:
return
if not raise_on_error:
resp_dict["result"] = error_code
return
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
+ f" for method: {method}"
)
if error_code in SMART_RETRYABLE_ERRORS:
raise _RetryableError(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
async def close(self) -> None:
"""Close the underlying transport."""
await self._transport.close()
@staticmethod
def _get_smart_camera_single_request(
request: dict[str, dict[str, Any]],
) -> SingleRequest:
method = next(iter(request))
if method == "multipleRequest":
method_type = "multi"
params = request["multipleRequest"]
req = {"method": "multipleRequest", "params": params}
return SingleRequest("multi", "multipleRequest", "", req)
param = next(iter(request[method]))
method_type = method
req = {
"method": method,
param: request[method][param],
}
return SingleRequest(method_type, method, param, req)
@staticmethod
def _make_snake_name(name: str) -> str:
"""Convert camel or pascal case to snake name."""
sn = "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
return sn
@staticmethod
def _make_smart_camera_single_request(
request: str,
) -> SingleRequest:
"""Make a single request given a method name and no params.
If method like getSomeThing then module will be some_thing.
"""
method = request
method_type = request[:3]
snake_name = SmartCameraProtocol._make_snake_name(request)
param = snake_name[4:]
if (
(short_method := method[:3])
and short_method in {"get", "set"}
and method not in GET_METHODS_AS_DO
):
method_type = short_method
param = snake_name[4:]
else:
method_type = "do"
param = snake_name
req = {"method": method_type, param: {}}
return SingleRequest(method_type, method, param, req)
async def _execute_query(
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if isinstance(request, dict):
method = next(iter(request))
if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
single_request = self._get_smart_camera_single_request(request)
else:
return await self._execute_multiple_query(request, retry_count)
else:
single_request = self._make_smart_camera_single_request(request)
smart_request = json_dumps(single_request.request)
if debug_enabled:
_LOGGER.debug(
"%s >> %s",
self._host,
pf(smart_request),
)
response_data = await self._transport.send(smart_request)
if debug_enabled:
_LOGGER.debug(
"%s << %s",
self._host,
pf(response_data),
)
if "error_code" in response_data:
# H200 does not return an error code
self._handle_response_error_code(response_data, single_request.method_name)
# Requests that are invalid and raise PROTOCOL_FORMAT_ERROR when sent
# as a multipleRequest will return {} when sent as a single request.
if single_request.method_type == "get" and (
not (section := next(iter(response_data))) or response_data[section] == {}
):
raise DeviceError(
f"No results for get request {single_request.method_name}"
)
# TODO need to update handle response lists
if single_request.method_type == "do":
return {single_request.method_name: response_data}
if single_request.method_type == "set":
return {}
if single_request.method_type == "multi":
return {single_request.method_name: response_data["result"]}
return {
single_request.method_name: {
single_request.param_name: response_data[single_request.param_name]
}
}
class _ChildCameraProtocolWrapper(SmartProtocol):
"""Protocol wrapper for controlling child devices.
This is an internal class used to communicate with child devices,
and should not be used directly.
This class overrides query() method of the protocol to modify all
outgoing queries to use ``controlChild`` command, and unwraps the
device responses before returning to the caller.
"""
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
self._device_id = device_id
self._protocol = base_protocol
self._transport = base_protocol._transport
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Wrap request inside controlChild envelope."""
return await self._query(request, retry_count)
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
"""Wrap request inside controlChild envelope."""
if not isinstance(request, dict):
raise KasaException("Child requests must be dictionaries.")
requests = []
methods = []
for key, val in request.items():
request = {
"method": "controlChild",
"params": {
"childControl": {
"device_id": self._device_id,
"request_data": {"method": key, "params": val},
}
},
}
methods.append(key)
requests.append(request)
multipleRequest = {"multipleRequest": {"requests": requests}}
response = await self._protocol.query(multipleRequest, retry_count)
responses = response["multipleRequest"]["responses"]
response_dict = {}
for index_id, response in enumerate(responses):
response_data = response["result"]["response_data"]
method = methods[index_id]
self._handle_response_error_code(
response_data, method, raise_on_error=False
)
response_dict[method] = response_data.get("result")
return response_dict
async def close(self) -> None:
"""Do nothing as the parent owns the protocol."""

View File

@@ -1,478 +0,0 @@
"""Implementation of the TP-Link SSL AES transport."""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import secrets
import ssl
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, cast
from yarl import URL
from ..credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
_RetryableError,
)
from ..httpclient import HttpClient
from ..json import dumps as json_dumps
from ..json import loads as json_loads
from ..transports import AesEncyptionSession, BaseTransport
_LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
def _sha256(payload: bytes) -> bytes:
return hashlib.sha256(payload).digest() # noqa: S324
def _md5_hash(payload: bytes) -> str:
return hashlib.md5(payload).hexdigest().upper() # noqa: S324
def _sha256_hash(payload: bytes) -> str:
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
class TransportState(Enum):
"""Enum for AES state."""
HANDSHAKE_REQUIRED = auto() # Handshake needed
ESTABLISHED = auto() # Ready to send requests
class SslAesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT: int = 443
COMMON_HEADERS = {
"Content-Type": "application/json; charset=UTF-8",
"requestByApp": "true",
"Accept": "application/json",
"Accept-Encoding": "gzip, deflate",
"User-Agent": "Tapo CameraClient Android",
}
CIPHERS = ":".join(
[
"AES256-GCM-SHA384",
"AES256-SHA256",
"AES128-GCM-SHA256",
"AES128-SHA256",
"AES256-SHA",
]
)
DEFAULT_TIMEOUT = 10
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
super().__init__(config=config)
self._login_version = config.connection_type.login_version
if (
not self._credentials or self._credentials.username is None
) and not self._credentials_hash:
self._credentials = Credentials()
self._default_credentials: Credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPOCAMERA"]
)
self._http_client: HttpClient = HttpClient(config)
self._state = TransportState.HANDSHAKE_REQUIRED
self._encryption_session: AesEncyptionSession | None = None
self._session_expire_at: float | None = None
self._host_port = f"{self._host}:{self._port}"
self._app_url = URL(f"https://{self._host_port}")
self._token_url: URL | None = None
self._ssl_context: ssl.SSLContext | None = None
ref = str(self._token_url) if self._token_url else str(self._app_url)
self._headers = {
**self.COMMON_HEADERS,
"Host": self._host_port,
"Referer": ref,
}
self._seq: int | None = None
self._pwd_hash: str | None = None
self._username: str | None = None
self._password: str | None = None
if self._credentials != Credentials() and self._credentials:
self._username = self._credentials.username
self._password = self._credentials.password
elif self._credentials_hash:
ch = json_loads(base64.b64decode(self._credentials_hash.encode()))
self._password = ch["pwd"]
self._username = ch["un"]
self._local_nonce: str | None = None
_LOGGER.debug("Created AES transport for %s", self._host)
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@staticmethod
def _create_b64_credentials(credentials: Credentials) -> str:
ch = {"un": credentials.username, "pwd": credentials.password}
return base64.b64encode(json_dumps(ch).encode()).decode()
@property
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
if self._credentials == Credentials():
return None
if not self._credentials and self._credentials_hash:
return self._credentials_hash
if (cred := self._credentials) and cred.password and cred.username:
return self._create_b64_credentials(cred)
return None
def _get_response_error(self, resp_dict: Any) -> SmartErrorCode:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
return error_code
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = self._get_response_error(resp_dict)
if error_code is SmartErrorCode.SUCCESS:
return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_RETRYABLE_ERRORS:
raise _RetryableError(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._state = TransportState.HANDSHAKE_REQUIRED
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
def _create_ssl_context(self) -> ssl.SSLContext:
context = ssl.SSLContext()
context.set_ciphers(self.CIPHERS)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context
async def _get_ssl_context(self) -> ssl.SSLContext:
if not self._ssl_context:
loop = asyncio.get_running_loop()
self._ssl_context = await loop.run_in_executor(
None, self._create_ssl_context
)
return self._ssl_context
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url:
url = self._token_url
else:
url = self._app_url
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
passthrough_request_str = json_dumps(passthrough_request)
if TYPE_CHECKING:
assert self._pwd_hash
assert self._local_nonce
assert self._seq
tag = self.generate_tag(
passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq
)
headers = {**self._headers, "Seq": str(self._seq), "Tapo_tag": tag}
self._seq += 1
status_code, resp_dict = await self._http_client.post(
url,
json=passthrough_request_str,
headers=headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None
if "result" in resp_dict and "response" in resp_dict["result"]:
raw_response: str = resp_dict["result"]["response"]
else:
# Tapo Cameras respond unencrypted to single requests.
return resp_dict
try:
response = self._encryption_session.decrypt(raw_response.encode())
ret_val = json_loads(response)
except Exception as ex:
try:
ret_val = json_loads(raw_response)
_LOGGER.debug(
"Received unencrypted response over secure passthrough from %s",
self._host,
)
except Exception:
raise KasaException(
f"Unable to decrypt response from {self._host}, "
+ f"error: {ex}, response: {raw_response}",
ex,
) from ex
return ret_val # type: ignore[return-value]
@staticmethod
def generate_confirm_hash(
local_nonce: str, server_nonce: str, pwd_hash: str
) -> str:
"""Generate an auth hash for the protocol on the supplied credentials."""
expected_confirm_bytes = _sha256_hash(
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
)
return expected_confirm_bytes + server_nonce + local_nonce
@staticmethod
def generate_digest_password(
local_nonce: str, server_nonce: str, pwd_hash: str
) -> str:
"""Generate an auth hash for the protocol on the supplied credentials."""
digest_password_hash = _sha256_hash(
pwd_hash.encode() + local_nonce.encode() + server_nonce.encode()
)
return (
digest_password_hash.encode() + local_nonce.encode() + server_nonce.encode()
).decode()
@staticmethod
def generate_encryption_token(
token_type: str, local_nonce: str, server_nonce: str, pwd_hash: str
) -> bytes:
"""Generate encryption token."""
hashedKey = _sha256_hash(
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
)
return _sha256(
token_type.encode()
+ local_nonce.encode()
+ server_nonce.encode()
+ hashedKey.encode()
)[:16]
@staticmethod
def generate_tag(request: str, local_nonce: str, pwd_hash: str, seq: int) -> str:
"""Generate the tag header from the request for the header."""
pwd_nonce_hash = _sha256_hash(pwd_hash.encode() + local_nonce.encode())
tag = _sha256_hash(
pwd_nonce_hash.encode() + request.encode() + str(seq).encode()
)
return tag
async def perform_handshake(self) -> None:
"""Perform the handshake."""
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
async def perform_handshake2(
self, local_nonce: str, server_nonce: str, pwd_hash: str
) -> None:
"""Perform the handshake."""
_LOGGER.debug("Performing handshake2 ...")
digest_password = self.generate_digest_password(
local_nonce, server_nonce, pwd_hash
)
body = {
"method": "login",
"params": {
"cnonce": local_nonce,
"encrypt_type": "3",
"digest_passwd": digest_password,
"username": self._username,
},
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake2"
)
resp_dict = cast(dict, resp_dict)
if (
error_code := self._get_response_error(resp_dict)
) and error_code is SmartErrorCode.INVALID_NONCE:
raise AuthenticationError(
f"Invalid password hash in handshake2 for {self._host}"
)
self._handle_response_error_code(resp_dict, "Error in handshake2")
self._seq = resp_dict["result"]["start_seq"]
stok = resp_dict["result"]["stok"]
self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds")
self._pwd_hash = pwd_hash
self._local_nonce = local_nonce
lsk = self.generate_encryption_token("lsk", local_nonce, server_nonce, pwd_hash)
ivb = self.generate_encryption_token("ivb", local_nonce, server_nonce, pwd_hash)
self._encryption_session = AesEncyptionSession(lsk, ivb)
self._state = TransportState.ESTABLISHED
_LOGGER.debug("Handshake2 complete ...")
async def perform_handshake1(self) -> tuple[str, str, str]:
"""Perform the handshake1."""
resp_dict = None
if self._username:
local_nonce = secrets.token_bytes(8).hex().upper()
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
# Try the default username. If it fails raise the original error_code
if (
not resp_dict
or (error_code := self._get_response_error(resp_dict))
is not SmartErrorCode.INVALID_NONCE
or "nonce" not in resp_dict["result"].get("data", {})
):
local_nonce = secrets.token_bytes(8).hex().upper()
default_resp_dict = await self.try_send_handshake1(
self._default_credentials.username, local_nonce
)
if (
default_error_code := self._get_response_error(default_resp_dict)
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
"result"
].get("data", {}):
_LOGGER.debug("Connected to {self._host} with default username")
self._username = self._default_credentials.username
error_code = default_error_code
resp_dict = default_resp_dict
if not self._username:
raise AuthenticationError(
f"Credentials must be supplied to connect to {self._host}"
)
if error_code is not SmartErrorCode.INVALID_NONCE or (
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
):
raise AuthenticationError(f"Error trying handshake1: {resp_dict}")
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
server_nonce = resp_dict["result"]["data"]["nonce"]
device_confirm = resp_dict["result"]["data"]["device_confirm"]
if self._credentials and self._credentials != Credentials():
pwd_hash = _sha256_hash(self._credentials.password.encode())
elif self._username and self._password:
pwd_hash = _sha256_hash(self._password.encode())
else:
pwd_hash = _sha256_hash(self._default_credentials.password.encode())
expected_confirm_sha256 = self.generate_confirm_hash(
local_nonce, server_nonce, pwd_hash
)
if device_confirm == expected_confirm_sha256:
_LOGGER.debug("Credentials match")
return local_nonce, server_nonce, pwd_hash
if TYPE_CHECKING:
assert self._credentials
assert self._credentials.password
pwd_hash = _md5_hash(self._credentials.password.encode())
expected_confirm_md5 = self.generate_confirm_hash(
local_nonce, server_nonce, pwd_hash
)
if device_confirm == expected_confirm_md5:
_LOGGER.debug("Credentials match")
return local_nonce, server_nonce, pwd_hash
msg = f"Server response doesn't match our challenge on ip {self._host}"
_LOGGER.debug(msg)
raise AuthenticationError(msg)
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
"""Perform the handshake."""
_LOGGER.debug("Will to send handshake1...")
body = {
"method": "login",
"params": {
"cnonce": local_nonce,
"encrypt_type": "3",
"username": username,
},
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
_LOGGER.debug("Device responded with: %s", resp_dict)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake1"
)
return cast(dict, resp_dict)
async def send(self, request: str) -> dict[str, Any]:
"""Send the request."""
if self._state is TransportState.HANDSHAKE_REQUIRED:
await self.perform_handshake()
return await self.send_secure_passthrough(request)
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal handshake state."""
self._state = TransportState.HANDSHAKE_REQUIRED
self._encryption_session = None
self._seq = 0
self._pwd_hash = None
self._local_nonce = None