diff --git a/kasa/discover.py b/kasa/discover.py index b43df57b..90904d30 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -15,7 +15,7 @@ from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb -from kasa.smartdevice import SmartDevice, SmartDeviceException +from kasa.smartdevice import DeviceType, SmartDevice, SmartDeviceException from kasa.smartdimmer import SmartDimmer from kasa.smartlightstrip import SmartLightStrip from kasa.smartplug import SmartPlug @@ -27,6 +27,14 @@ _LOGGER = logging.getLogger(__name__) OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] DeviceDict = Dict[str, SmartDevice] +DEVICE_TYPE_TO_CLASS = { + DeviceType.Plug: SmartPlug, + DeviceType.Bulb: SmartBulb, + DeviceType.Strip: SmartStrip, + DeviceType.Dimmer: SmartDimmer, + DeviceType.LightStrip: SmartLightStrip, +} + class _DiscoverProtocol(asyncio.DatagramProtocol): """Implementation of the discovery protocol handler. @@ -317,6 +325,7 @@ class Discover: port: Optional[int] = None, timeout=5, credentials: Optional[Credentials] = None, + device_type: Optional[DeviceType] = None, ) -> SmartDevice: """Connect to a single device by the given IP address. @@ -334,17 +343,21 @@ class Discover: :rtype: SmartDevice :return: Object for querying/controlling found device. """ - unknown_dev = SmartDevice( - host=host, port=port, credentials=credentials, timeout=timeout - ) - await unknown_dev.update() - device_class = Discover._get_device_class(unknown_dev.internal_state) - dev = device_class( - host=host, port=port, credentials=credentials, timeout=timeout - ) - # Reuse the connection from the unknown device - # so we don't have to reconnect - dev.protocol = unknown_dev.protocol + if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)): + dev = klass(host=host, port=port, credentials=credentials, timeout=timeout) + else: + unknown_dev = SmartDevice( + host=host, port=port, credentials=credentials, timeout=timeout + ) + await unknown_dev.update() + device_class = Discover._get_device_class(unknown_dev.internal_state) + dev = device_class( + host=host, port=port, credentials=credentials, timeout=timeout + ) + # Reuse the connection from the unknown device + # so we don't have to reconnect + dev.protocol = unknown_dev.protocol + await dev.update() return dev @staticmethod diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 3e9bd953..342bc27c 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -17,7 +17,7 @@ import inspect import logging from dataclasses import dataclass from datetime import datetime, timedelta -from enum import Enum, auto +from enum import Enum from typing import Any, Dict, List, Optional, Set from .credentials import Credentials @@ -32,13 +32,21 @@ _LOGGER = logging.getLogger(__name__) class DeviceType(Enum): """Device type enum.""" - Plug = auto() - Bulb = auto() - Strip = auto() - StripSocket = auto() - Dimmer = auto() - LightStrip = auto() - Unknown = -1 + Plug = "Plug" + Bulb = "Bulb" + Strip = "Strip" + StripSocket = "StripSocket" + Dimmer = "Dimmer" + LightStrip = "LightStrip" + Unknown = "Unknown" + + @staticmethod + def from_value(name: str) -> "DeviceType": + """Return device type from string value.""" + for device_type in DeviceType: + if device_type.value == name: + return device_type + return DeviceType.Unknown @dataclass diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 3039f30c..d9bc9a1b 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -4,7 +4,18 @@ import sys import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 -from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol +from kasa import ( + DeviceType, + Discover, + SmartBulb, + SmartDevice, + SmartDeviceException, + SmartDimmer, + SmartLightStrip, + SmartPlug, + SmartStrip, + protocol, +) from kasa.discover import _DiscoverProtocol, json_dumps from kasa.exceptions import UnsupportedDeviceException @@ -85,6 +96,33 @@ async def test_connect_single(discovery_data: dict, mocker, custom_port): assert dev.port == custom_port or dev.port == 9999 +@pytest.mark.parametrize("custom_port", [123, None]) +@pytest.mark.parametrize( + ("device_type", "klass"), + ( + (DeviceType.Plug, SmartPlug), + (DeviceType.Bulb, SmartBulb), + (DeviceType.Dimmer, SmartDimmer), + (DeviceType.LightStrip, SmartLightStrip), + (DeviceType.Unknown, SmartDevice), + ), +) +async def test_connect_single_passed_device_type( + discovery_data: dict, + mocker, + device_type: DeviceType, + klass: type[SmartDevice], + custom_port, +): + """Make sure that connect_single with a passed device type.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + + dev = await Discover.connect_single(host, port=custom_port, device_type=device_type) + assert isinstance(dev, klass) + assert dev.port == custom_port or dev.port == 9999 + + async def test_connect_single_query_fails(discovery_data: dict, mocker): """Make sure that connect_single fails when query fails.""" host = "127.0.0.1" diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index f6f470b8..24a3dd0f 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -6,6 +6,7 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import kasa from kasa import Credentials, SmartDevice, SmartDeviceException +from kasa.smartdevice import DeviceType from kasa.smartstrip import SmartStripPlug from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on @@ -58,6 +59,16 @@ async def test_initial_update_no_emeter(dev, mocker): assert spy.call_count == 2 +async def test_smart_device_from_value(): + """Make sure that every device type can be created from its value.""" + for name in DeviceType: + assert DeviceType.from_value(name.value) is not None + + assert DeviceType.from_value("nonexistent") is DeviceType.Unknown + assert DeviceType.from_value("Plug") is DeviceType.Plug + assert DeviceType.Plug.value == "Plug" + + async def test_query_helper(dev): with pytest.raises(SmartDeviceException): await dev._query_helper("test", "testcmd", {})