Merge branch 'master' into experimental

This commit is contained in:
Steven B.
2024-11-05 13:36:34 +00:00
committed by GitHub
17 changed files with 422 additions and 224 deletions

View File

@@ -41,7 +41,7 @@ from kasa.iotprotocol import (
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
)
from kasa.module import Module
from kasa.protocol import BaseProtocol
from kasa.protocol import BaseProtocol, BaseTransport
from kasa.smartprotocol import SmartProtocol
__version__ = version("python-kasa")
@@ -50,6 +50,7 @@ __version__ = version("python-kasa")
__all__ = [
"Discover",
"BaseProtocol",
"BaseTransport",
"IotProtocol",
"SmartProtocol",
"LightState",

View File

@@ -15,7 +15,7 @@ from kasa import (
Discover,
UnsupportedDeviceError,
)
from kasa.discover import DiscoveryResult
from kasa.discover import ConnectAttempt, DiscoveryResult
from .common import echo, error
@@ -165,8 +165,17 @@ async def config(ctx):
credentials = Credentials(username, password) if username and password else None
host_port = host + (f":{port}" if port else "")
def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
prot, tran, dev = connect_attempt
key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}"
result = "succeeded" if success else "failed"
msg = f"Attempt to connect to {host_port} with {key_str} {result}"
echo(msg)
dev = await Discover.try_connect_all(
host, credentials=credentials, timeout=timeout, port=port
host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt
)
if dev:
cparams = dev.config.connection_type

View File

@@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:
def get_device_class_from_family(
device_type: str, *, https: bool
device_type: str, *, https: bool, require_exact: bool = False
) -> type[Device] | None:
"""Return the device class from the type name."""
supported_device_types: dict[str, type[Device]] = {
@@ -185,8 +185,10 @@ def get_device_class_from_family(
}
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
if (
cls := supported_device_types.get(lookup_key)
) is None and device_type.startswith("SMART."):
(cls := supported_device_types.get(lookup_key)) is None
and device_type.startswith("SMART.")
and not require_exact
):
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
cls = SmartDevice

View File

@@ -91,7 +91,7 @@ import socket
import struct
from collections.abc import Awaitable
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
from aiohttp import ClientSession
@@ -118,6 +118,7 @@ from kasa.exceptions import (
TimeoutError,
UnsupportedDeviceError,
)
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import dumps as json_dumps
@@ -127,9 +128,21 @@ from kasa.xortransport import XorEncryption
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from kasa import BaseProtocol, BaseTransport
class ConnectAttempt(NamedTuple):
"""Try to connect attempt."""
protocol: type
transport: type
device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device]
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
@@ -535,6 +548,7 @@ class Discover:
timeout: int | None = None,
credentials: Credentials | None = None,
http_client: ClientSession | None = None,
on_attempt: OnConnectAttemptCallable | None = None,
) -> Device | None:
"""Try to connect directly to a device with all possible parameters.
@@ -551,13 +565,22 @@ class Discover:
"""
from .device_factory import _connect
candidates = {
main_device_families = {
Device.Family.SmartTapoPlug,
Device.Family.IotSmartPlugSwitch,
}
if Experimental.enabled():
main_device_families.add(Device.Family.SmartIpCamera)
candidates: dict[
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
tuple[BaseProtocol, DeviceConfig],
] = {
(type(protocol), type(protocol._transport), device_class): (
protocol,
config,
)
for encrypt in Device.EncryptionType
for device_family in Device.Family
for device_family in main_device_families
for https in (True, False)
if (
conn_params := DeviceConnectionParameters(
@@ -580,19 +603,26 @@ class Discover:
and (protocol := get_protocol(config))
and (
device_class := get_device_class_from_family(
device_family.value, https=https
device_family.value, https=https, require_exact=True
)
)
}
for protocol, config in candidates.values():
for key, val in candidates.items():
try:
dev = await _connect(config, protocol)
prot, config = val
dev = await _connect(config, prot)
except Exception:
_LOGGER.debug("Unable to connect with %s", protocol)
_LOGGER.debug("Unable to connect with %s", prot)
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, False)
else:
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, True)
return dev
finally:
await protocol.close()
await prot.close()
return None
@staticmethod

View File

@@ -9,7 +9,7 @@ from ..deviceconfig import DeviceConfig
from ..module import Module
from ..protocol import BaseProtocol
from .iotdevice import IotDevice, requires_update
from .modules import Antitheft, Cloud, Led, Schedule, Time, Usage
from .modules import AmbientLight, Antitheft, Cloud, Led, Motion, Schedule, Time, Usage
_LOGGER = logging.getLogger(__name__)
@@ -92,3 +92,12 @@ class IotWallSwitch(IotPlug):
) -> None:
super().__init__(host=host, config=config, protocol=protocol)
self._device_type = DeviceType.WallSwitch
async def _initialize_modules(self) -> None:
"""Initialize modules."""
await super()._initialize_modules()
if (dev_name := self.sys_info["dev_name"]) and "PIR" in dev_name:
self.add_module(Module.IotMotion, Motion(self, "smartlife.iot.PIR"))
self.add_module(
Module.IotAmbientLight, AmbientLight(self, "smartlife.iot.LAS")
)

View File

@@ -127,6 +127,9 @@ class Module(ABC):
WaterleakSensor: Final[ModuleName[smart.WaterleakSensor]] = ModuleName(
"WaterleakSensor"
)
ChildProtection: Final[ModuleName[smart.ChildProtection]] = ModuleName(
"ChildProtection"
)
TriggerLogs: Final[ModuleName[smart.TriggerLogs]] = ModuleName("TriggerLogs")
# SMARTCAMERA only modules

View File

@@ -6,6 +6,7 @@ from .autooff import AutoOff
from .batterysensor import BatterySensor
from .brightness import Brightness
from .childdevice import ChildDevice
from .childprotection import ChildProtection
from .cloud import Cloud
from .color import Color
from .colortemperature import ColorTemperature
@@ -40,6 +41,7 @@ __all__ = [
"HumiditySensor",
"TemperatureSensor",
"TemperatureControl",
"ChildProtection",
"ReportMode",
"AutoOff",
"Led",

View File

@@ -0,0 +1,41 @@
"""Child lock module."""
from __future__ import annotations
from ...feature import Feature
from ..smartmodule import SmartModule
class ChildProtection(SmartModule):
"""Implementation for child_protection."""
REQUIRED_COMPONENT = "child_protection"
QUERY_GETTER_NAME = "get_child_protection"
def _initialize_features(self):
"""Initialize features after the initial update."""
self._add_feature(
Feature(
device=self._device,
id="child_lock",
name="Child lock",
container=self,
attribute_getter="enabled",
attribute_setter="set_enabled",
type=Feature.Type.Switch,
category=Feature.Category.Config,
)
)
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {}
@property
def enabled(self) -> bool:
"""Return True if child protection is enabled."""
return self.data["child_protection"]
async def set_enabled(self, enabled: bool) -> dict:
"""Set child protection."""
return await self.call("set_child_protection", {"enable": enabled})

View File

@@ -430,6 +430,16 @@ class FakeSmartTransport(BaseTransport):
info["get_preset_rules"]["states"][params["index"]] = params["state"]
return {"error_code": 0}
def _update_sysinfo_key(self, info: dict, key: str, value: str) -> dict:
"""Update a single key in the main system info.
This is used to implement child device setters that change the main sysinfo state.
"""
sys_info = info.get("get_device_info", info)
sys_info[key] = value
return {"error_code": 0}
async def _send_request(self, request_dict: dict):
method = request_dict["method"]
@@ -437,7 +447,7 @@ class FakeSmartTransport(BaseTransport):
if method == "control_child":
return await self._handle_control_child(request_dict["params"])
params = request_dict.get("params")
params = request_dict.get("params", {})
if method == "component_nego" or method[:4] == "get_":
if method in info:
result = copy.deepcopy(info[method])
@@ -518,6 +528,8 @@ class FakeSmartTransport(BaseTransport):
return self._edit_preset_rules(info, params)
elif method == "set_on_off_gradually_info":
return self._set_on_off_gradually_info(info, params)
elif method == "set_child_protection":
return self._update_sysinfo_key(info, "child_protection", params["enable"])
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
info[target_method].update(params)

View File

@@ -0,0 +1,9 @@
from kasa.tests.device_fixtures import wallswitch_iot
@wallswitch_iot
def test_wallswitch_motion(dev):
"""Check that wallswitches with motion sensor get modules enabled."""
has_motion = "PIR" in dev.sys_info["dev_name"]
assert "motion" in dev.modules if has_motion else True
assert "ambient" in dev.modules if has_motion else True

View File

@@ -0,0 +1,43 @@
import pytest
from kasa import Module
from kasa.smart.modules import ChildProtection
from kasa.tests.device_fixtures import parametrize
child_protection = parametrize(
"has child protection",
component_filter="child_protection",
protocol_filter={"SMART.CHILD"},
)
@child_protection
@pytest.mark.parametrize(
("feature", "prop_name", "type"),
[
("child_lock", "enabled", bool),
],
)
async def test_features(dev, feature, prop_name, type):
"""Test that features are registered and work as expected."""
protect: ChildProtection = dev.modules[Module.ChildProtection]
assert protect is not None
prop = getattr(protect, prop_name)
assert isinstance(prop, type)
feat = protect._device.features[feature]
assert feat.value == prop
assert isinstance(feat.value, type)
@child_protection
async def test_enabled(dev):
"""Test the API."""
protect: ChildProtection = dev.modules[Module.ChildProtection]
assert protect is not None
assert isinstance(protect.enabled, bool)
await protect.set_enabled(False)
await dev.update()
assert protect.enabled is False

View File

@@ -1162,7 +1162,7 @@ async def test_cli_child_commands(
async def test_discover_config(dev: Device, mocker, runner):
"""Test that device config is returned."""
host = "127.0.0.1"
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev)
mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev])
res = await runner.invoke(
cli,
@@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner):
cparam = dev.config.connection_type
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
assert expected in res.output
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed",
res.output.replace("\n", ""),
)
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded",
res.output.replace("\n", ""),
)
async def test_discover_config_invalid(mocker, runner):