mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-28 09:36:25 +00:00
Update connect_single to allow passing in the device type
This commit is contained in:
parent
805e4b8588
commit
e638c7b189
@ -15,7 +15,7 @@ from kasa.json import dumps as json_dumps
|
|||||||
from kasa.json import loads as json_loads
|
from kasa.json import loads as json_loads
|
||||||
from kasa.protocol import TPLinkSmartHomeProtocol
|
from kasa.protocol import TPLinkSmartHomeProtocol
|
||||||
from kasa.smartbulb import SmartBulb
|
from kasa.smartbulb import SmartBulb
|
||||||
from kasa.smartdevice import SmartDevice, SmartDeviceException
|
from kasa.smartdevice import DeviceType, SmartDevice, SmartDeviceException
|
||||||
from kasa.smartdimmer import SmartDimmer
|
from kasa.smartdimmer import SmartDimmer
|
||||||
from kasa.smartlightstrip import SmartLightStrip
|
from kasa.smartlightstrip import SmartLightStrip
|
||||||
from kasa.smartplug import SmartPlug
|
from kasa.smartplug import SmartPlug
|
||||||
@ -27,6 +27,14 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
|
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
|
||||||
DeviceDict = Dict[str, SmartDevice]
|
DeviceDict = Dict[str, SmartDevice]
|
||||||
|
|
||||||
|
DEVICE_TYPE_TO_CLASS = {
|
||||||
|
DeviceType.Plug: SmartPlug,
|
||||||
|
DeviceType.Bulb: SmartBulb,
|
||||||
|
DeviceType.Strip: SmartStrip,
|
||||||
|
DeviceType.Dimmer: SmartDimmer,
|
||||||
|
DeviceType.LightStrip: SmartLightStrip,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
class _DiscoverProtocol(asyncio.DatagramProtocol):
|
||||||
"""Implementation of the discovery protocol handler.
|
"""Implementation of the discovery protocol handler.
|
||||||
@ -317,6 +325,7 @@ class Discover:
|
|||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
|
device_type: Optional[DeviceType] = None,
|
||||||
) -> SmartDevice:
|
) -> SmartDevice:
|
||||||
"""Connect to a single device by the given IP address.
|
"""Connect to a single device by the given IP address.
|
||||||
|
|
||||||
@ -334,17 +343,21 @@ class Discover:
|
|||||||
:rtype: SmartDevice
|
:rtype: SmartDevice
|
||||||
:return: Object for querying/controlling found device.
|
:return: Object for querying/controlling found device.
|
||||||
"""
|
"""
|
||||||
unknown_dev = SmartDevice(
|
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
|
||||||
host=host, port=port, credentials=credentials, timeout=timeout
|
dev = klass(host=host, port=port, credentials=credentials, timeout=timeout)
|
||||||
)
|
else:
|
||||||
await unknown_dev.update()
|
unknown_dev = SmartDevice(
|
||||||
device_class = Discover._get_device_class(unknown_dev.internal_state)
|
host=host, port=port, credentials=credentials, timeout=timeout
|
||||||
dev = device_class(
|
)
|
||||||
host=host, port=port, credentials=credentials, timeout=timeout
|
await unknown_dev.update()
|
||||||
)
|
device_class = Discover._get_device_class(unknown_dev.internal_state)
|
||||||
# Reuse the connection from the unknown device
|
dev = device_class(
|
||||||
# so we don't have to reconnect
|
host=host, port=port, credentials=credentials, timeout=timeout
|
||||||
dev.protocol = unknown_dev.protocol
|
)
|
||||||
|
# Reuse the connection from the unknown device
|
||||||
|
# so we don't have to reconnect
|
||||||
|
dev.protocol = unknown_dev.protocol
|
||||||
|
await dev.update()
|
||||||
return dev
|
return dev
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -17,7 +17,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from enum import Enum, auto
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
@ -32,13 +32,21 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
class DeviceType(Enum):
|
class DeviceType(Enum):
|
||||||
"""Device type enum."""
|
"""Device type enum."""
|
||||||
|
|
||||||
Plug = auto()
|
Plug = "Plug"
|
||||||
Bulb = auto()
|
Bulb = "Bulb"
|
||||||
Strip = auto()
|
Strip = "Strip"
|
||||||
StripSocket = auto()
|
StripSocket = "StripSocket"
|
||||||
Dimmer = auto()
|
Dimmer = "Dimmer"
|
||||||
LightStrip = auto()
|
LightStrip = "LightStrip"
|
||||||
Unknown = -1
|
Unknown = "Unknown"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_value(name: str) -> "DeviceType":
|
||||||
|
"""Return device type from string value."""
|
||||||
|
for device_type in DeviceType:
|
||||||
|
if device_type.value == name:
|
||||||
|
return device_type
|
||||||
|
return DeviceType.Unknown
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -4,7 +4,18 @@ import sys
|
|||||||
|
|
||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
|
from kasa import (
|
||||||
|
DeviceType,
|
||||||
|
Discover,
|
||||||
|
SmartBulb,
|
||||||
|
SmartDevice,
|
||||||
|
SmartDeviceException,
|
||||||
|
SmartDimmer,
|
||||||
|
SmartLightStrip,
|
||||||
|
SmartPlug,
|
||||||
|
SmartStrip,
|
||||||
|
protocol,
|
||||||
|
)
|
||||||
from kasa.discover import _DiscoverProtocol, json_dumps
|
from kasa.discover import _DiscoverProtocol, json_dumps
|
||||||
from kasa.exceptions import UnsupportedDeviceException
|
from kasa.exceptions import UnsupportedDeviceException
|
||||||
|
|
||||||
@ -85,6 +96,33 @@ async def test_connect_single(discovery_data: dict, mocker, custom_port):
|
|||||||
assert dev.port == custom_port or dev.port == 9999
|
assert dev.port == custom_port or dev.port == 9999
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("custom_port", [123, None])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("device_type", "klass"),
|
||||||
|
(
|
||||||
|
(DeviceType.Plug, SmartPlug),
|
||||||
|
(DeviceType.Bulb, SmartBulb),
|
||||||
|
(DeviceType.Dimmer, SmartDimmer),
|
||||||
|
(DeviceType.LightStrip, SmartLightStrip),
|
||||||
|
(DeviceType.Unknown, SmartDevice),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_connect_single_passed_device_type(
|
||||||
|
discovery_data: dict,
|
||||||
|
mocker,
|
||||||
|
device_type: DeviceType,
|
||||||
|
klass: type[SmartDevice],
|
||||||
|
custom_port,
|
||||||
|
):
|
||||||
|
"""Make sure that connect_single with a passed device type."""
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||||
|
|
||||||
|
dev = await Discover.connect_single(host, port=custom_port, device_type=device_type)
|
||||||
|
assert isinstance(dev, klass)
|
||||||
|
assert dev.port == custom_port or dev.port == 9999
|
||||||
|
|
||||||
|
|
||||||
async def test_connect_single_query_fails(discovery_data: dict, mocker):
|
async def test_connect_single_query_fails(discovery_data: dict, mocker):
|
||||||
"""Make sure that connect_single fails when query fails."""
|
"""Make sure that connect_single fails when query fails."""
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
|
@ -6,6 +6,7 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
|||||||
|
|
||||||
import kasa
|
import kasa
|
||||||
from kasa import Credentials, SmartDevice, SmartDeviceException
|
from kasa import Credentials, SmartDevice, SmartDeviceException
|
||||||
|
from kasa.smartdevice import DeviceType
|
||||||
from kasa.smartstrip import SmartStripPlug
|
from kasa.smartstrip import SmartStripPlug
|
||||||
|
|
||||||
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
|
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
|
||||||
@ -58,6 +59,16 @@ async def test_initial_update_no_emeter(dev, mocker):
|
|||||||
assert spy.call_count == 2
|
assert spy.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_smart_device_from_value():
|
||||||
|
"""Make sure that every device type can be created from its value."""
|
||||||
|
for name in DeviceType:
|
||||||
|
assert DeviceType.from_value(name.value) is not None
|
||||||
|
|
||||||
|
assert DeviceType.from_value("nonexistent") is DeviceType.Unknown
|
||||||
|
assert DeviceType.from_value("Plug") is DeviceType.Plug
|
||||||
|
assert DeviceType.Plug.value == "Plug"
|
||||||
|
|
||||||
|
|
||||||
async def test_query_helper(dev):
|
async def test_query_helper(dev):
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await dev._query_helper("test", "testcmd", {})
|
await dev._query_helper("test", "testcmd", {})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user