Move connect_single to SmartDevice.connect (#538)

This refactors `Discover.connect_single` by moving device instance construction into a separate device factory module.
New `SmartDevice.connect(host, *, port, timeout, credentials, device_type)` class method replaces the functionality of `connect_single`,
and also now allows constructing device instances without relying on UDP discovery for type discovery if `device_type` parameter is set.

---------

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2023-11-21 23:48:53 +01:00 committed by GitHub
parent 27c4799adc
commit e98252ff17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 361 additions and 144 deletions

View File

@ -12,6 +12,21 @@ or if you are just looking to access some information that is not currently expo
.. contents:: Contents
:local:
.. _initialization:
Initialization
**************
Use :func:`~kasa.Discover.discover` to perform udp-based broadcast discovery on the network.
This will return you a list of device instances based on the discovery replies.
If the device's host is already known, you can use to construct a device instance with
:meth:`~kasa.SmartDevice.connect()`.
When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to
pass the device type as well as this allows the library to use the correct device class for the
device without having to query the device.
.. _update_cycle:
Update Cycle

View File

@ -13,14 +13,13 @@ import asyncclick as click
from kasa import (
AuthenticationException,
Credentials,
DeviceType,
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
)
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
try:
from rich import print as _do_echo
@ -43,13 +42,11 @@ except ImportError:
# --json has set it to _nop_echo
echo = _do_echo
TYPE_TO_CLASS = {
"plug": SmartPlug,
"bulb": SmartBulb,
"dimmer": SmartDimmer,
"strip": SmartStrip,
"lightstrip": SmartLightStrip,
}
DEVICE_TYPES = [
device_type.value
for device_type in DeviceType
if device_type in DEVICE_TYPE_TO_CLASS
]
click.anyio_backend = "asyncio"
@ -129,7 +126,7 @@ def json_formatter_cb(result, **kwargs):
"--type",
envvar="KASA_TYPE",
default=None,
type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False),
type=click.Choice(DEVICE_TYPES, case_sensitive=False),
)
@click.option(
"--json", default=False, is_flag=True, help="Output raw device response as JSON."
@ -235,7 +232,10 @@ async def cli(
return await ctx.invoke(discover, timeout=discovery_timeout)
if type is not None:
dev = TYPE_TO_CLASS[type](host, credentials=credentials)
device_type = DeviceType.from_value(type)
dev = await SmartDevice.connect(
host, credentials=credentials, device_type=device_type
)
else:
echo("No --type defined, discovering..")
dev = await Discover.discover_single(
@ -243,8 +243,8 @@ async def cli(
port=port,
credentials=credentials,
)
await dev.update()
await dev.update()
ctx.obj = dev
if ctx.invoked_subcommand is None:

121
kasa/device_factory.py Executable file
View File

@ -0,0 +1,121 @@
"""Device creation by type."""
import logging
import time
from typing import Any, Dict, Optional, Type
from .credentials import Credentials
from .device_type import DeviceType
from .exceptions import UnsupportedDeviceException
from .smartbulb import SmartBulb
from .smartdevice import SmartDevice, SmartDeviceException
from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug
from .smartstrip import SmartStrip
DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
}
_LOGGER = logging.getLogger(__name__)
async def connect(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param device_type: Device type to use for the device.
If not given, the device type is discovered by querying the device.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled:
start_time = time.perf_counter()
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
dev: SmartDevice = klass(
host=host, port=port, credentials=credentials, timeout=timeout
)
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
_LOGGER.debug(
"Device %s with known type (%s) took %.2f seconds to connect",
host,
device_type.value,
end_time - start_time,
)
return dev
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
await unknown_dev.update()
device_class = get_device_class_from_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
_LOGGER.debug(
"Device %s with unknown type (%s) took %.2f seconds to connect",
host,
dev.device_type.value,
end_time - start_time,
)
return dev
def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
if "system" not in info or "get_sysinfo" not in info["system"]:
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
sysinfo: Dict[str, Any] = info["system"]["get_sysinfo"]
type_: Optional[str] = sysinfo.get("type", sysinfo.get("mic_type"))
if type_ is None:
raise SmartDeviceException("Unable to find the device type field!")
if "dev_name" in sysinfo and "Dimmer" in sysinfo["dev_name"]:
return SmartDimmer
if "smartplug" in type_.lower():
if "children" in sysinfo:
return SmartStrip
return SmartPlug
if "smartbulb" in type_.lower():
if "length" in sysinfo: # strips have length
return SmartLightStrip
return SmartBulb
raise UnsupportedDeviceException("Unknown device type: %s" % type_)

25
kasa/device_type.py Executable file
View File

@ -0,0 +1,25 @@
"""TP-Link device types."""
from enum import Enum
class DeviceType(Enum):
"""Device type enum."""
# The values match what the cli has historically used
Plug = "plug"
Bulb = "bulb"
Strip = "strip"
StripSocket = "stripsocket"
Dimmer = "dimmer"
LightStrip = "lightstrip"
Unknown = "unknown"
@staticmethod
def from_value(name: str) -> "DeviceType":
"""Return device type from string value."""
for device_type in DeviceType:
if device_type.value == name:
return device_type
return DeviceType.Unknown

View File

@ -20,13 +20,11 @@ from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.klapprotocol import TPLinkKlap
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug
from kasa.smartstrip import SmartStrip
from .device_factory import get_device_class_from_info
_LOGGER = logging.getLogger(__name__)
@ -345,78 +343,10 @@ class Discover:
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")
@staticmethod
async def connect_single(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
protocol_class: Optional[Type[TPLinkProtocol]] = None,
) -> SmartDevice:
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param port: Optionally set a different port for the device
:param timeout: Timeout for discovery
:param credentials: Optionally provide credentials for
devices requiring them
:param protocol_class: Optionally provide the protocol class
to use.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials)
await unknown_dev.update()
device_class = Discover._get_device_class(unknown_dev.internal_state)
dev = device_class(
host=host, port=port, credentials=credentials, timeout=timeout
)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
return dev
@staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
if "system" not in info or "get_sysinfo" not in info["system"]:
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
sysinfo = info["system"]["get_sysinfo"]
type_ = sysinfo.get("type", sysinfo.get("mic_type"))
if type_ is None:
raise SmartDeviceException("Unable to find the device type field!")
if "dev_name" in sysinfo and "Dimmer" in sysinfo["dev_name"]:
return SmartDimmer
if "smartplug" in type_.lower():
if "children" in sysinfo:
return SmartStrip
return SmartPlug
if "smartbulb" in type_.lower():
if "length" in sysinfo: # strips have length
return SmartLightStrip
return SmartBulb
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
return get_device_class_from_info(info)
@staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:

View File

@ -17,10 +17,10 @@ import inspect
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials
from .device_type import DeviceType
from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException
from .modules import Emeter, Module
@ -29,18 +29,6 @@ from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
_LOGGER = logging.getLogger(__name__)
class DeviceType(Enum):
"""Device type enum."""
Plug = auto()
Bulb = auto()
Strip = auto()
StripSocket = auto()
Dimmer = auto()
LightStrip = auto()
Unknown = -1
@dataclass
class WifiNetwork:
"""Wifi network container."""
@ -780,3 +768,42 @@ class SmartDevice:
f" ({self.alias}), is_on: {self.is_on}"
f" - dev specific: {self.state_information}>"
)
@staticmethod
async def connect(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param device_type: Device type to use for the device.
If not given, the device type is discovered by querying the device.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
from .device_factory import connect # pylint: disable=import-outside-toplevel
return await connect(
host=host,
port=port,
timeout=timeout,
credentials=credentials,
device_type=device_type,
)

View File

@ -1,22 +1,11 @@
import json
import sys
import asyncclick as click
import pytest
from asyncclick.testing import CliRunner
from kasa import SmartDevice, TPLinkSmartHomeProtocol
from kasa.cli import (
TYPE_TO_CLASS,
alias,
brightness,
cli,
emeter,
raw_command,
state,
sysinfo,
toggle,
)
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
from kasa.discover import Discover
from .conftest import handle_turn_on, turn_on
@ -154,14 +143,9 @@ async def test_credentials(discovery_data: dict, mocker):
)
mocker.patch("kasa.cli.state", new=_state)
# Get the type string parameter from the discovery_info
for cli_device_type in { # noqa: B007
i
for i in TYPE_TO_CLASS
if TYPE_TO_CLASS[i] == Discover._get_device_class(discovery_data)
}:
break
cli_device_type = Discover._get_device_class(discovery_data)(
"any"
).device_type.value
runner = CliRunner()
res = await runner.invoke(
@ -181,6 +165,24 @@ async def test_credentials(discovery_data: dict, mocker):
assert res.output == "Username:foo Password:bar\n"
async def test_without_device_type(discovery_data: dict, dev, mocker):
"""Test connecting without the device type."""
runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
res = await runner.invoke(
cli,
[
"--host",
"127.0.0.1",
"--username",
"foo",
"--password",
"bar",
],
)
assert res.exit_code == 0
@pytest.mark.parametrize("auth_param", ["--username", "--password"])
async def test_invalid_credential_params(auth_param):
"""Test for handling only one of username or password supplied."""

View File

@ -0,0 +1,74 @@
# type: ignore
import logging
from typing import Type
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
DeviceType,
SmartBulb,
SmartDevice,
SmartDeviceException,
SmartDimmer,
SmartLightStrip,
SmartPlug,
)
from kasa.device_factory import connect
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect(discovery_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: Type[SmartDevice],
custom_port,
):
"""Make sure that connect with a passed device type."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_query_fails(discovery_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await connect(host)
async def test_connect_logs_connect_time(
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host)
assert "seconds to connect" in caplog.text

View File

@ -0,0 +1,23 @@
from kasa.smartdevice import DeviceType
async def test_device_type_from_value():
"""Make sure that every device type can be created from its value."""
for name in DeviceType:
assert DeviceType.from_value(name.value) is not None
assert DeviceType.from_value("nonexistent") is DeviceType.Unknown
assert DeviceType.from_value("plug") is DeviceType.Plug
assert DeviceType.Plug.value == "plug"
assert DeviceType.from_value("bulb") is DeviceType.Bulb
assert DeviceType.Bulb.value == "bulb"
assert DeviceType.from_value("dimmer") is DeviceType.Dimmer
assert DeviceType.Dimmer.value == "dimmer"
assert DeviceType.from_value("strip") is DeviceType.Strip
assert DeviceType.Strip.value == "strip"
assert DeviceType.from_value("lightstrip") is DeviceType.LightStrip
assert DeviceType.LightStrip.value == "lightstrip"

View File

@ -1,7 +1,6 @@
# type: ignore
import re
import socket
import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
@ -111,27 +110,6 @@ async def test_discover_single_hostname(discovery_data: dict, mocker):
x = await Discover.discover_single(host)
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_single(discovery_data: dict, mocker, custom_port):
"""Make sure that connect_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
dev = await Discover.connect_single(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_single_query_fails(mocker):
"""Make sure that connect_single fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await Discover.connect_single(host)
UNSUPPORTED = {
"result": {
"device_id": "xx",

View File

@ -1,12 +1,12 @@
import inspect
from datetime import datetime
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartstrip import SmartStripPlug
from kasa.smartdevice import DeviceType
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@ -215,6 +215,28 @@ async def test_create_smart_device_with_timeout():
assert dev.protocol.timeout == 100
async def test_create_thin_wrapper():
"""Make sure thin wrapper is created with the correct device type."""
mock = Mock()
with patch("kasa.device_factory.connect", return_value=mock) as connect:
dev = await SmartDevice.connect(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
assert dev is mock
connect.assert_called_once_with(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
async def test_modules_not_supported(dev: SmartDevice):
"""Test that unsupported modules do not break the device."""
for module in dev.modules.values():