Update connect_single to allow passing in the device type

This commit is contained in:
J. Nick Koston 2023-10-31 16:11:23 -05:00
parent 805e4b8588
commit e638c7b189
No known key found for this signature in database
4 changed files with 91 additions and 21 deletions

View File

@ -15,7 +15,7 @@ from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb from kasa.smartbulb import SmartBulb
from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType, SmartDevice, SmartDeviceException
from kasa.smartdimmer import SmartDimmer from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
@ -27,6 +27,14 @@ _LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice] DeviceDict = Dict[str, SmartDevice]
DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
}
class _DiscoverProtocol(asyncio.DatagramProtocol): class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler. """Implementation of the discovery protocol handler.
@ -317,6 +325,7 @@ class Discover:
port: Optional[int] = None, port: Optional[int] = None,
timeout=5, timeout=5,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> SmartDevice: ) -> SmartDevice:
"""Connect to a single device by the given IP address. """Connect to a single device by the given IP address.
@ -334,17 +343,21 @@ class Discover:
:rtype: SmartDevice :rtype: SmartDevice
:return: Object for querying/controlling found device. :return: Object for querying/controlling found device.
""" """
unknown_dev = SmartDevice( if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
host=host, port=port, credentials=credentials, timeout=timeout dev = klass(host=host, port=port, credentials=credentials, timeout=timeout)
) else:
await unknown_dev.update() unknown_dev = SmartDevice(
device_class = Discover._get_device_class(unknown_dev.internal_state) host=host, port=port, credentials=credentials, timeout=timeout
dev = device_class( )
host=host, port=port, credentials=credentials, timeout=timeout await unknown_dev.update()
) device_class = Discover._get_device_class(unknown_dev.internal_state)
# Reuse the connection from the unknown device dev = device_class(
# so we don't have to reconnect host=host, port=port, credentials=credentials, timeout=timeout
dev.protocol = unknown_dev.protocol )
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
await dev.update()
return dev return dev
@staticmethod @staticmethod

View File

@ -17,7 +17,7 @@ import inspect
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum, auto from enum import Enum
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials from .credentials import Credentials
@ -32,13 +32,21 @@ _LOGGER = logging.getLogger(__name__)
class DeviceType(Enum): class DeviceType(Enum):
"""Device type enum.""" """Device type enum."""
Plug = auto() Plug = "Plug"
Bulb = auto() Bulb = "Bulb"
Strip = auto() Strip = "Strip"
StripSocket = auto() StripSocket = "StripSocket"
Dimmer = auto() Dimmer = "Dimmer"
LightStrip = auto() LightStrip = "LightStrip"
Unknown = -1 Unknown = "Unknown"
@staticmethod
def from_value(name: str) -> "DeviceType":
"""Return device type from string value."""
for device_type in DeviceType:
if device_type.value == name:
return device_type
return DeviceType.Unknown
@dataclass @dataclass

View File

@ -4,7 +4,18 @@ import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol from kasa import (
DeviceType,
Discover,
SmartBulb,
SmartDevice,
SmartDeviceException,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
protocol,
)
from kasa.discover import _DiscoverProtocol, json_dumps from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException from kasa.exceptions import UnsupportedDeviceException
@ -85,6 +96,33 @@ async def test_connect_single(discovery_data: dict, mocker, custom_port):
assert dev.port == custom_port or dev.port == 9999 assert dev.port == custom_port or dev.port == 9999
@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_single_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: type[SmartDevice],
custom_port,
):
"""Make sure that connect_single with a passed device type."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await Discover.connect_single(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_single_query_fails(discovery_data: dict, mocker): async def test_connect_single_query_fails(discovery_data: dict, mocker):
"""Make sure that connect_single fails when query fails.""" """Make sure that connect_single fails when query fails."""
host = "127.0.0.1" host = "127.0.0.1"

View File

@ -6,6 +6,7 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from kasa.smartstrip import SmartStripPlug from kasa.smartstrip import SmartStripPlug
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
@ -58,6 +59,16 @@ async def test_initial_update_no_emeter(dev, mocker):
assert spy.call_count == 2 assert spy.call_count == 2
async def test_smart_device_from_value():
"""Make sure that every device type can be created from its value."""
for name in DeviceType:
assert DeviceType.from_value(name.value) is not None
assert DeviceType.from_value("nonexistent") is DeviceType.Unknown
assert DeviceType.from_value("Plug") is DeviceType.Plug
assert DeviceType.Plug.value == "Plug"
async def test_query_helper(dev): async def test_query_helper(dev):
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await dev._query_helper("test", "testcmd", {}) await dev._query_helper("test", "testcmd", {})