Merge remote-tracking branch 'upstream/master' into feat/light_module_feats

This commit is contained in:
Steven B
2024-12-11 13:21:46 +00:00
78 changed files with 5065 additions and 854 deletions

View File

@@ -36,9 +36,11 @@ from kasa.exceptions import (
)
from kasa.feature import Feature
from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState
from kasa.interfaces.thermostat import Thermostat, ThermostatState
from kasa.module import Module
from kasa.protocols import BaseProtocol, IotProtocol, SmartProtocol
from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401
from kasa.smartcam.modules.camera import StreamResolution
from kasa.transports import BaseTransport
__version__ = version("python-kasa")
@@ -72,6 +74,9 @@ __all__ = [
"DeviceConnectionParameters",
"DeviceEncryptionType",
"DeviceFamily",
"ThermostatState",
"Thermostat",
"StreamResolution",
]
from . import iot

View File

@@ -14,8 +14,17 @@ from kasa import (
Discover,
UnsupportedDeviceError,
)
from kasa.discover import ConnectAttempt, DiscoveryResult
from kasa.discover import (
NEW_DISCOVERY_REDACTORS,
ConnectAttempt,
DiscoveredRaw,
DiscoveryResult,
)
from kasa.iot.iotdevice import _extract_sys_info
from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.protocols.protocol import redact_data
from ..json import dumps as json_dumps
from .common import echo, error
@@ -63,7 +72,9 @@ async def detail(ctx):
await ctx.parent.invoke(state)
echo()
discovered = await _discover(ctx, print_discovered, print_unsupported)
discovered = await _discover(
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported
)
if ctx.parent.parent.params["host"]:
return discovered
@@ -76,6 +87,33 @@ async def detail(ctx):
return discovered
@discover.command()
@click.option(
"--redact/--no-redact",
default=False,
is_flag=True,
type=bool,
help="Set flag to redact sensitive data from raw output.",
)
@click.pass_context
async def raw(ctx, redact: bool):
"""Return raw discovery data returned from devices."""
def print_raw(discovered: DiscoveredRaw):
if redact:
redactors = (
NEW_DISCOVERY_REDACTORS
if discovered["meta"]["port"] == Discover.DISCOVERY_PORT_2
else IOT_REDACTORS
)
discovered["discovery_response"] = redact_data(
discovered["discovery_response"], redactors
)
echo(json_dumps(discovered, indent=True))
return await _discover(ctx, print_raw=print_raw, do_echo=False)
@discover.command()
@click.pass_context
async def list(ctx):
@@ -101,10 +139,17 @@ async def list(ctx):
echo(f"{host:<15} UNSUPPORTED DEVICE")
echo(f"{'HOST':<15} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} {'ALIAS'}")
return await _discover(ctx, print_discovered, print_unsupported, do_echo=False)
return await _discover(
ctx,
print_discovered=print_discovered,
print_unsupported=print_unsupported,
do_echo=False,
)
async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
async def _discover(
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True
):
params = ctx.parent.parent.params
target = params["target"]
username = params["username"]
@@ -125,6 +170,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
timeout=timeout,
discovery_timeout=discovery_timeout,
on_unsupported=print_unsupported,
on_discovered_raw=print_raw,
)
if do_echo:
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
@@ -136,6 +182,7 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
port=port,
timeout=timeout,
credentials=credentials,
on_discovered_raw=print_raw,
)
for device in discovered_devices.values():
@@ -201,8 +248,8 @@ def _echo_discovery_info(discovery_info) -> None:
if discovery_info is None:
return
if "system" in discovery_info and "get_sysinfo" in discovery_info["system"]:
_echo_dictionary(discovery_info["system"]["get_sysinfo"])
if sysinfo := _extract_sys_info(discovery_info):
_echo_dictionary(sysinfo)
return
try:
@@ -230,10 +277,12 @@ def _echo_discovery_info(discovery_info) -> None:
_conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud)
_conditional_echo("OBD Src", dr.owner)
_conditional_echo("Factory Default", dr.factory_default)
_conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type)
_conditional_echo("Encrypt Type", dr.encrypt_type)
_conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https)
_conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port)
if mgt_encrypt_schm := dr.mgt_encrypt_schm:
_conditional_echo("Encrypt Type", mgt_encrypt_schm.encrypt_type)
_conditional_echo("Supports HTTPS", mgt_encrypt_schm.is_support_https)
_conditional_echo("HTTP Port", mgt_encrypt_schm.http_port)
_conditional_echo("Login version", mgt_encrypt_schm.lv)
_conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None)
_conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None)

View File

@@ -75,6 +75,7 @@ def _legacy_type_to_class(_type: str) -> Any:
"time": None,
"schedule": None,
"usage": None,
"energy": "usage",
# device commands runnnable at top level
"state": "device",
"on": "device",
@@ -307,6 +308,7 @@ async def cli(
if type == "camera":
encrypt_type = "AES"
https = True
login_version = 2
device_family = "SMART.IPCAMERA"
from kasa.device import Device

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import logging
from typing import cast
import asyncclick as click
@@ -21,21 +20,6 @@ from .common import (
)
@click.command()
@click.option("--index", type=int, required=False)
@click.option("--name", type=str, required=False)
@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False)
@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False)
@click.option("--erase", is_flag=True)
@click.pass_context
async def emeter(ctx: click.Context, index, name, year, month, erase):
"""Query emeter for historical consumption."""
logging.warning("Deprecated, use 'kasa energy'")
return await ctx.invoke(
energy, child_index=index, child=name, year=year, month=month, erase=erase
)
@click.command()
@click.option("--year", type=click.DateTime(["%Y"]), default=None, required=False)
@click.option("--month", type=click.DateTime(["%Y-%m"]), default=None, required=False)
@@ -46,7 +30,7 @@ async def energy(dev: Device, year, month, erase):
Daily and monthly data provided in CSV format.
"""
echo("[bold]== Emeter ==[/bold]")
echo("[bold]== Energy ==[/bold]")
if not (energy := dev.modules.get(Module.Energy)):
error("Device has no energy module.")
return
@@ -71,7 +55,7 @@ async def energy(dev: Device, year, month, erase):
usage_data = await energy.get_daily_stats(year=month.year, month=month.month)
else:
# Call with no argument outputs summary data and returns
emeter_status = await energy.get_status()
emeter_status = energy.status
echo("Current: {} A".format(emeter_status["current"]))
echo("Voltage: {} V".format(emeter_status["voltage"]))

View File

@@ -25,6 +25,7 @@ def get_default_credentials(tuple: tuple[str, str]) -> Credentials:
DEFAULT_CREDENTIALS = {
"KASA": ("a2FzYUB0cC1saW5rLm5ldA==", "a2FzYVNldHVw"),
"KASACAMERA": ("YWRtaW4=", "MjEyMzJmMjk3YTU3YTVhNzQzODk0YTBlNGE4MDFmYzM="),
"TAPO": ("dGVzdEB0cC1saW5rLm5ldA==", "dGVzdA=="),
"TAPOCAMERA": ("YWRtaW4=", "YWRtaW4="),
}

21
kasa/device_factory.py Executable file → Normal file
View File

@@ -12,6 +12,7 @@ from .deviceconfig import DeviceConfig
from .exceptions import KasaException, UnsupportedDeviceError
from .iot import (
IotBulb,
IotCamera,
IotDevice,
IotDimmer,
IotLightStrip,
@@ -32,6 +33,8 @@ from .transports import (
BaseTransport,
KlapTransport,
KlapTransportV2,
LinkieTransportV2,
SslTransport,
XorTransport,
)
from .transports.sslaestransport import SslAesTransport
@@ -137,6 +140,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
DeviceType.Strip: IotStrip,
DeviceType.WallSwitch: IotWallSwitch,
DeviceType.LightStrip: IotLightStrip,
DeviceType.Camera: IotCamera,
}
return TYPE_TO_CLASS[IotDevice._get_device_type_from_sys_info(sysinfo)]
@@ -155,8 +159,10 @@ def get_device_class_from_family(
"SMART.KASAHUB": SmartDevice,
"SMART.KASASWITCH": SmartDevice,
"SMART.IPCAMERA.HTTPS": SmartCamDevice,
"SMART.TAPOROBOVAC": SmartDevice,
"IOT.SMARTPLUGSWITCH": IotPlug,
"IOT.SMARTBULB": IotBulb,
"IOT.IPCAMERA": IotCamera,
}
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
if (
@@ -176,20 +182,31 @@ def get_protocol(
"""Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0]
ctype = config.connection_type
protocol_transport_key = (
protocol_name
+ "."
+ ctype.encryption_type.value
+ (".HTTPS" if ctype.https else "")
+ (
f".{ctype.login_version}"
if ctype.login_version and ctype.login_version > 1
else ""
)
)
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
supported_device_protocols: dict[
str, tuple[type[BaseProtocol], type[BaseTransport]]
] = {
"IOT.XOR": (IotProtocol, XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport),
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
"SMART.AES": (SmartProtocol, AesTransport),
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
"SMART.AES.2": (SmartProtocol, AesTransport),
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
}
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
return None

View File

@@ -21,6 +21,7 @@ class DeviceType(Enum):
Hub = "hub"
Fan = "fan"
Thermostat = "thermostat"
Vacuum = "vacuum"
Unknown = "unknown"
@staticmethod

View File

@@ -69,6 +69,7 @@ class DeviceFamily(Enum):
IotSmartPlugSwitch = "IOT.SMARTPLUGSWITCH"
IotSmartBulb = "IOT.SMARTBULB"
IotIpCamera = "IOT.IPCAMERA"
SmartKasaPlug = "SMART.KASAPLUG"
SmartKasaSwitch = "SMART.KASASWITCH"
SmartTapoPlug = "SMART.TAPOPLUG"
@@ -77,6 +78,7 @@ class DeviceFamily(Enum):
SmartTapoHub = "SMART.TAPOHUB"
SmartKasaHub = "SMART.KASAHUB"
SmartIpCamera = "SMART.IPCAMERA"
SmartTapoRobovac = "SMART.TAPOROBOVAC"
class _DeviceConfigBaseMixin(DataClassJSONMixin):

View File

@@ -99,6 +99,7 @@ from typing import (
Annotated,
Any,
NamedTuple,
TypedDict,
cast,
)
@@ -123,7 +124,7 @@ from kasa.exceptions import (
TimeoutError,
UnsupportedDeviceError,
)
from kasa.iot.iotdevice import IotDevice
from kasa.iot.iotdevice import IotDevice, _extract_sys_info
from kasa.json import DataClassJSONMixin
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
@@ -147,15 +148,35 @@ class ConnectAttempt(NamedTuple):
device: type
class DiscoveredMeta(TypedDict):
"""Meta info about discovery response."""
ip: str
port: int
class DiscoveredRaw(TypedDict):
"""Try to connect attempt."""
meta: DiscoveredMeta
discovery_response: dict
OnDiscoveredCallable = Callable[[Device], Coroutine]
OnDiscoveredRawCallable = Callable[[DiscoveredRaw], None]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
"device_id": lambda x: "REDACTED_" + x[9::],
"device_name": lambda x: "#MASKED_NAME#" if x else "",
"owner": lambda x: "REDACTED_" + x[9::],
"mac": mask_mac,
"master_device_id": lambda x: "REDACTED_" + x[9::],
"group_id": lambda x: "REDACTED_" + x[9::],
"group_name": lambda x: "I01BU0tFRF9TU0lEIw==",
"encrypt_info": lambda x: {**x, "key": "", "data": ""},
}
@@ -213,6 +234,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self,
*,
on_discovered: OnDiscoveredCallable | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
target: str = "255.255.255.255",
discovery_packets: int = 3,
discovery_timeout: int = 5,
@@ -237,6 +259,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.unsupported_device_exceptions: dict = {}
self.invalid_device_exceptions: dict = {}
self.on_unsupported = on_unsupported
self.on_discovered_raw = on_discovered_raw
self.credentials = credentials
self.timeout = timeout
self.discovery_timeout = discovery_timeout
@@ -326,12 +349,23 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
config.timeout = self.timeout
try:
if port == self.discovery_port:
device = Discover._get_device_instance_legacy(data, config)
json_func = Discover._get_discovery_json_legacy
device_func = Discover._get_device_instance_legacy
elif port == Discover.DISCOVERY_PORT_2:
config.uses_http = True
device = Discover._get_device_instance(data, config)
json_func = Discover._get_discovery_json
device_func = Discover._get_device_instance
else:
return
info = json_func(data, ip)
if self.on_discovered_raw is not None:
self.on_discovered_raw(
{
"discovery_response": info,
"meta": {"ip": ip, "port": port},
}
)
device = device_func(info, config)
except UnsupportedDeviceError as udex:
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
self.unsupported_device_exceptions[ip] = udex
@@ -388,6 +422,7 @@ class Discover:
*,
target: str = "255.255.255.255",
on_discovered: OnDiscoveredCallable | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
discovery_timeout: int = 5,
discovery_packets: int = 3,
interface: str | None = None,
@@ -418,6 +453,8 @@ class Discover:
:param target: The target address where to send the broadcast discovery
queries if multi-homing (e.g. 192.168.xxx.255).
:param on_discovered: coroutine to execute on discovery
:param on_discovered_raw: Optional callback once discovered json is loaded
before any attempt to deserialize it and create devices
:param discovery_timeout: Seconds to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets to broadcast
:param interface: Bind to specific interface
@@ -440,6 +477,7 @@ class Discover:
discovery_packets=discovery_packets,
interface=interface,
on_unsupported=on_unsupported,
on_discovered_raw=on_discovered_raw,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
@@ -473,6 +511,7 @@ class Discover:
credentials: Credentials | None = None,
username: str | None = None,
password: str | None = None,
on_discovered_raw: OnDiscoveredRawCallable | None = None,
on_unsupported: OnUnsupportedCallable | None = None,
) -> Device | None:
"""Discover a single device by the given IP address.
@@ -490,6 +529,9 @@ class Discover:
username and password are ignored if provided.
:param username: Username for devices that require authentication
:param password: Password for devices that require authentication
:param on_discovered_raw: Optional callback once discovered json is loaded
before any attempt to deserialize it and create devices
:param on_unsupported: Optional callback when unsupported devices are discovered
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
@@ -526,6 +568,7 @@ class Discover:
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
on_discovered_raw=on_discovered_raw,
),
local_addr=("0.0.0.0", 0), # noqa: S104
)
@@ -595,10 +638,12 @@ class Discover:
for encrypt in Device.EncryptionType
for device_family in main_device_families
for https in (True, False)
for login_version in (None, 2)
if (
conn_params := DeviceConnectionParameters(
device_family=device_family,
encryption_type=encrypt,
login_version=login_version,
https=https,
)
)
@@ -643,7 +688,11 @@ class Discover:
"""Find SmartDevice subclass for device described by passed data."""
if "result" in info:
discovery_result = DiscoveryResult.from_dict(info["result"])
https = discovery_result.mgt_encrypt_schm.is_support_https
https = (
discovery_result.mgt_encrypt_schm.is_support_https
if discovery_result.mgt_encrypt_schm
else False
)
dev_class = get_device_class_from_family(
discovery_result.device_type, https=https
)
@@ -657,27 +706,36 @@ class Discover:
return get_device_class_from_sys_info(info)
@staticmethod
def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> IotDevice:
"""Get SmartDevice from legacy 9999 response."""
def _get_discovery_json_legacy(data: bytes, ip: str) -> dict:
"""Get discovery json from legacy 9999 response."""
try:
info = json_loads(XorEncryption.decrypt(data))
except Exception as ex:
raise KasaException(
f"Unable to read response from device: {config.host}: {ex}"
f"Unable to read response from device: {ip}: {ex}"
) from ex
return info
@staticmethod
def _get_device_instance_legacy(info: dict, config: DeviceConfig) -> Device:
"""Get IotDevice from legacy 9999 response."""
if _LOGGER.isEnabledFor(logging.DEBUG):
data = redact_data(info, IOT_REDACTORS) if Discover._redact_data else info
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
device_class = cast(type[IotDevice], Discover._get_device_class(info))
device = device_class(config.host, config=config)
sys_info = info["system"]["get_sysinfo"]
if device_type := sys_info.get("mic_type", sys_info.get("type")):
config.connection_type = DeviceConnectionParameters.from_values(
device_family=device_type,
encryption_type=DeviceEncryptionType.Xor.value,
)
sys_info = _extract_sys_info(info)
device_type = sys_info.get("mic_type", sys_info.get("type"))
login_version = (
sys_info.get("stream_version") if device_type == "IOT.IPCAMERA" else None
)
config.connection_type = DeviceConnectionParameters.from_values(
device_family=device_type,
encryption_type=DeviceEncryptionType.Xor.value,
https=device_type == "IOT.IPCAMERA",
login_version=login_version,
)
device.protocol = get_protocol(config) # type: ignore[assignment]
device.update_from_discover_info(info)
return device
@@ -701,20 +759,25 @@ class Discover:
discovery_result.decrypted_data = json_loads(decrypted_data)
@staticmethod
def _get_discovery_json(data: bytes, ip: str) -> dict:
"""Get discovery json from the new 20002 response."""
try:
info = json_loads(data[16:])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", ip, data)
raise KasaException(
f"Unable to read response from device: {ip}: {ex}"
) from ex
return info
@staticmethod
def _get_device_instance(
data: bytes,
info: dict,
config: DeviceConfig,
) -> Device:
"""Get SmartDevice from the new 20002 response."""
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
try:
info = json_loads(data[16:])
except Exception as ex:
_LOGGER.debug("Got invalid response from device %s: %s", config.host, data)
raise KasaException(
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult.from_dict(info["result"])
@@ -743,11 +806,19 @@ class Discover:
Discover._decrypt_discovery_data(discovery_result)
except Exception:
_LOGGER.exception(
"Unable to decrypt discovery data %s: %s", config.host, data
"Unable to decrypt discovery data %s: %s",
config.host,
redact_data(info, NEW_DISCOVERY_REDACTORS),
)
type_ = discovery_result.device_type
encrypt_schm = discovery_result.mgt_encrypt_schm
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=config.host,
)
try:
if not (encrypt_type := encrypt_schm.encrypt_type) and (
@@ -755,6 +826,13 @@ class Discover:
):
encrypt_type = encrypt_info.sym_schm
if (
not (login_version := encrypt_schm.lv)
and (et := discovery_result.encrypt_type)
and et == ["3"]
):
login_version = 2
if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
@@ -765,13 +843,13 @@ class Discover:
config.connection_type = DeviceConnectionParameters.from_values(
type_,
encrypt_type,
discovery_result.mgt_encrypt_schm.lv,
discovery_result.mgt_encrypt_schm.is_support_https,
login_version,
encrypt_schm.is_support_https,
)
except KasaException as ex:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
+ f"with encrypt_type {encrypt_schm.encrypt_type}",
discovery_result=discovery_result.to_dict(),
host=config.host,
) from ex
@@ -854,7 +932,7 @@ class DiscoveryResult(_DiscoveryBaseMixin):
device_id: str
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
mgt_encrypt_schm: EncryptionScheme | None = None
device_name: str | None = None
encrypt_info: EncryptionInfo | None = None
encrypt_type: list[str] | None = None

View File

@@ -24,7 +24,6 @@ State (state): True
Signal Level (signal_level): 2
RSSI (rssi): -52
SSID (ssid): #MASKED_SSID#
Overheated (overheated): False
Reboot (reboot): <Action>
Brightness (brightness): 100
Cloud connection (cloud_connection): True
@@ -39,6 +38,7 @@ Light effect (light_effect): Off
Light preset (light_preset): Not set
Smooth transition on (smooth_transition_on): 2
Smooth transition off (smooth_transition_off): 2
Overheated (overheated): False
Device time (device_time): 2024-02-23 02:40:15+01:00
To see whether a device supports a feature, check for the existence of it:

View File

@@ -6,6 +6,7 @@ from .led import Led
from .light import Light, LightState
from .lighteffect import LightEffect
from .lightpreset import LightPreset
from .thermostat import Thermostat, ThermostatState
from .time import Time
__all__ = [
@@ -16,5 +17,7 @@ __all__ = [
"LightEffect",
"LightState",
"LightPreset",
"Thermostat",
"ThermostatState",
"Time",
]

View File

@@ -0,0 +1,65 @@
"""Interact with a TPLink Thermostat."""
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Annotated, Literal
from ..module import FeatureAttribute, Module
class ThermostatState(Enum):
"""Thermostat state."""
Heating = "heating"
Calibrating = "progress_calibration"
Idle = "idle"
Off = "off"
Unknown = "unknown"
class Thermostat(Module, ABC):
"""Base class for TP-Link Thermostat."""
@property
@abstractmethod
def state(self) -> bool:
"""Return thermostat state."""
@abstractmethod
async def set_state(self, enabled: bool) -> dict:
"""Set thermostat state."""
@property
@abstractmethod
def mode(self) -> ThermostatState:
"""Return thermostat state."""
@property
@abstractmethod
def target_temperature(self) -> Annotated[float, FeatureAttribute()]:
"""Return target temperature."""
@abstractmethod
async def set_target_temperature(
self, target: float
) -> Annotated[dict, FeatureAttribute()]:
"""Set target temperature."""
@property
@abstractmethod
def temperature(self) -> Annotated[float, FeatureAttribute()]:
"""Return current humidity in percentage."""
return self._device.sys_info["current_temp"]
@property
@abstractmethod
def temperature_unit(self) -> Literal["celsius", "fahrenheit"]:
"""Return current temperature unit."""
@abstractmethod
async def set_temperature_unit(
self, unit: Literal["celsius", "fahrenheit"]
) -> dict:
"""Set the device temperature unit."""

View File

@@ -1,6 +1,7 @@
"""Package for supporting legacy kasa devices."""
from .iotbulb import IotBulb
from .iotcamera import IotCamera
from .iotdevice import IotDevice
from .iotdimmer import IotDimmer
from .iotlightstrip import IotLightStrip
@@ -15,4 +16,5 @@ __all__ = [
"IotDimmer",
"IotLightStrip",
"IotWallSwitch",
"IotCamera",
]

42
kasa/iot/iotcamera.py Normal file
View File

@@ -0,0 +1,42 @@
"""Module for cameras."""
from __future__ import annotations
import logging
from datetime import datetime, tzinfo
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
from ..protocols import BaseProtocol
from .iotdevice import IotDevice
_LOGGER = logging.getLogger(__name__)
class IotCamera(IotDevice):
"""Representation of a TP-Link Camera."""
def __init__(
self,
host: str,
*,
config: DeviceConfig | None = None,
protocol: BaseProtocol | None = None,
) -> None:
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.Camera
@property
def time(self) -> datetime:
"""Get the camera's time."""
return datetime.fromtimestamp(self.sys_info["system_time"])
@property
def timezone(self) -> tzinfo:
"""Get the camera's timezone."""
return None # type: ignore
@property # type: ignore
def is_on(self) -> bool:
"""Return whether device is on."""
return True

View File

@@ -70,6 +70,16 @@ def _parse_features(features: str) -> set[str]:
return set(features.split(":"))
def _extract_sys_info(info: dict[str, Any]) -> dict[str, Any]:
"""Return the system info structure."""
sysinfo_default = info.get("system", {}).get("get_sysinfo", {})
sysinfo_nest = sysinfo_default.get("system", {})
if len(sysinfo_nest) > len(sysinfo_default) and isinstance(sysinfo_nest, dict):
return sysinfo_nest
return sysinfo_default
class IotDevice(Device):
"""Base class for all supported device types.
@@ -304,14 +314,14 @@ class IotDevice(Device):
_LOGGER.debug("Performing the initial update to obtain sysinfo")
response = await self.protocol.query(req)
self._last_update = response
self._set_sys_info(response["system"]["get_sysinfo"])
self._set_sys_info(_extract_sys_info(response))
if not self._modules:
await self._initialize_modules()
await self._modular_update(req)
self._set_sys_info(self._last_update["system"]["get_sysinfo"])
self._set_sys_info(_extract_sys_info(self._last_update))
for module in self._modules.values():
await module._post_update_hook()
@@ -705,10 +715,13 @@ class IotDevice(Device):
@staticmethod
def _get_device_type_from_sys_info(info: dict[str, Any]) -> DeviceType:
"""Find SmartDevice subclass for device described by passed data."""
if "system" in info.get("system", {}).get("get_sysinfo", {}):
return DeviceType.Camera
if "system" not in info or "get_sysinfo" not in info["system"]:
raise KasaException("No 'system' or 'get_sysinfo' in response")
sysinfo: dict[str, Any] = info["system"]["get_sysinfo"]
sysinfo: dict[str, Any] = _extract_sys_info(info)
type_: str | None = sysinfo.get("type", sysinfo.get("mic_type"))
if type_ is None:
raise KasaException("Unable to find the device type field!")
@@ -728,6 +741,7 @@ class IotDevice(Device):
return DeviceType.LightStrip
return DeviceType.Bulb
_LOGGER.warning("Unknown device type %s, falling back to plug", type_)
return DeviceType.Plug
@@ -736,7 +750,7 @@ class IotDevice(Device):
info: dict[str, Any], discovery_info: dict[str, Any] | None
) -> _DeviceInfo:
"""Get model information for a device."""
sys_info = info["system"]["get_sysinfo"]
sys_info = _extract_sys_info(info)
# Get model and region info
region = None

View File

@@ -8,18 +8,24 @@ from typing import Any
try:
import orjson
def dumps(obj: Any, *, default: Callable | None = None) -> str:
def dumps(
obj: Any, *, default: Callable | None = None, indent: bool = False
) -> str:
"""Dump JSON."""
return orjson.dumps(obj).decode()
return orjson.dumps(
obj, option=orjson.OPT_INDENT_2 if indent else None
).decode()
loads = orjson.loads
except ImportError:
import json
def dumps(obj: Any, *, default: Callable | None = None) -> str:
def dumps(
obj: Any, *, default: Callable | None = None, indent: bool = False
) -> str:
"""Dump JSON."""
# Separators specified for consistency with orjson
return json.dumps(obj, separators=(",", ":"))
return json.dumps(obj, separators=(",", ":"), indent=2 if indent else None)
loads = json.loads

View File

@@ -14,9 +14,17 @@ Light, AutoOff, Firmware etc.
>>> print(dev.alias)
Living Room Bulb
To see whether a device supports functionality check for the existence of the module:
To see whether a device supports a group of functionality
check for the existence of the module:
>>> if light := dev.modules.get("Light"):
>>> print(light.brightness)
100
To see whether a device supports specific functionality, you can check whether the
module has that feature:
>>> if light.has_feature("hsv"):
>>> print(light.hsv)
HSV(hue=0, saturation=100, value=100)
@@ -70,6 +78,9 @@ ModuleT = TypeVar("ModuleT", bound="Module")
class FeatureAttribute:
"""Class for annotating attributes bound to feature."""
def __repr__(self) -> str:
return "FeatureAttribute"
class Module(ABC):
"""Base class implemention for all modules.
@@ -85,6 +96,7 @@ class Module(ABC):
Led: Final[ModuleName[interfaces.Led]] = ModuleName("Led")
Light: Final[ModuleName[interfaces.Light]] = ModuleName("Light")
LightPreset: Final[ModuleName[interfaces.LightPreset]] = ModuleName("LightPreset")
Thermostat: Final[ModuleName[interfaces.Thermostat]] = ModuleName("Thermostat")
Time: Final[ModuleName[interfaces.Time]] = ModuleName("Time")
# IOT only Modules

View File

@@ -24,9 +24,11 @@ from .lightpreset import LightPreset
from .lightstripeffect import LightStripEffect
from .lighttransition import LightTransition
from .motionsensor import MotionSensor
from .overheatprotection import OverheatProtection
from .reportmode import ReportMode
from .temperaturecontrol import TemperatureControl
from .temperaturesensor import TemperatureSensor
from .thermostat import Thermostat
from .time import Time
from .triggerlogs import TriggerLogs
from .waterleaksensor import WaterleakSensor
@@ -61,5 +63,7 @@ __all__ = [
"MotionSensor",
"TriggerLogs",
"FrostProtection",
"Thermostat",
"SmartLightEffect",
"OverheatProtection",
]

View File

@@ -10,7 +10,7 @@ class ContactSensor(SmartModule):
"""Implementation of contact sensor module."""
REQUIRED_COMPONENT = None # we depend on availability of key
REQUIRED_KEY_ON_PARENT = "open"
SYSINFO_LOOKUP_KEYS = ["open"]
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""

View File

@@ -75,8 +75,12 @@ class Energy(SmartModule, EnergyInterface):
async def get_status(self) -> EmeterStatus:
"""Return real-time statistics."""
res = await self.call("get_energy_usage")
return self._get_status_from_energy(res["get_energy_usage"])
if "get_emeter_data" in self.data:
res = await self.call("get_emeter_data")
return EmeterStatus(res["get_emeter_data"])
else:
res = await self.call("get_energy_usage")
return self._get_status_from_energy(res["get_energy_usage"])
@property
@raise_if_update_error

View File

@@ -24,6 +24,7 @@ class LightTransition(SmartModule):
REQUIRED_COMPONENT = "on_off_gradually"
QUERY_GETTER_NAME = "get_on_off_gradually_info"
MINIMUM_UPDATE_INTERVAL_SECS = 60
# v3 added max_duration, we default to 60 when it's not available
MAXIMUM_DURATION = 60
# Key in sysinfo that indicates state can be retrieved from there.
@@ -144,10 +145,22 @@ class LightTransition(SmartModule):
return await self.call("set_on_off_gradually_info", {"enable": enable})
else:
on = await self.call(
"set_on_off_gradually_info", {"on_state": {"enable": enable}}
"set_on_off_gradually_info",
{
"on_state": {
"enable": enable,
"duration": self._on_state["duration"],
}
},
)
off = await self.call(
"set_on_off_gradually_info", {"off_state": {"enable": enable}}
"set_on_off_gradually_info",
{
"off_state": {
"enable": enable,
"duration": self._off_state["duration"],
}
},
)
return {**on, **off}
@@ -167,7 +180,6 @@ class LightTransition(SmartModule):
@property
def _turn_on_transition_max(self) -> int:
"""Maximum turn on duration."""
# v3 added max_duration, we default to 60 when it's not available
return self._on_state["max_duration"]
@allow_update_after
@@ -184,7 +196,7 @@ class LightTransition(SmartModule):
if seconds <= 0:
return await self.call(
"set_on_off_gradually_info",
{"on_state": {"enable": False}},
{"on_state": {"enable": False, "duration": self._on_state["duration"]}},
)
return await self.call(
@@ -220,7 +232,12 @@ class LightTransition(SmartModule):
if seconds <= 0:
return await self.call(
"set_on_off_gradually_info",
{"off_state": {"enable": False}},
{
"off_state": {
"enable": False,
"duration": self._off_state["duration"],
}
},
)
return await self.call(

View File

@@ -0,0 +1,41 @@
"""Overheat module."""
from __future__ import annotations
from ...feature import Feature
from ..smartmodule import SmartModule
class OverheatProtection(SmartModule):
"""Implementation for overheat_protection."""
SYSINFO_LOOKUP_KEYS = ["overheated", "overheat_status"]
def _initialize_features(self) -> None:
"""Initialize features after the initial update."""
self._add_feature(
Feature(
self._device,
container=self,
id="overheated",
name="Overheated",
attribute_getter="overheated",
icon="mdi:heat-wave",
type=Feature.Type.BinarySensor,
category=Feature.Category.Info,
)
)
@property
def overheated(self) -> bool:
"""Return True if device reports overheating."""
if (value := self._device.sys_info.get("overheat_status")) is not None:
# Value can be normal, cooldown, or overheated.
# We report all but normal as overheated.
return value != "normal"
return self._device.sys_info["overheated"]
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}

View File

@@ -3,24 +3,14 @@
from __future__ import annotations
import logging
from enum import Enum
from ...feature import Feature
from ...interfaces.thermostat import ThermostatState
from ..smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__)
class ThermostatState(Enum):
"""Thermostat state."""
Heating = "heating"
Calibrating = "progress_calibration"
Idle = "idle"
Off = "off"
Unknown = "unknown"
class TemperatureControl(SmartModule):
"""Implementation of temperature module."""
@@ -56,7 +46,6 @@ class TemperatureControl(SmartModule):
category=Feature.Category.Config,
)
)
self._add_feature(
Feature(
self._device,
@@ -69,7 +58,6 @@ class TemperatureControl(SmartModule):
type=Feature.Type.Switch,
)
)
self._add_feature(
Feature(
self._device,

View File

@@ -0,0 +1,74 @@
"""Module for a Thermostat."""
from __future__ import annotations
from typing import Annotated, Literal
from ...feature import Feature
from ...interfaces.thermostat import Thermostat as ThermostatInterface
from ...interfaces.thermostat import ThermostatState
from ...module import FeatureAttribute, Module
from ..smartmodule import SmartModule
class Thermostat(SmartModule, ThermostatInterface):
"""Implementation of a Thermostat."""
@property
def _all_features(self) -> dict[str, Feature]:
"""Get the features for this module and any sub modules."""
ret: dict[str, Feature] = {}
if temp_control := self._device.modules.get(Module.TemperatureControl):
ret.update(**temp_control._module_features)
if temp_sensor := self._device.modules.get(Module.TemperatureSensor):
ret.update(**temp_sensor._module_features)
return ret
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def state(self) -> bool:
"""Return thermostat state."""
return self._device.modules[Module.TemperatureControl].state
async def set_state(self, enabled: bool) -> dict:
"""Set thermostat state."""
return await self._device.modules[Module.TemperatureControl].set_state(enabled)
@property
def mode(self) -> ThermostatState:
"""Return thermostat state."""
return self._device.modules[Module.TemperatureControl].mode
@property
def target_temperature(self) -> Annotated[float, FeatureAttribute()]:
"""Return target temperature."""
return self._device.modules[Module.TemperatureControl].target_temperature
async def set_target_temperature(
self, target: float
) -> Annotated[dict, FeatureAttribute()]:
"""Set target temperature."""
return await self._device.modules[
Module.TemperatureControl
].set_target_temperature(target)
@property
def temperature(self) -> Annotated[float, FeatureAttribute()]:
"""Return current humidity in percentage."""
return self._device.modules[Module.TemperatureSensor].temperature
@property
def temperature_unit(self) -> Literal["celsius", "fahrenheit"]:
"""Return current temperature unit."""
return self._device.modules[Module.TemperatureSensor].temperature_unit
async def set_temperature_unit(
self, unit: Literal["celsius", "fahrenheit"]
) -> dict:
"""Set the device temperature unit."""
return await self._device.modules[
Module.TemperatureSensor
].set_temperature_unit(unit)

View File

@@ -24,6 +24,7 @@ from .modules import (
DeviceModule,
Firmware,
Light,
Thermostat,
Time,
)
from .smartmodule import SmartModule
@@ -166,7 +167,14 @@ class SmartDevice(Device):
self._last_update, "get_child_device_list", {}
):
for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info)
child_id = info["device_id"]
if child_id not in self._children:
_LOGGER.debug(
"Skipping child update for %s, probably unsupported device",
child_id,
)
continue
self._children[child_id]._update_internal_state(info)
def _update_internal_info(self, info_resp: dict) -> None:
"""Update the internal device info."""
@@ -341,9 +349,8 @@ class SmartDevice(Device):
) or mod.__name__ in child_modules_to_skip:
continue
required_component = cast(str, mod.REQUIRED_COMPONENT)
if required_component in self._components or (
mod.REQUIRED_KEY_ON_PARENT
and self.sys_info.get(mod.REQUIRED_KEY_ON_PARENT) is not None
if required_component in self._components or any(
self.sys_info.get(key) is not None for key in mod.SYSINFO_LOOKUP_KEYS
):
_LOGGER.debug(
"Device %s, found required %s, adding %s to modules.",
@@ -361,6 +368,11 @@ class SmartDevice(Device):
or Module.ColorTemperature in self._modules
):
self._modules[Light.__name__] = Light(self, "light")
if (
Module.TemperatureControl in self._modules
and Module.TemperatureSensor in self._modules
):
self._modules[Thermostat.__name__] = Thermostat(self, "thermostat")
async def _initialize_features(self) -> None:
"""Initialize device features."""
@@ -427,19 +439,6 @@ class SmartDevice(Device):
)
)
if "overheated" in self._info:
self._add_feature(
Feature(
self,
id="overheated",
name="Overheated",
attribute_getter=lambda x: x._info["overheated"],
icon="mdi:heat-wave",
type=Feature.Type.BinarySensor,
category=Feature.Category.Info,
)
)
# We check for the key available, and not for the property truthiness,
# as the value is falsy when the device is off.
if "on_time" in self._info:
@@ -759,10 +758,11 @@ class SmartDevice(Device):
if self._device_type is not DeviceType.Unknown:
return self._device_type
# Fallback to device_type (from disco info)
type_str = self._info.get("type", self._info.get("device_type"))
if not type_str: # no update or discovery info
if (
not (type_str := self._info.get("type", self._info.get("device_type")))
or not self._components
):
# no update or discovery info
return self._device_type
self._device_type = self._get_device_type_from_components(
@@ -796,6 +796,8 @@ class SmartDevice(Device):
return DeviceType.Sensor
if "ENERGY" in device_type:
return DeviceType.Thermostat
if "ROBOVAC" in device_type:
return DeviceType.Vacuum
_LOGGER.warning("Unknown device type, falling back to plug")
return DeviceType.Plug

View File

@@ -54,8 +54,8 @@ class SmartModule(Module):
NAME: str
#: Module is initialized, if the given component is available
REQUIRED_COMPONENT: str | None = None
#: Module is initialized, if the given key available in the main sysinfo
REQUIRED_KEY_ON_PARENT: str | None = None
#: Module is initialized, if any of the given keys exists in the sysinfo
SYSINFO_LOOKUP_KEYS: list[str] = []
#: Query to execute during the main update cycle
QUERY_GETTER_NAME: str

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import base64
import logging
from enum import StrEnum
from urllib.parse import quote_plus
from ...credentials import Credentials
@@ -15,6 +16,14 @@ from ..smartcammodule import SmartCamModule
_LOGGER = logging.getLogger(__name__)
LOCAL_STREAMING_PORT = 554
ONVIF_PORT = 2020
class StreamResolution(StrEnum):
"""Class for stream resolution."""
HD = "HD"
SD = "SD"
class Camera(SmartCamModule):
@@ -64,7 +73,12 @@ class Camera(SmartCamModule):
return None
def stream_rtsp_url(self, credentials: Credentials | None = None) -> str | None:
def stream_rtsp_url(
self,
credentials: Credentials | None = None,
*,
stream_resolution: StreamResolution = StreamResolution.HD,
) -> str | None:
"""Return the local rtsp streaming url.
:param credentials: Credentials for camera account.
@@ -73,17 +87,27 @@ class Camera(SmartCamModule):
:return: rtsp url with escaped credentials or None if no credentials or
camera is off.
"""
if not self.is_on:
streams = {
StreamResolution.HD: "stream1",
StreamResolution.SD: "stream2",
}
if (stream := streams.get(stream_resolution)) is None:
return None
dev = self._device
if not credentials:
credentials = self._get_credentials()
if not credentials or not credentials.username or not credentials.password:
return None
username = quote_plus(credentials.username)
password = quote_plus(credentials.password)
return f"rtsp://{username}:{password}@{dev.host}:{LOCAL_STREAMING_PORT}/stream1"
return f"rtsp://{username}:{password}@{self._device.host}:{LOCAL_STREAMING_PORT}/{stream}"
def onvif_url(self) -> str | None:
"""Return the onvif url."""
return f"http://{self._device.host}:{ONVIF_PORT}/onvif/device_service"
async def set_state(self, on: bool) -> dict:
"""Set the device state."""

View File

@@ -68,7 +68,14 @@ class SmartCamDevice(SmartDevice):
self._last_update, "getChildDeviceList", {}
):
for info in child_info["child_device_list"]:
self._children[info["device_id"]]._update_internal_state(info)
child_id = info["device_id"]
if child_id not in self._children:
_LOGGER.debug(
"Skipping child update for %s, probably unsupported device",
child_id,
)
continue
self._children[child_id]._update_internal_state(info)
async def _initialize_smart_child(
self, info: dict, child_components: dict
@@ -100,20 +107,29 @@ class SmartCamDevice(SmartDevice):
resp = await self.protocol.query(child_info_query)
self.internal_state.update(resp)
children_components = {
smart_children_components = {
child["device_id"]: {
comp["id"]: int(comp["ver_code"]) for comp in child["component_list"]
comp["id"]: int(comp["ver_code"]) for comp in component_list
}
for child in resp["getChildDeviceComponentList"]["child_component_list"]
if (component_list := child.get("component_list"))
# Child camera devices will have a different component schema so only
# extract smart values.
and (first_comp := next(iter(component_list), None))
and isinstance(first_comp, dict)
and "id" in first_comp
and "ver_code" in first_comp
}
children = {}
for info in resp["getChildDeviceList"]["child_device_list"]:
if (
category := info.get("category")
) and category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP:
child_id = info["device_id"]
(category := info.get("category"))
and category in SmartChildDevice.CHILD_DEVICE_TYPE_MAP
and (child_id := info.get("device_id"))
and (child_components := smart_children_components.get(child_id))
):
children[child_id] = await self._initialize_smart_child(
info, children_components[child_id]
info, child_components
)
else:
_LOGGER.debug("Child device type not supported: %s", info)
@@ -191,6 +207,7 @@ class SmartCamDevice(SmartDevice):
"mac": basic_info["mac"],
"hwId": basic_info.get("hw_id"),
"oem_id": basic_info["oem_id"],
"device_id": basic_info["dev_id"],
}
@property

View File

@@ -3,14 +3,18 @@
from .aestransport import AesEncyptionSession, AesTransport
from .basetransport import BaseTransport
from .klaptransport import KlapTransport, KlapTransportV2
from .linkietransport import LinkieTransportV2
from .ssltransport import SslTransport
from .xortransport import XorEncryption, XorTransport
__all__ = [
"AesTransport",
"AesEncyptionSession",
"SslTransport",
"BaseTransport",
"KlapTransport",
"KlapTransportV2",
"LinkieTransportV2",
"XorTransport",
"XorEncryption",
]

View File

@@ -0,0 +1,143 @@
"""Implementation of the linkie kasa camera transport."""
from __future__ import annotations
import asyncio
import base64
import logging
import ssl
from typing import TYPE_CHECKING, cast
from urllib.parse import quote
from yarl import URL
from kasa.credentials import DEFAULT_CREDENTIALS, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException, _RetryableError
from kasa.httpclient import HttpClient
from kasa.json import loads as json_loads
from kasa.transports.xortransport import XorEncryption
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)
class LinkieTransportV2(BaseTransport):
"""Implementation of the Linkie encryption protocol.
Linkie is used as the endpoint for TP-Link's camera encryption
protocol, used by newer firmware versions.
"""
DEFAULT_PORT: int = 10443
CIPHERS = ":".join(
[
"AES256-GCM-SHA384",
"AES256-SHA256",
"AES128-GCM-SHA256",
"AES128-SHA256",
"AES256-SHA",
]
)
def __init__(self, *, config: DeviceConfig) -> None:
super().__init__(config=config)
self._http_client = HttpClient(config)
self._ssl_context: ssl.SSLContext | None = None
self._app_url = URL(f"https://{self._host}:{self._port}/data/LINKIE2.json")
self._headers = {
"Authorization": f"Basic {self.credentials_hash}",
"Content-Type": "application/x-www-form-urlencoded",
}
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
creds = get_default_credentials(DEFAULT_CREDENTIALS["KASACAMERA"])
creds_combined = f"{creds.username}:{creds.password}"
return base64.b64encode(creds_combined.encode()).decode()
async def _execute_send(self, request: str) -> dict:
"""Execute a query on the device and wait for the response."""
_LOGGER.debug("%s >> %s", self._host, request)
encrypted_cmd = XorEncryption.encrypt(request)[4:]
b64_cmd = base64.b64encode(encrypted_cmd).decode()
url_safe_cmd = quote(b64_cmd, safe="!~*'()")
status_code, response = await self._http_client.post(
self._app_url,
headers=self._headers,
data=f"content={url_safe_cmd}".encode(),
ssl=await self._get_ssl_context(),
)
if TYPE_CHECKING:
response = cast(bytes, response)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
# Expected response
try:
json_payload: dict = json_loads(
XorEncryption.decrypt(base64.b64decode(response))
)
_LOGGER.debug("%s << %s", self._host, json_payload)
return json_payload
except Exception: # noqa: S110
pass
# Device returned error as json plaintext
to_raise: KasaException | None = None
try:
error_payload: dict = json_loads(response)
to_raise = KasaException(f"Device {self._host} send error: {error_payload}")
except Exception as ex:
raise KasaException("Unable to read response") from ex
raise to_raise
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self._http_client.close()
async def reset(self) -> None:
"""Reset the transport.
NOOP for this transport.
"""
async def send(self, request: str) -> dict:
"""Send a message to the device and return a response."""
try:
return await self._execute_send(request)
except Exception as ex:
await self.reset()
raise _RetryableError(
f"Unable to query the device {self._host}:{self._port}: {ex}"
) from ex
async def _get_ssl_context(self) -> ssl.SSLContext:
if not self._ssl_context:
loop = asyncio.get_running_loop()
self._ssl_context = await loop.run_in_executor(
None, self._create_ssl_context
)
return self._ssl_context
def _create_ssl_context(self) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.set_ciphers(self.CIPHERS)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context

View File

@@ -0,0 +1,233 @@
"""Implementation of the clear-text passthrough ssl transport.
This transport does not encrypt the passthrough payloads at all, but requires a login.
This has been seen on some devices (like robovacs).
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import time
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, cast
from yarl import URL
from kasa.credentials import DEFAULT_CREDENTIALS, Credentials, get_default_credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
_RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.transports import BaseTransport
_LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
def _md5_hash(payload: bytes) -> str:
return hashlib.md5(payload).hexdigest().upper() # noqa: S324
class TransportState(Enum):
"""Enum for transport state."""
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests
class SslTransport(BaseTransport):
"""Implementation of the cleartext transport protocol.
This transport uses HTTPS without any further payload encryption.
"""
DEFAULT_PORT: int = 4433
COMMON_HEADERS = {
"Content-Type": "application/json",
}
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
super().__init__(config=config)
if (
not self._credentials or self._credentials.username is None
) and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params(self._credentials)
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_credentials: Credentials | None = None
self._http_client: HttpClient = HttpClient(config)
self._state = TransportState.LOGIN_REQUIRED
self._session_expire_at: float | None = None
self._app_url = URL(f"https://{self._host}:{self._port}/app")
_LOGGER.debug("Created ssltransport for %s", self._host)
@property
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
def _get_login_params(self, credentials: Credentials) -> dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(credentials)
return {"password": pw, "username": un}
@staticmethod
def hash_credentials(credentials: Credentials) -> tuple[str, str]:
"""Hash the credentials."""
un = credentials.username
pw = _md5_hash(credentials.password.encode())
return un, pw
async def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
"""Handle response errors to request reauth etc."""
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_RETRYABLE_ERRORS:
raise _RetryableError(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
await self.reset()
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)
async def send_request(self, request: str) -> dict[str, Any]:
"""Send request."""
url = self._app_url
_LOGGER.debug("Sending %s to %s", request, url)
status_code, resp_dict = await self._http_client.post(
url,
json=request,
headers=self.COMMON_HEADERS,
)
if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
+ f"status code {status_code}"
)
_LOGGER.debug("Response with %s: %r", status_code, resp_dict)
await self._handle_response_error_code(resp_dict, "Error sending request")
if TYPE_CHECKING:
resp_dict = cast(dict[str, Any], resp_dict)
return resp_dict
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationError as aex:
try:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex
_LOGGER.debug("Login failed, going to try default credentials")
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
"%s: logged in with default credentials",
self._host,
)
except AuthenticationError:
raise
except Exception as ex:
raise KasaException(
"Unable to login and trying default "
+ f"login raised another exception: {ex}",
ex,
) from ex
async def try_login(self, login_params: dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
login_request = {
"method": "login",
"params": login_params,
}
request = json_dumps(login_request)
_LOGGER.debug("Going to send login request")
resp_dict = await self.send_request(request)
await self._handle_response_error_code(resp_dict, "Error logging in")
login_token = resp_dict["result"]["token"]
self._app_url = self._app_url.with_query(f"token={login_token}")
self._state = TransportState.ESTABLISHED
self._session_expire_at = (
time.time() + ONE_DAY_SECONDS - SESSION_EXPIRE_BUFFER_SECONDS
)
def _session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
or self._session_expire_at - time.time() <= 0
)
async def send(self, request: str) -> dict[str, Any]:
"""Send the request."""
_LOGGER.info("Going to send %s", request)
if self._state is not TransportState.ESTABLISHED or self._session_expired():
_LOGGER.debug("Transport not established or session expired, logging in")
await self.perform_login()
return await self.send_request(request)
async def close(self) -> None:
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()
async def reset(self) -> None:
"""Reset internal login state."""
self._state = TransportState.LOGIN_REQUIRED
self._app_url = URL(f"https://{self._host}:{self._port}/app")