diff --git a/docs/source/design.rst b/docs/source/design.rst index 8acbfea6..5679943d 100644 --- a/docs/source/design.rst +++ b/docs/source/design.rst @@ -12,6 +12,21 @@ or if you are just looking to access some information that is not currently expo .. contents:: Contents :local: +.. _initialization: + +Initialization +************** + +Use :func:`~kasa.Discover.discover` to perform udp-based broadcast discovery on the network. +This will return you a list of device instances based on the discovery replies. + +If the device's host is already known, you can use to construct a device instance with +:meth:`~kasa.SmartDevice.connect()`. + +When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to +pass the device type as well as this allows the library to use the correct device class for the +device without having to query the device. + .. _update_cycle: Update Cycle diff --git a/kasa/cli.py b/kasa/cli.py index 7280dd33..e71c7b9f 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -13,14 +13,13 @@ import asyncclick as click from kasa import ( AuthenticationException, Credentials, + DeviceType, Discover, SmartBulb, SmartDevice, - SmartDimmer, - SmartLightStrip, - SmartPlug, SmartStrip, ) +from kasa.device_factory import DEVICE_TYPE_TO_CLASS try: from rich import print as _do_echo @@ -43,13 +42,11 @@ except ImportError: # --json has set it to _nop_echo echo = _do_echo -TYPE_TO_CLASS = { - "plug": SmartPlug, - "bulb": SmartBulb, - "dimmer": SmartDimmer, - "strip": SmartStrip, - "lightstrip": SmartLightStrip, -} +DEVICE_TYPES = [ + device_type.value + for device_type in DeviceType + if device_type in DEVICE_TYPE_TO_CLASS +] click.anyio_backend = "asyncio" @@ -129,7 +126,7 @@ def json_formatter_cb(result, **kwargs): "--type", envvar="KASA_TYPE", default=None, - type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False), + type=click.Choice(DEVICE_TYPES, case_sensitive=False), ) @click.option( "--json", default=False, is_flag=True, help="Output raw device response as JSON." @@ -235,7 +232,10 @@ async def cli( return await ctx.invoke(discover, timeout=discovery_timeout) if type is not None: - dev = TYPE_TO_CLASS[type](host, credentials=credentials) + device_type = DeviceType.from_value(type) + dev = await SmartDevice.connect( + host, credentials=credentials, device_type=device_type + ) else: echo("No --type defined, discovering..") dev = await Discover.discover_single( @@ -243,8 +243,8 @@ async def cli( port=port, credentials=credentials, ) + await dev.update() - await dev.update() ctx.obj = dev if ctx.invoked_subcommand is None: diff --git a/kasa/device_factory.py b/kasa/device_factory.py new file mode 100755 index 00000000..c3ed4de3 --- /dev/null +++ b/kasa/device_factory.py @@ -0,0 +1,121 @@ +"""Device creation by type.""" + +import logging +import time +from typing import Any, Dict, Optional, Type + +from .credentials import Credentials +from .device_type import DeviceType +from .exceptions import UnsupportedDeviceException +from .smartbulb import SmartBulb +from .smartdevice import SmartDevice, SmartDeviceException +from .smartdimmer import SmartDimmer +from .smartlightstrip import SmartLightStrip +from .smartplug import SmartPlug +from .smartstrip import SmartStrip + +DEVICE_TYPE_TO_CLASS = { + DeviceType.Plug: SmartPlug, + DeviceType.Bulb: SmartBulb, + DeviceType.Strip: SmartStrip, + DeviceType.Dimmer: SmartDimmer, + DeviceType.LightStrip: SmartLightStrip, +} + +_LOGGER = logging.getLogger(__name__) + + +async def connect( + host: str, + *, + port: Optional[int] = None, + timeout=5, + credentials: Optional[Credentials] = None, + device_type: Optional[DeviceType] = None, +) -> "SmartDevice": + """Connect to a single device by the given IP address. + + This method avoids the UDP based discovery process and + will connect directly to the device to query its type. + + It is generally preferred to avoid :func:`discover_single()` and + use this function instead as it should perform better when + the WiFi network is congested or the device is not responding + to discovery requests. + + The device type is discovered by querying the device. + + :param host: Hostname of device to query + :param device_type: Device type to use for the device. + If not given, the device type is discovered by querying the device. + If the device type is already known, it is preferred to pass it + to avoid the extra query to the device to discover its type. + :rtype: SmartDevice + :return: Object for querying/controlling found device. + """ + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) + + if debug_enabled: + start_time = time.perf_counter() + + if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)): + dev: SmartDevice = klass( + host=host, port=port, credentials=credentials, timeout=timeout + ) + await dev.update() + if debug_enabled: + end_time = time.perf_counter() + _LOGGER.debug( + "Device %s with known type (%s) took %.2f seconds to connect", + host, + device_type.value, + end_time - start_time, + ) + return dev + + unknown_dev = SmartDevice( + host=host, port=port, credentials=credentials, timeout=timeout + ) + await unknown_dev.update() + device_class = get_device_class_from_info(unknown_dev.internal_state) + dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) + # Reuse the connection from the unknown device + # so we don't have to reconnect + dev.protocol = unknown_dev.protocol + await dev.update() + if debug_enabled: + end_time = time.perf_counter() + _LOGGER.debug( + "Device %s with unknown type (%s) took %.2f seconds to connect", + host, + dev.device_type.value, + end_time - start_time, + ) + return dev + + +def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]: + """Find SmartDevice subclass for device described by passed data.""" + if "system" not in info or "get_sysinfo" not in info["system"]: + raise SmartDeviceException("No 'system' or 'get_sysinfo' in response") + + sysinfo: Dict[str, Any] = info["system"]["get_sysinfo"] + type_: Optional[str] = sysinfo.get("type", sysinfo.get("mic_type")) + if type_ is None: + raise SmartDeviceException("Unable to find the device type field!") + + if "dev_name" in sysinfo and "Dimmer" in sysinfo["dev_name"]: + return SmartDimmer + + if "smartplug" in type_.lower(): + if "children" in sysinfo: + return SmartStrip + + return SmartPlug + + if "smartbulb" in type_.lower(): + if "length" in sysinfo: # strips have length + return SmartLightStrip + + return SmartBulb + raise UnsupportedDeviceException("Unknown device type: %s" % type_) diff --git a/kasa/device_type.py b/kasa/device_type.py new file mode 100755 index 00000000..162fc4f2 --- /dev/null +++ b/kasa/device_type.py @@ -0,0 +1,25 @@ +"""TP-Link device types.""" + + +from enum import Enum + + +class DeviceType(Enum): + """Device type enum.""" + + # The values match what the cli has historically used + Plug = "plug" + Bulb = "bulb" + Strip = "strip" + StripSocket = "stripsocket" + Dimmer = "dimmer" + LightStrip = "lightstrip" + Unknown = "unknown" + + @staticmethod + def from_value(name: str) -> "DeviceType": + """Return device type from string value.""" + for device_type in DeviceType: + if device_type.value == name: + return device_type + return DeviceType.Unknown diff --git a/kasa/discover.py b/kasa/discover.py index 151aae82..2523ba1a 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -20,13 +20,11 @@ from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.klapprotocol import TPLinkKlap -from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol -from kasa.smartbulb import SmartBulb +from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartdevice import SmartDevice, SmartDeviceException -from kasa.smartdimmer import SmartDimmer -from kasa.smartlightstrip import SmartLightStrip from kasa.smartplug import SmartPlug -from kasa.smartstrip import SmartStrip + +from .device_factory import get_device_class_from_info _LOGGER = logging.getLogger(__name__) @@ -345,78 +343,10 @@ class Discover: else: raise SmartDeviceException(f"Unable to get discovery response for {host}") - @staticmethod - async def connect_single( - host: str, - *, - port: Optional[int] = None, - timeout=5, - credentials: Optional[Credentials] = None, - protocol_class: Optional[Type[TPLinkProtocol]] = None, - ) -> SmartDevice: - """Connect to a single device by the given IP address. - - This method avoids the UDP based discovery process and - will connect directly to the device to query its type. - - It is generally preferred to avoid :func:`discover_single()` and - use this function instead as it should perform better when - the WiFi network is congested or the device is not responding - to discovery requests. - - The device type is discovered by querying the device. - - :param host: Hostname of device to query - :param port: Optionally set a different port for the device - :param timeout: Timeout for discovery - :param credentials: Optionally provide credentials for - devices requiring them - :param protocol_class: Optionally provide the protocol class - to use. - :rtype: SmartDevice - :return: Object for querying/controlling found device. - """ - unknown_dev = SmartDevice( - host=host, port=port, credentials=credentials, timeout=timeout - ) - if protocol_class is not None: - unknown_dev.protocol = protocol_class(host, credentials=credentials) - await unknown_dev.update() - device_class = Discover._get_device_class(unknown_dev.internal_state) - dev = device_class( - host=host, port=port, credentials=credentials, timeout=timeout - ) - # Reuse the connection from the unknown device - # so we don't have to reconnect - dev.protocol = unknown_dev.protocol - return dev - @staticmethod def _get_device_class(info: dict) -> Type[SmartDevice]: """Find SmartDevice subclass for device described by passed data.""" - if "system" not in info or "get_sysinfo" not in info["system"]: - raise SmartDeviceException("No 'system' or 'get_sysinfo' in response") - - sysinfo = info["system"]["get_sysinfo"] - type_ = sysinfo.get("type", sysinfo.get("mic_type")) - if type_ is None: - raise SmartDeviceException("Unable to find the device type field!") - - if "dev_name" in sysinfo and "Dimmer" in sysinfo["dev_name"]: - return SmartDimmer - - if "smartplug" in type_.lower(): - if "children" in sysinfo: - return SmartStrip - - return SmartPlug - - if "smartbulb" in type_.lower(): - if "length" in sysinfo: # strips have length - return SmartLightStrip - - return SmartBulb - raise UnsupportedDeviceException("Unknown device type: %s" % type_) + return get_device_class_from_info(info) @staticmethod def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index b081ac3f..f1995db8 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -17,10 +17,10 @@ import inspect import logging from dataclasses import dataclass from datetime import datetime, timedelta -from enum import Enum, auto from typing import Any, Dict, List, Optional, Set from .credentials import Credentials +from .device_type import DeviceType from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException from .modules import Emeter, Module @@ -29,18 +29,6 @@ from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol _LOGGER = logging.getLogger(__name__) -class DeviceType(Enum): - """Device type enum.""" - - Plug = auto() - Bulb = auto() - Strip = auto() - StripSocket = auto() - Dimmer = auto() - LightStrip = auto() - Unknown = -1 - - @dataclass class WifiNetwork: """Wifi network container.""" @@ -780,3 +768,42 @@ class SmartDevice: f" ({self.alias}), is_on: {self.is_on}" f" - dev specific: {self.state_information}>" ) + + @staticmethod + async def connect( + host: str, + *, + port: Optional[int] = None, + timeout=5, + credentials: Optional[Credentials] = None, + device_type: Optional[DeviceType] = None, + ) -> "SmartDevice": + """Connect to a single device by the given IP address. + + This method avoids the UDP based discovery process and + will connect directly to the device to query its type. + + It is generally preferred to avoid :func:`discover_single()` and + use this function instead as it should perform better when + the WiFi network is congested or the device is not responding + to discovery requests. + + The device type is discovered by querying the device. + + :param host: Hostname of device to query + :param device_type: Device type to use for the device. + If not given, the device type is discovered by querying the device. + If the device type is already known, it is preferred to pass it + to avoid the extra query to the device to discover its type. + :rtype: SmartDevice + :return: Object for querying/controlling found device. + """ + from .device_factory import connect # pylint: disable=import-outside-toplevel + + return await connect( + host=host, + port=port, + timeout=timeout, + credentials=credentials, + device_type=device_type, + ) diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 009632d7..f590808f 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -1,22 +1,11 @@ import json -import sys import asyncclick as click import pytest from asyncclick.testing import CliRunner from kasa import SmartDevice, TPLinkSmartHomeProtocol -from kasa.cli import ( - TYPE_TO_CLASS, - alias, - brightness, - cli, - emeter, - raw_command, - state, - sysinfo, - toggle, -) +from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle from kasa.discover import Discover from .conftest import handle_turn_on, turn_on @@ -154,14 +143,9 @@ async def test_credentials(discovery_data: dict, mocker): ) mocker.patch("kasa.cli.state", new=_state) - - # Get the type string parameter from the discovery_info - for cli_device_type in { # noqa: B007 - i - for i in TYPE_TO_CLASS - if TYPE_TO_CLASS[i] == Discover._get_device_class(discovery_data) - }: - break + cli_device_type = Discover._get_device_class(discovery_data)( + "any" + ).device_type.value runner = CliRunner() res = await runner.invoke( @@ -181,6 +165,24 @@ async def test_credentials(discovery_data: dict, mocker): assert res.output == "Username:foo Password:bar\n" +async def test_without_device_type(discovery_data: dict, dev, mocker): + """Test connecting without the device type.""" + runner = CliRunner() + mocker.patch("kasa.discover.Discover.discover_single", return_value=dev) + res = await runner.invoke( + cli, + [ + "--host", + "127.0.0.1", + "--username", + "foo", + "--password", + "bar", + ], + ) + assert res.exit_code == 0 + + @pytest.mark.parametrize("auth_param", ["--username", "--password"]) async def test_invalid_credential_params(auth_param): """Test for handling only one of username or password supplied.""" diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py new file mode 100644 index 00000000..3a08857a --- /dev/null +++ b/kasa/tests/test_device_factory.py @@ -0,0 +1,74 @@ +# type: ignore +import logging +from typing import Type + +import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 + +from kasa import ( + DeviceType, + SmartBulb, + SmartDevice, + SmartDeviceException, + SmartDimmer, + SmartLightStrip, + SmartPlug, +) +from kasa.device_factory import connect + + +@pytest.mark.parametrize("custom_port", [123, None]) +async def test_connect(discovery_data: dict, mocker, custom_port): + """Make sure that connect returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + + dev = await connect(host, port=custom_port) + assert issubclass(dev.__class__, SmartDevice) + assert dev.port == custom_port or dev.port == 9999 + + +@pytest.mark.parametrize("custom_port", [123, None]) +@pytest.mark.parametrize( + ("device_type", "klass"), + ( + (DeviceType.Plug, SmartPlug), + (DeviceType.Bulb, SmartBulb), + (DeviceType.Dimmer, SmartDimmer), + (DeviceType.LightStrip, SmartLightStrip), + (DeviceType.Unknown, SmartDevice), + ), +) +async def test_connect_passed_device_type( + discovery_data: dict, + mocker, + device_type: DeviceType, + klass: Type[SmartDevice], + custom_port, +): + """Make sure that connect with a passed device type.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + + dev = await connect(host, port=custom_port, device_type=device_type) + assert isinstance(dev, klass) + assert dev.port == custom_port or dev.port == 9999 + + +async def test_connect_query_fails(discovery_data: dict, mocker): + """Make sure that connect fails when query fails.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) + + with pytest.raises(SmartDeviceException): + await connect(host) + + +async def test_connect_logs_connect_time( + discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker +): + """Test that the connect time is logged when debug logging is enabled.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + logging.getLogger("kasa").setLevel(logging.DEBUG) + await connect(host) + assert "seconds to connect" in caplog.text diff --git a/kasa/tests/test_device_type.py b/kasa/tests/test_device_type.py new file mode 100644 index 00000000..da1707dc --- /dev/null +++ b/kasa/tests/test_device_type.py @@ -0,0 +1,23 @@ +from kasa.smartdevice import DeviceType + + +async def test_device_type_from_value(): + """Make sure that every device type can be created from its value.""" + for name in DeviceType: + assert DeviceType.from_value(name.value) is not None + + assert DeviceType.from_value("nonexistent") is DeviceType.Unknown + assert DeviceType.from_value("plug") is DeviceType.Plug + assert DeviceType.Plug.value == "plug" + + assert DeviceType.from_value("bulb") is DeviceType.Bulb + assert DeviceType.Bulb.value == "bulb" + + assert DeviceType.from_value("dimmer") is DeviceType.Dimmer + assert DeviceType.Dimmer.value == "dimmer" + + assert DeviceType.from_value("strip") is DeviceType.Strip + assert DeviceType.Strip.value == "strip" + + assert DeviceType.from_value("lightstrip") is DeviceType.LightStrip + assert DeviceType.LightStrip.value == "lightstrip" diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index f2e12599..7e1dabc0 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,7 +1,6 @@ # type: ignore import re import socket -import sys import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 @@ -111,27 +110,6 @@ async def test_discover_single_hostname(discovery_data: dict, mocker): x = await Discover.discover_single(host) -@pytest.mark.parametrize("custom_port", [123, None]) -async def test_connect_single(discovery_data: dict, mocker, custom_port): - """Make sure that connect_single returns an initialized SmartDevice instance.""" - host = "127.0.0.1" - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info) - - dev = await Discover.connect_single(host, port=custom_port) - assert issubclass(dev.__class__, SmartDevice) - assert dev.port == custom_port or dev.port == 9999 - - -async def test_connect_single_query_fails(mocker): - """Make sure that connect_single fails when query fails.""" - host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) - - with pytest.raises(SmartDeviceException): - await Discover.connect_single(host) - - UNSUPPORTED = { "result": { "device_id": "xx", diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index ae6886b8..85dc358d 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -1,12 +1,12 @@ import inspect from datetime import datetime -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import kasa from kasa import Credentials, SmartDevice, SmartDeviceException -from kasa.smartstrip import SmartStripPlug +from kasa.smartdevice import DeviceType from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol @@ -215,6 +215,28 @@ async def test_create_smart_device_with_timeout(): assert dev.protocol.timeout == 100 +async def test_create_thin_wrapper(): + """Make sure thin wrapper is created with the correct device type.""" + mock = Mock() + with patch("kasa.device_factory.connect", return_value=mock) as connect: + dev = await SmartDevice.connect( + host="test_host", + port=1234, + timeout=100, + credentials=Credentials("username", "password"), + device_type=DeviceType.Strip, + ) + assert dev is mock + + connect.assert_called_once_with( + host="test_host", + port=1234, + timeout=100, + credentials=Credentials("username", "password"), + device_type=DeviceType.Strip, + ) + + async def test_modules_not_supported(dev: SmartDevice): """Test that unsupported modules do not break the device.""" for module in dev.modules.values():