mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-27 10:31:59 +00:00
Move TAPO smartcamera out of experimental package (#1255)
Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
@@ -1,15 +0,0 @@
|
||||
"""Modules for SMARTCAMERA devices."""
|
||||
|
||||
from .camera import Camera
|
||||
from .childdevice import ChildDevice
|
||||
from .device import DeviceModule
|
||||
from .led import Led
|
||||
from .time import Time
|
||||
|
||||
__all__ = [
|
||||
"Camera",
|
||||
"ChildDevice",
|
||||
"DeviceModule",
|
||||
"Led",
|
||||
"Time",
|
||||
]
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Implementation of device module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from ...credentials import Credentials
|
||||
from ...device_type import DeviceType
|
||||
from ...feature import Feature
|
||||
from ..smartcameramodule import SmartCameraModule
|
||||
|
||||
LOCAL_STREAMING_PORT = 554
|
||||
|
||||
|
||||
class Camera(SmartCameraModule):
|
||||
"""Implementation of device module."""
|
||||
|
||||
QUERY_GETTER_NAME = "getLensMaskConfig"
|
||||
QUERY_MODULE_NAME = "lens_mask"
|
||||
QUERY_SECTION_NAMES = "lens_mask_info"
|
||||
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
self._device,
|
||||
id="state",
|
||||
name="State",
|
||||
attribute_getter="is_on",
|
||||
attribute_setter="set_state",
|
||||
type=Feature.Type.Switch,
|
||||
category=Feature.Category.Primary,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool:
|
||||
"""Return the device id."""
|
||||
return self.data["lens_mask_info"]["enabled"] == "off"
|
||||
|
||||
def stream_rtsp_url(self, credentials: Credentials | None = None) -> str | None:
|
||||
"""Return the local rtsp streaming url.
|
||||
|
||||
:param credentials: Credentials for camera account.
|
||||
These could be different credentials to tplink cloud credentials.
|
||||
If not provided will use tplink credentials if available
|
||||
:return: rtsp url with escaped credentials or None if no credentials or
|
||||
camera is off.
|
||||
"""
|
||||
if not self.is_on:
|
||||
return None
|
||||
dev = self._device
|
||||
if not credentials:
|
||||
credentials = dev.credentials
|
||||
if not credentials or not credentials.username or not credentials.password:
|
||||
return None
|
||||
username = quote_plus(credentials.username)
|
||||
password = quote_plus(credentials.password)
|
||||
return f"rtsp://{username}:{password}@{dev.host}:{LOCAL_STREAMING_PORT}/stream1"
|
||||
|
||||
async def set_state(self, on: bool) -> dict:
|
||||
"""Set the device state."""
|
||||
# Turning off enables the privacy mask which is why value is reversed.
|
||||
params = {"enabled": "off" if on else "on"}
|
||||
return await self._device._query_setter_helper(
|
||||
"setLensMaskConfig", self.QUERY_MODULE_NAME, "lens_mask_info", params
|
||||
)
|
||||
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device."""
|
||||
return self._device.device_type is DeviceType.Camera
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Module for child devices."""
|
||||
|
||||
from ...device_type import DeviceType
|
||||
from ..smartcameramodule import SmartCameraModule
|
||||
|
||||
|
||||
class ChildDevice(SmartCameraModule):
|
||||
"""Implementation for child devices."""
|
||||
|
||||
REQUIRED_COMPONENT = "childControl"
|
||||
NAME = "childdevice"
|
||||
QUERY_GETTER_NAME = "getChildDeviceList"
|
||||
# This module is unusual in that QUERY_MODULE_NAME in the response is not
|
||||
# the same one used in the request.
|
||||
QUERY_MODULE_NAME = "child_device_list"
|
||||
|
||||
def query(self) -> dict:
|
||||
"""Query to execute during the update cycle.
|
||||
|
||||
Default implementation uses the raw query getter w/o parameters.
|
||||
"""
|
||||
return {self.QUERY_GETTER_NAME: {"childControl": {"start_index": 0}}}
|
||||
|
||||
async def _check_supported(self) -> bool:
|
||||
"""Additional check to see if the module is supported by the device."""
|
||||
return self._device.device_type is DeviceType.Hub
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Implementation of device module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...feature import Feature
|
||||
from ..smartcameramodule import SmartCameraModule
|
||||
|
||||
|
||||
class DeviceModule(SmartCameraModule):
|
||||
"""Implementation of device module."""
|
||||
|
||||
NAME = "devicemodule"
|
||||
QUERY_GETTER_NAME = "getDeviceInfo"
|
||||
QUERY_MODULE_NAME = "device_info"
|
||||
QUERY_SECTION_NAMES = ["basic_info", "info"]
|
||||
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
self._device,
|
||||
id="device_id",
|
||||
name="Device ID",
|
||||
attribute_getter="device_id",
|
||||
category=Feature.Category.Debug,
|
||||
type=Feature.Type.Sensor,
|
||||
)
|
||||
)
|
||||
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Overriden to prevent module disabling.
|
||||
|
||||
Overrides the default behaviour to disable a module if the query returns
|
||||
an error because this module is critical.
|
||||
"""
|
||||
|
||||
@property
|
||||
def device_id(self) -> str:
|
||||
"""Return the device id."""
|
||||
return self.data["basic_info"]["dev_id"]
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Module for led controls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...interfaces.led import Led as LedInterface
|
||||
from ..smartcameramodule import SmartCameraModule
|
||||
|
||||
|
||||
class Led(SmartCameraModule, LedInterface):
|
||||
"""Implementation of led controls."""
|
||||
|
||||
REQUIRED_COMPONENT = "led"
|
||||
QUERY_GETTER_NAME = "getLedStatus"
|
||||
QUERY_MODULE_NAME = "led"
|
||||
QUERY_SECTION_NAMES = "config"
|
||||
|
||||
@property
|
||||
def led(self) -> bool:
|
||||
"""Return current led status."""
|
||||
return self.data["config"]["enabled"] == "on"
|
||||
|
||||
async def set_led(self, enable: bool) -> dict:
|
||||
"""Set led.
|
||||
|
||||
This should probably be a select with always/never/nightmode.
|
||||
"""
|
||||
params = {"enabled": "on"} if enable else {"enabled": "off"}
|
||||
return await self.call("setLedStatus", {"led": {"config": params}})
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Implementation of time module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone, tzinfo
|
||||
from typing import cast
|
||||
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from ...cachedzoneinfo import CachedZoneInfo
|
||||
from ...feature import Feature
|
||||
from ...interfaces import Time as TimeInterface
|
||||
from ..smartcameramodule import SmartCameraModule
|
||||
|
||||
|
||||
class Time(SmartCameraModule, TimeInterface):
|
||||
"""Implementation of device_local_time."""
|
||||
|
||||
QUERY_GETTER_NAME = "getTimezone"
|
||||
QUERY_MODULE_NAME = "system"
|
||||
QUERY_SECTION_NAMES = "basic"
|
||||
|
||||
_timezone: tzinfo = timezone.utc
|
||||
_time: datetime
|
||||
|
||||
def _initialize_features(self) -> None:
|
||||
"""Initialize features after the initial update."""
|
||||
self._add_feature(
|
||||
Feature(
|
||||
device=self._device,
|
||||
id="device_time",
|
||||
name="Device time",
|
||||
attribute_getter="time",
|
||||
container=self,
|
||||
category=Feature.Category.Debug,
|
||||
type=Feature.Type.Sensor,
|
||||
)
|
||||
)
|
||||
|
||||
def query(self) -> dict:
|
||||
"""Query to execute during the update cycle."""
|
||||
q = super().query()
|
||||
q["getClockStatus"] = {self.QUERY_MODULE_NAME: {"name": "clock_status"}}
|
||||
|
||||
return q
|
||||
|
||||
async def _post_update_hook(self) -> None:
|
||||
"""Perform actions after a device update."""
|
||||
time_data = self.data["getClockStatus"]["system"]["clock_status"]
|
||||
timezone_data = self.data["getTimezone"]["system"]["basic"]
|
||||
zone_id = timezone_data["zone_id"]
|
||||
timestamp = time_data["seconds_from_1970"]
|
||||
try:
|
||||
# Zoneinfo will return a DST aware object
|
||||
tz: tzinfo = await CachedZoneInfo.get_cached_zone_info(zone_id)
|
||||
except ZoneInfoNotFoundError:
|
||||
# timezone string like: UTC+10:00
|
||||
timezone_str = timezone_data["timezone"]
|
||||
tz = cast(tzinfo, datetime.strptime(timezone_str[-6:], "%z").tzinfo)
|
||||
|
||||
self._timezone = tz
|
||||
self._time = datetime.fromtimestamp(
|
||||
cast(float, timestamp),
|
||||
tz=tz,
|
||||
)
|
||||
|
||||
@property
|
||||
def timezone(self) -> tzinfo:
|
||||
"""Return current timezone."""
|
||||
return self._timezone
|
||||
|
||||
@property
|
||||
def time(self) -> datetime:
|
||||
"""Return device's current datetime."""
|
||||
return self._time
|
||||
|
||||
async def set_time(self, dt: datetime) -> dict:
|
||||
"""Set device time."""
|
||||
if not dt.tzinfo:
|
||||
timestamp = dt.replace(tzinfo=self.timezone).timestamp()
|
||||
else:
|
||||
timestamp = dt.timestamp()
|
||||
|
||||
lt = datetime.fromtimestamp(timestamp).isoformat().replace("T", " ")
|
||||
params = {"seconds_from_1970": int(timestamp), "local_time": lt}
|
||||
# Doesn't seem to update the time, perhaps because timing_mode is ntp
|
||||
res = await self.call("setTimezone", {"system": {"clock_status": params}})
|
||||
if (zinfo := dt.tzinfo) and isinstance(zinfo, ZoneInfo):
|
||||
tz_params = {"zone_id": zinfo.key}
|
||||
res = await self.call("setTimezone", {"system": {"basic": tz_params}})
|
||||
return res
|
||||
@@ -1,217 +0,0 @@
|
||||
"""Module for smartcamera."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from ..device_type import DeviceType
|
||||
from ..module import Module
|
||||
from ..smart import SmartChildDevice, SmartDevice
|
||||
from .modules.childdevice import ChildDevice
|
||||
from .modules.device import DeviceModule
|
||||
from .smartcameramodule import SmartCameraModule
|
||||
from .smartcameraprotocol import _ChildCameraProtocolWrapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmartCamera(SmartDevice):
|
||||
"""Class for smart cameras."""
|
||||
|
||||
# Modules that are called as part of the init procedure on first update
|
||||
FIRST_UPDATE_MODULES = {DeviceModule, ChildDevice}
|
||||
|
||||
@staticmethod
|
||||
def _get_device_type_from_sysinfo(sysinfo: dict[str, Any]) -> DeviceType:
|
||||
"""Find type to be displayed as a supported device category."""
|
||||
device_type = sysinfo["device_type"]
|
||||
if device_type.endswith("HUB"):
|
||||
return DeviceType.Hub
|
||||
return DeviceType.Camera
|
||||
|
||||
def _update_internal_info(self, info_resp: dict) -> None:
|
||||
"""Update the internal device info."""
|
||||
info = self._try_get_response(info_resp, "getDeviceInfo")
|
||||
self._info = self._map_info(info["device_info"])
|
||||
|
||||
def _update_children_info(self) -> None:
|
||||
"""Update the internal child device info from the parent info."""
|
||||
if child_info := self._try_get_response(
|
||||
self._last_update, "getChildDeviceList", {}
|
||||
):
|
||||
for info in child_info["child_device_list"]:
|
||||
self._children[info["device_id"]]._update_internal_state(info)
|
||||
|
||||
async def _initialize_smart_child(
|
||||
self, info: dict, child_components: dict
|
||||
) -> SmartDevice:
|
||||
"""Initialize a smart child device attached to a smartcamera."""
|
||||
child_id = info["device_id"]
|
||||
child_protocol = _ChildCameraProtocolWrapper(child_id, self.protocol)
|
||||
try:
|
||||
initial_response = await child_protocol.query(
|
||||
{"get_connect_cloud_state": None}
|
||||
)
|
||||
except Exception as ex:
|
||||
_LOGGER.exception("Error initialising child %s: %s", child_id, ex)
|
||||
|
||||
return await SmartChildDevice.create(
|
||||
parent=self,
|
||||
child_info=info,
|
||||
child_components=child_components,
|
||||
protocol=child_protocol,
|
||||
last_update=initial_response,
|
||||
)
|
||||
|
||||
async def _initialize_children(self) -> None:
|
||||
"""Initialize children for hubs."""
|
||||
child_info_query = {
|
||||
"getChildDeviceList": {"childControl": {"start_index": 0}},
|
||||
"getChildDeviceComponentList": {"childControl": {"start_index": 0}},
|
||||
}
|
||||
resp = await self.protocol.query(child_info_query)
|
||||
self.internal_state.update(resp)
|
||||
|
||||
children_components = {
|
||||
child["device_id"]: {
|
||||
comp["id"]: int(comp["ver_code"]) for comp in child["component_list"]
|
||||
}
|
||||
for child in resp["getChildDeviceComponentList"]["child_component_list"]
|
||||
}
|
||||
children = {}
|
||||
for info in resp["getChildDeviceList"]["child_device_list"]:
|
||||
if (
|
||||
category := info.get("category")
|
||||
) and category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP:
|
||||
child_id = info["device_id"]
|
||||
children[child_id] = await self._initialize_smart_child(
|
||||
info, children_components[child_id]
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("Child device type not supported: %s", info)
|
||||
|
||||
self._children = children
|
||||
|
||||
async def _initialize_modules(self) -> None:
|
||||
"""Initialize modules based on component negotiation response."""
|
||||
for mod in SmartCameraModule.REGISTERED_MODULES.values():
|
||||
if (
|
||||
mod.REQUIRED_COMPONENT
|
||||
and mod.REQUIRED_COMPONENT not in self._components
|
||||
):
|
||||
continue
|
||||
module = mod(self, mod._module_name())
|
||||
if await module._check_supported():
|
||||
self._modules[module.name] = module
|
||||
|
||||
async def _initialize_features(self) -> None:
|
||||
"""Initialize device features."""
|
||||
for module in self.modules.values():
|
||||
module._initialize_features()
|
||||
for feat in module._module_features.values():
|
||||
self._add_feature(feat)
|
||||
|
||||
for child in self._children.values():
|
||||
await child._initialize_features()
|
||||
|
||||
async def _query_setter_helper(
|
||||
self, method: str, module: str, section: str, params: dict | None = None
|
||||
) -> dict:
|
||||
res = await self.protocol.query({method: {module: {section: params}}})
|
||||
|
||||
return res
|
||||
|
||||
async def _query_getter_helper(
|
||||
self, method: str, module: str, sections: str | list[str]
|
||||
) -> Any:
|
||||
res = await self.protocol.query({method: {module: {"name": sections}}})
|
||||
|
||||
return res
|
||||
|
||||
async def _negotiate(self) -> None:
|
||||
"""Perform initialization.
|
||||
|
||||
We fetch the device info and the available components as early as possible.
|
||||
If the device reports supporting child devices, they are also initialized.
|
||||
"""
|
||||
initial_query = {
|
||||
"getDeviceInfo": {"device_info": {"name": ["basic_info", "info"]}},
|
||||
"getAppComponentList": {"app_component": {"name": "app_component_list"}},
|
||||
}
|
||||
resp = await self.protocol.query(initial_query)
|
||||
self._last_update.update(resp)
|
||||
self._update_internal_info(resp)
|
||||
|
||||
self._components = {
|
||||
comp["name"]: int(comp["version"])
|
||||
for comp in resp["getAppComponentList"]["app_component"][
|
||||
"app_component_list"
|
||||
]
|
||||
}
|
||||
|
||||
if "childControl" in self._components and not self.children:
|
||||
await self._initialize_children()
|
||||
|
||||
def _map_info(self, device_info: dict) -> dict:
|
||||
basic_info = device_info["basic_info"]
|
||||
return {
|
||||
"model": basic_info["device_model"],
|
||||
"device_type": basic_info["device_type"],
|
||||
"alias": basic_info["device_alias"],
|
||||
"fw_ver": basic_info["sw_version"],
|
||||
"hw_ver": basic_info["hw_version"],
|
||||
"mac": basic_info["mac"],
|
||||
"hwId": basic_info.get("hw_id"),
|
||||
"oem_id": basic_info["oem_id"],
|
||||
}
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool:
|
||||
"""Return true if the device is on."""
|
||||
if (camera := self.modules.get(Module.Camera)) and not camera.disabled:
|
||||
return camera.is_on
|
||||
|
||||
return True
|
||||
|
||||
async def set_state(self, on: bool) -> dict:
|
||||
"""Set the device state."""
|
||||
if (camera := self.modules.get(Module.Camera)) and not camera.disabled:
|
||||
return await camera.set_state(on)
|
||||
|
||||
return {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> DeviceType:
|
||||
"""Return the device type."""
|
||||
if self._device_type == DeviceType.Unknown:
|
||||
self._device_type = self._get_device_type_from_sysinfo(self._info)
|
||||
return self._device_type
|
||||
|
||||
@property
|
||||
def alias(self) -> str | None:
|
||||
"""Returns the device alias or nickname."""
|
||||
if self._info:
|
||||
return self._info.get("alias")
|
||||
return None
|
||||
|
||||
async def set_alias(self, alias: str) -> dict:
|
||||
"""Set the device name (alias)."""
|
||||
return await self.protocol.query(
|
||||
{
|
||||
"setDeviceAlias": {"system": {"sys": {"dev_alias": alias}}},
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def hw_info(self) -> dict:
|
||||
"""Return hardware info for the device."""
|
||||
return {
|
||||
"sw_ver": self._info.get("hw_ver"),
|
||||
"hw_ver": self._info.get("fw_ver"),
|
||||
"mac": self._info.get("mac"),
|
||||
"type": self._info.get("type"),
|
||||
"hwId": self._info.get("hwId"),
|
||||
"dev_name": self.alias,
|
||||
"oemId": self._info.get("oem_id"),
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Base implementation for SMART modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from ..exceptions import DeviceError, KasaException, SmartErrorCode
|
||||
from ..smart.smartmodule import SmartModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .smartcamera import SmartCamera
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmartCameraModule(SmartModule):
|
||||
"""Base class for SMARTCAMERA modules."""
|
||||
|
||||
#: Query to execute during the main update cycle
|
||||
QUERY_GETTER_NAME: str
|
||||
#: Module name to be queried
|
||||
QUERY_MODULE_NAME: str
|
||||
#: Section name or names to be queried
|
||||
QUERY_SECTION_NAMES: str | list[str]
|
||||
|
||||
REGISTERED_MODULES = {}
|
||||
|
||||
_device: SmartCamera
|
||||
|
||||
def query(self) -> dict:
|
||||
"""Query to execute during the update cycle.
|
||||
|
||||
Default implementation uses the raw query getter w/o parameters.
|
||||
"""
|
||||
return {
|
||||
self.QUERY_GETTER_NAME: {
|
||||
self.QUERY_MODULE_NAME: {"name": self.QUERY_SECTION_NAMES}
|
||||
}
|
||||
}
|
||||
|
||||
async def call(self, method: str, params: dict | None = None) -> dict:
|
||||
"""Call a method.
|
||||
|
||||
Just a helper method.
|
||||
"""
|
||||
if params:
|
||||
module = next(iter(params))
|
||||
section = next(iter(params[module]))
|
||||
else:
|
||||
module = "system"
|
||||
section = "null"
|
||||
|
||||
if method[:3] == "get":
|
||||
return await self._device._query_getter_helper(method, module, section)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
params = cast(dict[str, dict[str, Any]], params)
|
||||
return await self._device._query_setter_helper(
|
||||
method, module, section, params[module][section]
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self) -> dict:
|
||||
"""Return response data for the module."""
|
||||
dev = self._device
|
||||
q = self.query()
|
||||
|
||||
if not q:
|
||||
return dev.sys_info
|
||||
|
||||
if len(q) == 1:
|
||||
query_resp = dev._last_update.get(self.QUERY_GETTER_NAME, {})
|
||||
if isinstance(query_resp, SmartErrorCode):
|
||||
raise DeviceError(
|
||||
f"Error accessing module data in {self._module}",
|
||||
error_code=query_resp,
|
||||
)
|
||||
|
||||
if not query_resp:
|
||||
raise KasaException(
|
||||
f"You need to call update() prior accessing module data"
|
||||
f" for '{self._module}'"
|
||||
)
|
||||
|
||||
return query_resp.get(self.QUERY_MODULE_NAME)
|
||||
else:
|
||||
found = {key: val for key, val in dev._last_update.items() if key in q}
|
||||
for key in q:
|
||||
if key not in found:
|
||||
raise KasaException(
|
||||
f"{key} not found, you need to call update() prior accessing"
|
||||
f" module data for '{self._module}'"
|
||||
)
|
||||
if isinstance(found[key], SmartErrorCode):
|
||||
raise DeviceError(
|
||||
f"Error accessing module data {key} in {self._module}",
|
||||
error_code=found[key],
|
||||
)
|
||||
return found
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Module for SmartCamera Protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat as pf
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import (
|
||||
AuthenticationError,
|
||||
DeviceError,
|
||||
KasaException,
|
||||
_RetryableError,
|
||||
)
|
||||
from ..json import dumps as json_dumps
|
||||
from ..protocols import SmartProtocol
|
||||
from .sslaestransport import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SmartErrorCode,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# List of getMethodNames that should be sent as {"method":"do"}
|
||||
# https://md.depau.eu/s/r1Ys_oWoP#Modules
|
||||
GET_METHODS_AS_DO = {
|
||||
"getSdCardFormatStatus",
|
||||
"getConnectionType",
|
||||
"getUserID",
|
||||
"getP2PSharePassword",
|
||||
"getAESEncryptKey",
|
||||
"getFirmwareAFResult",
|
||||
"getWhitelampStatus",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleRequest:
|
||||
"""Class for returning single request details from helper functions."""
|
||||
|
||||
method_type: str
|
||||
method_name: str
|
||||
param_name: str
|
||||
request: dict[str, Any]
|
||||
|
||||
|
||||
class SmartCameraProtocol(SmartProtocol):
|
||||
"""Class for SmartCamera Protocol."""
|
||||
|
||||
async def _handle_response_lists(
|
||||
self, response_result: dict[str, Any], method: str, retry_count: int
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _handle_response_error_code(
|
||||
self, resp_dict: dict, method: str, raise_on_error: bool = True
|
||||
) -> None:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
||||
)
|
||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
|
||||
if error_code is SmartErrorCode.SUCCESS:
|
||||
return
|
||||
|
||||
if not raise_on_error:
|
||||
resp_dict["result"] = error_code
|
||||
return
|
||||
|
||||
msg = (
|
||||
f"Error querying device: {self._host}: "
|
||||
+ f"{error_code.name}({error_code.value})"
|
||||
+ f" for method: {method}"
|
||||
)
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
raise _RetryableError(msg, error_code=error_code)
|
||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||
raise AuthenticationError(msg, error_code=error_code)
|
||||
raise DeviceError(msg, error_code=error_code)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying transport."""
|
||||
await self._transport.close()
|
||||
|
||||
@staticmethod
|
||||
def _get_smart_camera_single_request(
|
||||
request: dict[str, dict[str, Any]],
|
||||
) -> SingleRequest:
|
||||
method = next(iter(request))
|
||||
if method == "multipleRequest":
|
||||
method_type = "multi"
|
||||
params = request["multipleRequest"]
|
||||
req = {"method": "multipleRequest", "params": params}
|
||||
return SingleRequest("multi", "multipleRequest", "", req)
|
||||
|
||||
param = next(iter(request[method]))
|
||||
method_type = method
|
||||
req = {
|
||||
"method": method,
|
||||
param: request[method][param],
|
||||
}
|
||||
return SingleRequest(method_type, method, param, req)
|
||||
|
||||
@staticmethod
|
||||
def _make_snake_name(name: str) -> str:
|
||||
"""Convert camel or pascal case to snake name."""
|
||||
sn = "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
|
||||
return sn
|
||||
|
||||
@staticmethod
|
||||
def _make_smart_camera_single_request(
|
||||
request: str,
|
||||
) -> SingleRequest:
|
||||
"""Make a single request given a method name and no params.
|
||||
|
||||
If method like getSomeThing then module will be some_thing.
|
||||
"""
|
||||
method = request
|
||||
method_type = request[:3]
|
||||
snake_name = SmartCameraProtocol._make_snake_name(request)
|
||||
param = snake_name[4:]
|
||||
if (
|
||||
(short_method := method[:3])
|
||||
and short_method in {"get", "set"}
|
||||
and method not in GET_METHODS_AS_DO
|
||||
):
|
||||
method_type = short_method
|
||||
param = snake_name[4:]
|
||||
else:
|
||||
method_type = "do"
|
||||
param = snake_name
|
||||
req = {"method": method_type, param: {}}
|
||||
return SingleRequest(method_type, method, param, req)
|
||||
|
||||
async def _execute_query(
|
||||
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
|
||||
) -> dict:
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
if isinstance(request, dict):
|
||||
method = next(iter(request))
|
||||
if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
|
||||
single_request = self._get_smart_camera_single_request(request)
|
||||
else:
|
||||
return await self._execute_multiple_query(request, retry_count)
|
||||
else:
|
||||
single_request = self._make_smart_camera_single_request(request)
|
||||
|
||||
smart_request = json_dumps(single_request.request)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s >> %s",
|
||||
self._host,
|
||||
pf(smart_request),
|
||||
)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self._host,
|
||||
pf(response_data),
|
||||
)
|
||||
|
||||
if "error_code" in response_data:
|
||||
# H200 does not return an error code
|
||||
self._handle_response_error_code(response_data, single_request.method_name)
|
||||
# Requests that are invalid and raise PROTOCOL_FORMAT_ERROR when sent
|
||||
# as a multipleRequest will return {} when sent as a single request.
|
||||
if single_request.method_type == "get" and (
|
||||
not (section := next(iter(response_data))) or response_data[section] == {}
|
||||
):
|
||||
raise DeviceError(
|
||||
f"No results for get request {single_request.method_name}"
|
||||
)
|
||||
|
||||
# TODO need to update handle response lists
|
||||
|
||||
if single_request.method_type == "do":
|
||||
return {single_request.method_name: response_data}
|
||||
if single_request.method_type == "set":
|
||||
return {}
|
||||
if single_request.method_type == "multi":
|
||||
return {single_request.method_name: response_data["result"]}
|
||||
return {
|
||||
single_request.method_name: {
|
||||
single_request.param_name: response_data[single_request.param_name]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class _ChildCameraProtocolWrapper(SmartProtocol):
|
||||
"""Protocol wrapper for controlling child devices.
|
||||
|
||||
This is an internal class used to communicate with child devices,
|
||||
and should not be used directly.
|
||||
|
||||
This class overrides query() method of the protocol to modify all
|
||||
outgoing queries to use ``controlChild`` command, and unwraps the
|
||||
device responses before returning to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, device_id: str, base_protocol: SmartProtocol) -> None:
|
||||
self._device_id = device_id
|
||||
self._protocol = base_protocol
|
||||
self._transport = base_protocol._transport
|
||||
|
||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Wrap request inside controlChild envelope."""
|
||||
return await self._query(request, retry_count)
|
||||
|
||||
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
"""Wrap request inside controlChild envelope."""
|
||||
if not isinstance(request, dict):
|
||||
raise KasaException("Child requests must be dictionaries.")
|
||||
requests = []
|
||||
methods = []
|
||||
for key, val in request.items():
|
||||
request = {
|
||||
"method": "controlChild",
|
||||
"params": {
|
||||
"childControl": {
|
||||
"device_id": self._device_id,
|
||||
"request_data": {"method": key, "params": val},
|
||||
}
|
||||
},
|
||||
}
|
||||
methods.append(key)
|
||||
requests.append(request)
|
||||
|
||||
multipleRequest = {"multipleRequest": {"requests": requests}}
|
||||
|
||||
response = await self._protocol.query(multipleRequest, retry_count)
|
||||
|
||||
responses = response["multipleRequest"]["responses"]
|
||||
response_dict = {}
|
||||
for index_id, response in enumerate(responses):
|
||||
response_data = response["result"]["response_data"]
|
||||
method = methods[index_id]
|
||||
self._handle_response_error_code(
|
||||
response_data, method, raise_on_error=False
|
||||
)
|
||||
response_dict[method] = response_data.get("result")
|
||||
|
||||
return response_dict
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Do nothing as the parent owns the protocol."""
|
||||
@@ -1,478 +0,0 @@
|
||||
"""Implementation of the TP-Link SSL AES transport."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import ssl
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Any, Dict, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from ..credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
AuthenticationError,
|
||||
DeviceError,
|
||||
KasaException,
|
||||
SmartErrorCode,
|
||||
_RetryableError,
|
||||
)
|
||||
from ..httpclient import HttpClient
|
||||
from ..json import dumps as json_dumps
|
||||
from ..json import loads as json_loads
|
||||
from ..transports import AesEncyptionSession, BaseTransport
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ONE_DAY_SECONDS = 86400
|
||||
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
|
||||
|
||||
|
||||
def _sha256(payload: bytes) -> bytes:
|
||||
return hashlib.sha256(payload).digest() # noqa: S324
|
||||
|
||||
|
||||
def _md5_hash(payload: bytes) -> str:
|
||||
return hashlib.md5(payload).hexdigest().upper() # noqa: S324
|
||||
|
||||
|
||||
def _sha256_hash(payload: bytes) -> str:
|
||||
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
|
||||
|
||||
|
||||
class TransportState(Enum):
|
||||
"""Enum for AES state."""
|
||||
|
||||
HANDSHAKE_REQUIRED = auto() # Handshake needed
|
||||
ESTABLISHED = auto() # Ready to send requests
|
||||
|
||||
|
||||
class SslAesTransport(BaseTransport):
|
||||
"""Implementation of the AES encryption protocol.
|
||||
|
||||
AES is the name used in device discovery for TP-Link's TAPO encryption
|
||||
protocol, sometimes used by newer firmware versions on kasa devices.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT: int = 443
|
||||
COMMON_HEADERS = {
|
||||
"Content-Type": "application/json; charset=UTF-8",
|
||||
"requestByApp": "true",
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"User-Agent": "Tapo CameraClient Android",
|
||||
}
|
||||
CIPHERS = ":".join(
|
||||
[
|
||||
"AES256-GCM-SHA384",
|
||||
"AES256-SHA256",
|
||||
"AES128-GCM-SHA256",
|
||||
"AES128-SHA256",
|
||||
"AES256-SHA",
|
||||
]
|
||||
)
|
||||
DEFAULT_TIMEOUT = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: DeviceConfig,
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
self._login_version = config.connection_type.login_version
|
||||
if (
|
||||
not self._credentials or self._credentials.username is None
|
||||
) and not self._credentials_hash:
|
||||
self._credentials = Credentials()
|
||||
self._default_credentials: Credentials = get_default_credentials(
|
||||
DEFAULT_CREDENTIALS["TAPOCAMERA"]
|
||||
)
|
||||
self._http_client: HttpClient = HttpClient(config)
|
||||
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
|
||||
self._encryption_session: AesEncyptionSession | None = None
|
||||
self._session_expire_at: float | None = None
|
||||
|
||||
self._host_port = f"{self._host}:{self._port}"
|
||||
self._app_url = URL(f"https://{self._host_port}")
|
||||
self._token_url: URL | None = None
|
||||
self._ssl_context: ssl.SSLContext | None = None
|
||||
ref = str(self._token_url) if self._token_url else str(self._app_url)
|
||||
self._headers = {
|
||||
**self.COMMON_HEADERS,
|
||||
"Host": self._host_port,
|
||||
"Referer": ref,
|
||||
}
|
||||
self._seq: int | None = None
|
||||
self._pwd_hash: str | None = None
|
||||
self._username: str | None = None
|
||||
self._password: str | None = None
|
||||
if self._credentials != Credentials() and self._credentials:
|
||||
self._username = self._credentials.username
|
||||
self._password = self._credentials.password
|
||||
elif self._credentials_hash:
|
||||
ch = json_loads(base64.b64decode(self._credentials_hash.encode()))
|
||||
self._password = ch["pwd"]
|
||||
self._username = ch["un"]
|
||||
self._local_nonce: str | None = None
|
||||
|
||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||
|
||||
@property
|
||||
def default_port(self) -> int:
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@staticmethod
|
||||
def _create_b64_credentials(credentials: Credentials) -> str:
|
||||
ch = {"un": credentials.username, "pwd": credentials.password}
|
||||
return base64.b64encode(json_dumps(ch).encode()).decode()
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str | None:
|
||||
"""The hashed credentials used by the transport."""
|
||||
if self._credentials == Credentials():
|
||||
return None
|
||||
if not self._credentials and self._credentials_hash:
|
||||
return self._credentials_hash
|
||||
if (cred := self._credentials) and cred.password and cred.username:
|
||||
return self._create_b64_credentials(cred)
|
||||
return None
|
||||
|
||||
def _get_response_error(self, resp_dict: Any) -> SmartErrorCode:
|
||||
error_code_raw = resp_dict.get("error_code")
|
||||
try:
|
||||
error_code = SmartErrorCode.from_int(error_code_raw)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Device %s received unknown error code: %s", self._host, error_code_raw
|
||||
)
|
||||
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
|
||||
return error_code
|
||||
|
||||
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
|
||||
error_code = self._get_response_error(resp_dict)
|
||||
if error_code is SmartErrorCode.SUCCESS:
|
||||
return
|
||||
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
|
||||
if error_code in SMART_RETRYABLE_ERRORS:
|
||||
raise _RetryableError(msg, error_code=error_code)
|
||||
if error_code in SMART_AUTHENTICATION_ERRORS:
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
raise AuthenticationError(msg, error_code=error_code)
|
||||
raise DeviceError(msg, error_code=error_code)
|
||||
|
||||
def _create_ssl_context(self) -> ssl.SSLContext:
|
||||
context = ssl.SSLContext()
|
||||
context.set_ciphers(self.CIPHERS)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
return context
|
||||
|
||||
async def _get_ssl_context(self) -> ssl.SSLContext:
|
||||
if not self._ssl_context:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._ssl_context = await loop.run_in_executor(
|
||||
None, self._create_ssl_context
|
||||
)
|
||||
return self._ssl_context
|
||||
|
||||
async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
|
||||
"""Send encrypted message as passthrough."""
|
||||
if self._state is TransportState.ESTABLISHED and self._token_url:
|
||||
url = self._token_url
|
||||
else:
|
||||
url = self._app_url
|
||||
|
||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||
passthrough_request = {
|
||||
"method": "securePassthrough",
|
||||
"params": {"request": encrypted_payload.decode()},
|
||||
}
|
||||
passthrough_request_str = json_dumps(passthrough_request)
|
||||
if TYPE_CHECKING:
|
||||
assert self._pwd_hash
|
||||
assert self._local_nonce
|
||||
assert self._seq
|
||||
tag = self.generate_tag(
|
||||
passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq
|
||||
)
|
||||
headers = {**self._headers, "Seq": str(self._seq), "Tapo_tag": tag}
|
||||
self._seq += 1
|
||||
status_code, resp_dict = await self._http_client.post(
|
||||
url,
|
||||
json=passthrough_request_str,
|
||||
headers=headers,
|
||||
ssl=await self._get_ssl_context(),
|
||||
)
|
||||
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to passthrough"
|
||||
)
|
||||
|
||||
self._handle_response_error_code(
|
||||
resp_dict, "Error sending secure_passthrough message"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||
assert self._encryption_session is not None
|
||||
|
||||
if "result" in resp_dict and "response" in resp_dict["result"]:
|
||||
raw_response: str = resp_dict["result"]["response"]
|
||||
else:
|
||||
# Tapo Cameras respond unencrypted to single requests.
|
||||
return resp_dict
|
||||
|
||||
try:
|
||||
response = self._encryption_session.decrypt(raw_response.encode())
|
||||
ret_val = json_loads(response)
|
||||
except Exception as ex:
|
||||
try:
|
||||
ret_val = json_loads(raw_response)
|
||||
_LOGGER.debug(
|
||||
"Received unencrypted response over secure passthrough from %s",
|
||||
self._host,
|
||||
)
|
||||
except Exception:
|
||||
raise KasaException(
|
||||
f"Unable to decrypt response from {self._host}, "
|
||||
+ f"error: {ex}, response: {raw_response}",
|
||||
ex,
|
||||
) from ex
|
||||
return ret_val # type: ignore[return-value]
|
||||
|
||||
@staticmethod
|
||||
def generate_confirm_hash(
|
||||
local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> str:
|
||||
"""Generate an auth hash for the protocol on the supplied credentials."""
|
||||
expected_confirm_bytes = _sha256_hash(
|
||||
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
|
||||
)
|
||||
return expected_confirm_bytes + server_nonce + local_nonce
|
||||
|
||||
@staticmethod
|
||||
def generate_digest_password(
|
||||
local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> str:
|
||||
"""Generate an auth hash for the protocol on the supplied credentials."""
|
||||
digest_password_hash = _sha256_hash(
|
||||
pwd_hash.encode() + local_nonce.encode() + server_nonce.encode()
|
||||
)
|
||||
return (
|
||||
digest_password_hash.encode() + local_nonce.encode() + server_nonce.encode()
|
||||
).decode()
|
||||
|
||||
@staticmethod
|
||||
def generate_encryption_token(
|
||||
token_type: str, local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> bytes:
|
||||
"""Generate encryption token."""
|
||||
hashedKey = _sha256_hash(
|
||||
local_nonce.encode() + pwd_hash.encode() + server_nonce.encode()
|
||||
)
|
||||
return _sha256(
|
||||
token_type.encode()
|
||||
+ local_nonce.encode()
|
||||
+ server_nonce.encode()
|
||||
+ hashedKey.encode()
|
||||
)[:16]
|
||||
|
||||
@staticmethod
|
||||
def generate_tag(request: str, local_nonce: str, pwd_hash: str, seq: int) -> str:
|
||||
"""Generate the tag header from the request for the header."""
|
||||
pwd_nonce_hash = _sha256_hash(pwd_hash.encode() + local_nonce.encode())
|
||||
tag = _sha256_hash(
|
||||
pwd_nonce_hash.encode() + request.encode() + str(seq).encode()
|
||||
)
|
||||
return tag
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
|
||||
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
|
||||
|
||||
async def perform_handshake2(
|
||||
self, local_nonce: str, server_nonce: str, pwd_hash: str
|
||||
) -> None:
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Performing handshake2 ...")
|
||||
digest_password = self.generate_digest_password(
|
||||
local_nonce, server_nonce, pwd_hash
|
||||
)
|
||||
body = {
|
||||
"method": "login",
|
||||
"params": {
|
||||
"cnonce": local_nonce,
|
||||
"encrypt_type": "3",
|
||||
"digest_passwd": digest_password,
|
||||
"username": self._username,
|
||||
},
|
||||
}
|
||||
http_client = self._http_client
|
||||
status_code, resp_dict = await http_client.post(
|
||||
self._app_url,
|
||||
json=body,
|
||||
headers=self._headers,
|
||||
ssl=await self._get_ssl_context(),
|
||||
)
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake2"
|
||||
)
|
||||
resp_dict = cast(dict, resp_dict)
|
||||
if (
|
||||
error_code := self._get_response_error(resp_dict)
|
||||
) and error_code is SmartErrorCode.INVALID_NONCE:
|
||||
raise AuthenticationError(
|
||||
f"Invalid password hash in handshake2 for {self._host}"
|
||||
)
|
||||
|
||||
self._handle_response_error_code(resp_dict, "Error in handshake2")
|
||||
|
||||
self._seq = resp_dict["result"]["start_seq"]
|
||||
stok = resp_dict["result"]["stok"]
|
||||
self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds")
|
||||
self._pwd_hash = pwd_hash
|
||||
self._local_nonce = local_nonce
|
||||
lsk = self.generate_encryption_token("lsk", local_nonce, server_nonce, pwd_hash)
|
||||
ivb = self.generate_encryption_token("ivb", local_nonce, server_nonce, pwd_hash)
|
||||
self._encryption_session = AesEncyptionSession(lsk, ivb)
|
||||
self._state = TransportState.ESTABLISHED
|
||||
_LOGGER.debug("Handshake2 complete ...")
|
||||
|
||||
async def perform_handshake1(self) -> tuple[str, str, str]:
|
||||
"""Perform the handshake1."""
|
||||
resp_dict = None
|
||||
if self._username:
|
||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
|
||||
|
||||
# Try the default username. If it fails raise the original error_code
|
||||
if (
|
||||
not resp_dict
|
||||
or (error_code := self._get_response_error(resp_dict))
|
||||
is not SmartErrorCode.INVALID_NONCE
|
||||
or "nonce" not in resp_dict["result"].get("data", {})
|
||||
):
|
||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||
default_resp_dict = await self.try_send_handshake1(
|
||||
self._default_credentials.username, local_nonce
|
||||
)
|
||||
if (
|
||||
default_error_code := self._get_response_error(default_resp_dict)
|
||||
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
|
||||
"result"
|
||||
].get("data", {}):
|
||||
_LOGGER.debug("Connected to {self._host} with default username")
|
||||
self._username = self._default_credentials.username
|
||||
error_code = default_error_code
|
||||
resp_dict = default_resp_dict
|
||||
|
||||
if not self._username:
|
||||
raise AuthenticationError(
|
||||
f"Credentials must be supplied to connect to {self._host}"
|
||||
)
|
||||
if error_code is not SmartErrorCode.INVALID_NONCE or (
|
||||
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
|
||||
):
|
||||
raise AuthenticationError(f"Error trying handshake1: {resp_dict}")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
resp_dict = cast(Dict[str, Any], resp_dict)
|
||||
|
||||
server_nonce = resp_dict["result"]["data"]["nonce"]
|
||||
device_confirm = resp_dict["result"]["data"]["device_confirm"]
|
||||
if self._credentials and self._credentials != Credentials():
|
||||
pwd_hash = _sha256_hash(self._credentials.password.encode())
|
||||
elif self._username and self._password:
|
||||
pwd_hash = _sha256_hash(self._password.encode())
|
||||
else:
|
||||
pwd_hash = _sha256_hash(self._default_credentials.password.encode())
|
||||
|
||||
expected_confirm_sha256 = self.generate_confirm_hash(
|
||||
local_nonce, server_nonce, pwd_hash
|
||||
)
|
||||
if device_confirm == expected_confirm_sha256:
|
||||
_LOGGER.debug("Credentials match")
|
||||
return local_nonce, server_nonce, pwd_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._credentials
|
||||
assert self._credentials.password
|
||||
pwd_hash = _md5_hash(self._credentials.password.encode())
|
||||
expected_confirm_md5 = self.generate_confirm_hash(
|
||||
local_nonce, server_nonce, pwd_hash
|
||||
)
|
||||
if device_confirm == expected_confirm_md5:
|
||||
_LOGGER.debug("Credentials match")
|
||||
return local_nonce, server_nonce, pwd_hash
|
||||
|
||||
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
||||
_LOGGER.debug(msg)
|
||||
raise AuthenticationError(msg)
|
||||
|
||||
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will to send handshake1...")
|
||||
|
||||
body = {
|
||||
"method": "login",
|
||||
"params": {
|
||||
"cnonce": local_nonce,
|
||||
"encrypt_type": "3",
|
||||
"username": username,
|
||||
},
|
||||
}
|
||||
http_client = self._http_client
|
||||
|
||||
status_code, resp_dict = await http_client.post(
|
||||
self._app_url,
|
||||
json=body,
|
||||
headers=self._headers,
|
||||
ssl=await self._get_ssl_context(),
|
||||
)
|
||||
|
||||
_LOGGER.debug("Device responded with: %s", resp_dict)
|
||||
|
||||
if status_code != 200:
|
||||
raise KasaException(
|
||||
f"{self._host} responded with an unexpected "
|
||||
+ f"status code {status_code} to handshake1"
|
||||
)
|
||||
|
||||
return cast(dict, resp_dict)
|
||||
|
||||
async def send(self, request: str) -> dict[str, Any]:
|
||||
"""Send the request."""
|
||||
if self._state is TransportState.HANDSHAKE_REQUIRED:
|
||||
await self.perform_handshake()
|
||||
|
||||
return await self.send_secure_passthrough(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the http client and reset internal state."""
|
||||
await self.reset()
|
||||
await self._http_client.close()
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset internal handshake state."""
|
||||
self._state = TransportState.HANDSHAKE_REQUIRED
|
||||
self._encryption_session = None
|
||||
self._seq = 0
|
||||
self._pwd_hash = None
|
||||
self._local_nonce = None
|
||||
Reference in New Issue
Block a user