Update connect_single to allow passing in the device type

This commit is contained in:
J. Nick Koston 2023-10-31 16:11:23 -05:00
parent 805e4b8588
commit e638c7b189
No known key found for this signature in database
4 changed files with 91 additions and 21 deletions

View File

@ -15,7 +15,7 @@ from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType, SmartDevice, SmartDeviceException
from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug
@ -27,6 +27,14 @@ _LOGGER = logging.getLogger(__name__)
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]
DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
}
class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
@ -317,6 +325,7 @@ class Discover:
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> SmartDevice:
"""Connect to a single device by the given IP address.
@ -334,17 +343,21 @@ class Discover:
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
await unknown_dev.update()
device_class = Discover._get_device_class(unknown_dev.internal_state)
dev = device_class(
host=host, port=port, credentials=credentials, timeout=timeout
)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
dev = klass(host=host, port=port, credentials=credentials, timeout=timeout)
else:
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
await unknown_dev.update()
device_class = Discover._get_device_class(unknown_dev.internal_state)
dev = device_class(
host=host, port=port, credentials=credentials, timeout=timeout
)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
await dev.update()
return dev
@staticmethod

View File

@ -17,7 +17,7 @@ import inspect
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
from enum import Enum
from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials
@ -32,13 +32,21 @@ _LOGGER = logging.getLogger(__name__)
class DeviceType(Enum):
"""Device type enum."""
Plug = auto()
Bulb = auto()
Strip = auto()
StripSocket = auto()
Dimmer = auto()
LightStrip = auto()
Unknown = -1
Plug = "Plug"
Bulb = "Bulb"
Strip = "Strip"
StripSocket = "StripSocket"
Dimmer = "Dimmer"
LightStrip = "LightStrip"
Unknown = "Unknown"
@staticmethod
def from_value(name: str) -> "DeviceType":
"""Return device type from string value."""
for device_type in DeviceType:
if device_type.value == name:
return device_type
return DeviceType.Unknown
@dataclass

View File

@ -4,7 +4,18 @@ import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa import (
DeviceType,
Discover,
SmartBulb,
SmartDevice,
SmartDeviceException,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
protocol,
)
from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException
@ -85,6 +96,33 @@ async def test_connect_single(discovery_data: dict, mocker, custom_port):
assert dev.port == custom_port or dev.port == 9999
@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_single_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: type[SmartDevice],
custom_port,
):
"""Make sure that connect_single with a passed device type."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await Discover.connect_single(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_single_query_fails(discovery_data: dict, mocker):
"""Make sure that connect_single fails when query fails."""
host = "127.0.0.1"

View File

@ -6,6 +6,7 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from kasa.smartstrip import SmartStripPlug
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
@ -58,6 +59,16 @@ async def test_initial_update_no_emeter(dev, mocker):
assert spy.call_count == 2
async def test_smart_device_from_value():
"""Make sure that every device type can be created from its value."""
for name in DeviceType:
assert DeviceType.from_value(name.value) is not None
assert DeviceType.from_value("nonexistent") is DeviceType.Unknown
assert DeviceType.from_value("Plug") is DeviceType.Plug
assert DeviceType.Plug.value == "Plug"
async def test_query_helper(dev):
with pytest.raises(SmartDeviceException):
await dev._query_helper("test", "testcmd", {})