Move connect_single to SmartDevice.connect (#538)

This refactors `Discover.connect_single` by moving device instance construction into a separate device factory module.
New `SmartDevice.connect(host, *, port, timeout, credentials, device_type)` class method replaces the functionality of `connect_single`,
and also now allows constructing device instances without relying on UDP discovery for type discovery if `device_type` parameter is set.

---------

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2023-11-21 23:48:53 +01:00 committed by GitHub
parent 27c4799adc
commit e98252ff17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 361 additions and 144 deletions

View File

@ -12,6 +12,21 @@ or if you are just looking to access some information that is not currently expo
.. contents:: Contents .. contents:: Contents
:local: :local:
.. _initialization:
Initialization
**************
Use :func:`~kasa.Discover.discover` to perform udp-based broadcast discovery on the network.
This will return you a list of device instances based on the discovery replies.
If the device's host is already known, you can use to construct a device instance with
:meth:`~kasa.SmartDevice.connect()`.
When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to
pass the device type as well as this allows the library to use the correct device class for the
device without having to query the device.
.. _update_cycle: .. _update_cycle:
Update Cycle Update Cycle

View File

@ -13,14 +13,13 @@ import asyncclick as click
from kasa import ( from kasa import (
AuthenticationException, AuthenticationException,
Credentials, Credentials,
DeviceType,
Discover, Discover,
SmartBulb, SmartBulb,
SmartDevice, SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip, SmartStrip,
) )
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
try: try:
from rich import print as _do_echo from rich import print as _do_echo
@ -43,13 +42,11 @@ except ImportError:
# --json has set it to _nop_echo # --json has set it to _nop_echo
echo = _do_echo echo = _do_echo
TYPE_TO_CLASS = { DEVICE_TYPES = [
"plug": SmartPlug, device_type.value
"bulb": SmartBulb, for device_type in DeviceType
"dimmer": SmartDimmer, if device_type in DEVICE_TYPE_TO_CLASS
"strip": SmartStrip, ]
"lightstrip": SmartLightStrip,
}
click.anyio_backend = "asyncio" click.anyio_backend = "asyncio"
@ -129,7 +126,7 @@ def json_formatter_cb(result, **kwargs):
"--type", "--type",
envvar="KASA_TYPE", envvar="KASA_TYPE",
default=None, default=None,
type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False), type=click.Choice(DEVICE_TYPES, case_sensitive=False),
) )
@click.option( @click.option(
"--json", default=False, is_flag=True, help="Output raw device response as JSON." "--json", default=False, is_flag=True, help="Output raw device response as JSON."
@ -235,7 +232,10 @@ async def cli(
return await ctx.invoke(discover, timeout=discovery_timeout) return await ctx.invoke(discover, timeout=discovery_timeout)
if type is not None: if type is not None:
dev = TYPE_TO_CLASS[type](host, credentials=credentials) device_type = DeviceType.from_value(type)
dev = await SmartDevice.connect(
host, credentials=credentials, device_type=device_type
)
else: else:
echo("No --type defined, discovering..") echo("No --type defined, discovering..")
dev = await Discover.discover_single( dev = await Discover.discover_single(
@ -243,8 +243,8 @@ async def cli(
port=port, port=port,
credentials=credentials, credentials=credentials,
) )
await dev.update()
await dev.update()
ctx.obj = dev ctx.obj = dev
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:

121
kasa/device_factory.py Executable file
View File

@ -0,0 +1,121 @@
"""Device creation by type."""
import logging
import time
from typing import Any, Dict, Optional, Type
from .credentials import Credentials
from .device_type import DeviceType
from .exceptions import UnsupportedDeviceException
from .smartbulb import SmartBulb
from .smartdevice import SmartDevice, SmartDeviceException
from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug
from .smartstrip import SmartStrip
DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
}
_LOGGER = logging.getLogger(__name__)
async def connect(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param device_type: Device type to use for the device.
If not given, the device type is discovered by querying the device.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled:
start_time = time.perf_counter()
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
dev: SmartDevice = klass(
host=host, port=port, credentials=credentials, timeout=timeout
)
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
_LOGGER.debug(
"Device %s with known type (%s) took %.2f seconds to connect",
host,
device_type.value,
end_time - start_time,
)
return dev
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
await unknown_dev.update()
device_class = get_device_class_from_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
_LOGGER.debug(
"Device %s with unknown type (%s) took %.2f seconds to connect",
host,
dev.device_type.value,
end_time - start_time,
)
return dev
def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data."""
if "system" not in info or "get_sysinfo" not in info["system"]:
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
sysinfo: Dict[str, Any] = info["system"]["get_sysinfo"]
type_: Optional[str] = sysinfo.get("type", sysinfo.get("mic_type"))
if type_ is None:
raise SmartDeviceException("Unable to find the device type field!")
if "dev_name" in sysinfo and "Dimmer" in sysinfo["dev_name"]:
return SmartDimmer
if "smartplug" in type_.lower():
if "children" in sysinfo:
return SmartStrip
return SmartPlug
if "smartbulb" in type_.lower():
if "length" in sysinfo: # strips have length
return SmartLightStrip
return SmartBulb
raise UnsupportedDeviceException("Unknown device type: %s" % type_)

25
kasa/device_type.py Executable file
View File

@ -0,0 +1,25 @@
"""TP-Link device types."""
from enum import Enum
class DeviceType(Enum):
"""Device type enum."""
# The values match what the cli has historically used
Plug = "plug"
Bulb = "bulb"
Strip = "strip"
StripSocket = "stripsocket"
Dimmer = "dimmer"
LightStrip = "lightstrip"
Unknown = "unknown"
@staticmethod
def from_value(name: str) -> "DeviceType":
"""Return device type from string value."""
for device_type in DeviceType:
if device_type.value == name:
return device_type
return DeviceType.Unknown

View File

@ -20,13 +20,11 @@ from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.klapprotocol import TPLinkKlap from kasa.klapprotocol import TPLinkKlap
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
from kasa.smartstrip import SmartStrip
from .device_factory import get_device_class_from_info
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -345,78 +343,10 @@ class Discover:
else: else:
raise SmartDeviceException(f"Unable to get discovery response for {host}") raise SmartDeviceException(f"Unable to get discovery response for {host}")
@staticmethod
async def connect_single(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
protocol_class: Optional[Type[TPLinkProtocol]] = None,
) -> SmartDevice:
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param port: Optionally set a different port for the device
:param timeout: Timeout for discovery
:param credentials: Optionally provide credentials for
devices requiring them
:param protocol_class: Optionally provide the protocol class
to use.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials)
await unknown_dev.update()
device_class = Discover._get_device_class(unknown_dev.internal_state)
dev = device_class(
host=host, port=port, credentials=credentials, timeout=timeout
)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
return dev
@staticmethod @staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]: def _get_device_class(info: dict) -> Type[SmartDevice]:
"""Find SmartDevice subclass for device described by passed data.""" """Find SmartDevice subclass for device described by passed data."""
if "system" not in info or "get_sysinfo" not in info["system"]: return get_device_class_from_info(info)
raise SmartDeviceException("No 'system' or 'get_sysinfo' in response")
sysinfo = info["system"]["get_sysinfo"]
type_ = sysinfo.get("type", sysinfo.get("mic_type"))
if type_ is None:
raise SmartDeviceException("Unable to find the device type field!")
if "dev_name" in sysinfo and "Dimmer" in sysinfo["dev_name"]:
return SmartDimmer
if "smartplug" in type_.lower():
if "children" in sysinfo:
return SmartStrip
return SmartPlug
if "smartbulb" in type_.lower():
if "length" in sysinfo: # strips have length
return SmartLightStrip
return SmartBulb
raise UnsupportedDeviceException("Unknown device type: %s" % type_)
@staticmethod @staticmethod
def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice:

View File

@ -17,10 +17,10 @@ import inspect
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials from .credentials import Credentials
from .device_type import DeviceType
from .emeterstatus import EmeterStatus from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException from .exceptions import SmartDeviceException
from .modules import Emeter, Module from .modules import Emeter, Module
@ -29,18 +29,6 @@ from .protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class DeviceType(Enum):
"""Device type enum."""
Plug = auto()
Bulb = auto()
Strip = auto()
StripSocket = auto()
Dimmer = auto()
LightStrip = auto()
Unknown = -1
@dataclass @dataclass
class WifiNetwork: class WifiNetwork:
"""Wifi network container.""" """Wifi network container."""
@ -780,3 +768,42 @@ class SmartDevice:
f" ({self.alias}), is_on: {self.is_on}" f" ({self.alias}), is_on: {self.is_on}"
f" - dev specific: {self.state_information}>" f" - dev specific: {self.state_information}>"
) )
@staticmethod
async def connect(
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> "SmartDevice":
"""Connect to a single device by the given IP address.
This method avoids the UDP based discovery process and
will connect directly to the device to query its type.
It is generally preferred to avoid :func:`discover_single()` and
use this function instead as it should perform better when
the WiFi network is congested or the device is not responding
to discovery requests.
The device type is discovered by querying the device.
:param host: Hostname of device to query
:param device_type: Device type to use for the device.
If not given, the device type is discovered by querying the device.
If the device type is already known, it is preferred to pass it
to avoid the extra query to the device to discover its type.
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
from .device_factory import connect # pylint: disable=import-outside-toplevel
return await connect(
host=host,
port=port,
timeout=timeout,
credentials=credentials,
device_type=device_type,
)

View File

@ -1,22 +1,11 @@
import json import json
import sys
import asyncclick as click import asyncclick as click
import pytest import pytest
from asyncclick.testing import CliRunner from asyncclick.testing import CliRunner
from kasa import SmartDevice, TPLinkSmartHomeProtocol from kasa import SmartDevice, TPLinkSmartHomeProtocol
from kasa.cli import ( from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
TYPE_TO_CLASS,
alias,
brightness,
cli,
emeter,
raw_command,
state,
sysinfo,
toggle,
)
from kasa.discover import Discover from kasa.discover import Discover
from .conftest import handle_turn_on, turn_on from .conftest import handle_turn_on, turn_on
@ -154,14 +143,9 @@ async def test_credentials(discovery_data: dict, mocker):
) )
mocker.patch("kasa.cli.state", new=_state) mocker.patch("kasa.cli.state", new=_state)
cli_device_type = Discover._get_device_class(discovery_data)(
# Get the type string parameter from the discovery_info "any"
for cli_device_type in { # noqa: B007 ).device_type.value
i
for i in TYPE_TO_CLASS
if TYPE_TO_CLASS[i] == Discover._get_device_class(discovery_data)
}:
break
runner = CliRunner() runner = CliRunner()
res = await runner.invoke( res = await runner.invoke(
@ -181,6 +165,24 @@ async def test_credentials(discovery_data: dict, mocker):
assert res.output == "Username:foo Password:bar\n" assert res.output == "Username:foo Password:bar\n"
async def test_without_device_type(discovery_data: dict, dev, mocker):
"""Test connecting without the device type."""
runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
res = await runner.invoke(
cli,
[
"--host",
"127.0.0.1",
"--username",
"foo",
"--password",
"bar",
],
)
assert res.exit_code == 0
@pytest.mark.parametrize("auth_param", ["--username", "--password"]) @pytest.mark.parametrize("auth_param", ["--username", "--password"])
async def test_invalid_credential_params(auth_param): async def test_invalid_credential_params(auth_param):
"""Test for handling only one of username or password supplied.""" """Test for handling only one of username or password supplied."""

View File

@ -0,0 +1,74 @@
# type: ignore
import logging
from typing import Type
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
DeviceType,
SmartBulb,
SmartDevice,
SmartDeviceException,
SmartDimmer,
SmartLightStrip,
SmartPlug,
)
from kasa.device_factory import connect
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect(discovery_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: Type[SmartDevice],
custom_port,
):
"""Make sure that connect with a passed device type."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_query_fails(discovery_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await connect(host)
async def test_connect_logs_connect_time(
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host)
assert "seconds to connect" in caplog.text

View File

@ -0,0 +1,23 @@
from kasa.smartdevice import DeviceType
async def test_device_type_from_value():
"""Make sure that every device type can be created from its value."""
for name in DeviceType:
assert DeviceType.from_value(name.value) is not None
assert DeviceType.from_value("nonexistent") is DeviceType.Unknown
assert DeviceType.from_value("plug") is DeviceType.Plug
assert DeviceType.Plug.value == "plug"
assert DeviceType.from_value("bulb") is DeviceType.Bulb
assert DeviceType.Bulb.value == "bulb"
assert DeviceType.from_value("dimmer") is DeviceType.Dimmer
assert DeviceType.Dimmer.value == "dimmer"
assert DeviceType.from_value("strip") is DeviceType.Strip
assert DeviceType.Strip.value == "strip"
assert DeviceType.from_value("lightstrip") is DeviceType.LightStrip
assert DeviceType.LightStrip.value == "lightstrip"

View File

@ -1,7 +1,6 @@
# type: ignore # type: ignore
import re import re
import socket import socket
import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
@ -111,27 +110,6 @@ async def test_discover_single_hostname(discovery_data: dict, mocker):
x = await Discover.discover_single(host) x = await Discover.discover_single(host)
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_single(discovery_data: dict, mocker, custom_port):
"""Make sure that connect_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)
dev = await Discover.connect_single(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_single_query_fails(mocker):
"""Make sure that connect_single fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await Discover.connect_single(host)
UNSUPPORTED = { UNSUPPORTED = {
"result": { "result": {
"device_id": "xx", "device_id": "xx",

View File

@ -1,12 +1,12 @@
import inspect import inspect
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartstrip import SmartStripPlug from kasa.smartdevice import DeviceType
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@ -215,6 +215,28 @@ async def test_create_smart_device_with_timeout():
assert dev.protocol.timeout == 100 assert dev.protocol.timeout == 100
async def test_create_thin_wrapper():
"""Make sure thin wrapper is created with the correct device type."""
mock = Mock()
with patch("kasa.device_factory.connect", return_value=mock) as connect:
dev = await SmartDevice.connect(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
assert dev is mock
connect.assert_called_once_with(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
async def test_modules_not_supported(dev: SmartDevice): async def test_modules_not_supported(dev: SmartDevice):
"""Test that unsupported modules do not break the device.""" """Test that unsupported modules do not break the device."""
for module in dev.modules.values(): for module in dev.modules.values():