Make device port configurable (#471)

This commit is contained in:
Viktar Karpach 2023-07-09 18:55:27 -05:00 committed by GitHub
parent 6199521269
commit 9b039d8374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 67 additions and 26 deletions

View File

@ -99,6 +99,12 @@ def json_formatter_cb(result, **kwargs):
required=False,
help="The host name or IP address of the device to connect to.",
)
@click.option(
"--port",
envvar="KASA_PORT",
required=False,
help="The port of the device to connect to.",
)
@click.option(
"--alias",
envvar="KASA_NAME",
@ -125,7 +131,7 @@ def json_formatter_cb(result, **kwargs):
)
@click.version_option(package_name="python-kasa")
@click.pass_context
async def cli(ctx, host, alias, target, debug, type, json):
async def cli(ctx, host, port, alias, target, debug, type, json):
"""A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help
if sys.argv[-1] == "--help":
@ -179,7 +185,7 @@ async def cli(ctx, host, alias, target, debug, type, json):
dev = TYPE_TO_CLASS[type](host)
else:
echo("No --type defined, discovering..")
dev = await Discover.discover_single(host)
dev = await Discover.discover_single(host, port=port)
await dev.update()
ctx.obj = dev
@ -275,6 +281,7 @@ async def state(dev: SmartDevice):
"""Print out device state and versions."""
echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]")
echo(f"\tHost: {dev.host}")
echo(f"\tPort: {dev.port}")
echo(f"\tDevice state: {dev.is_on}")
if dev.is_strip:
echo("\t[bold]== Plugs ==[/bold]")

View File

@ -193,19 +193,19 @@ class Discover:
return protocol.discovered_devices
@staticmethod
async def discover_single(host: str) -> SmartDevice:
async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice:
"""Discover a single device by the given IP address.
:param host: Hostname of device to query
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol(host)
protocol = TPLinkSmartHomeProtocol(host, port=port)
info = await protocol.query(Discover.DISCOVERY_QUERY)
device_class = Discover._get_device_class(info)
dev = device_class(host)
dev = device_class(host, port=port)
await dev.update()
return dev

View File

@ -33,9 +33,10 @@ class TPLinkSmartHomeProtocol:
DEFAULT_TIMEOUT = 5
BLOCK_SIZE = 4
def __init__(self, host: str) -> None:
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
"""Create a protocol object."""
self.host = host
self.port = port or TPLinkSmartHomeProtocol.DEFAULT_PORT
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
@ -78,7 +79,7 @@ class TPLinkSmartHomeProtocol:
if self.writer:
return
self.reader = self.writer = None
task = asyncio.open_connection(self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
task = asyncio.open_connection(self.host, self.port)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
async def _execute_query(self, request: str) -> Dict:
@ -133,13 +134,13 @@ class TPLinkSmartHomeProtocol:
except ConnectionRefusedError as ex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
)
except OSError as ex:
await self.close()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
)
continue
except Exception as ex:
@ -147,7 +148,7 @@ class TPLinkSmartHomeProtocol:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
)
continue
@ -162,7 +163,7 @@ class TPLinkSmartHomeProtocol:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to query the device {self.host}: {ex}"
f"Unable to query the device {self.host}:{self.port}: {ex}"
) from ex
_LOGGER.debug(

View File

@ -199,8 +199,8 @@ class SmartBulb(SmartDevice):
SET_LIGHT_METHOD = "transition_light_state"
emeter_type = "smartlife.iot.common.emeter"
def __init__(self, host: str) -> None:
super().__init__(host=host)
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host=host, port=port)
self._device_type = DeviceType.Bulb
self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule"))
self.add_module("usage", Usage(self, "smartlife.iot.common.schedule"))

View File

@ -191,14 +191,15 @@ class SmartDevice:
emeter_type = "emeter"
def __init__(self, host: str) -> None:
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
"""Create a new SmartDevice instance.
:param str host: host name or ip address on which the device listens
"""
self.host = host
self.port = port
self.protocol = TPLinkSmartHomeProtocol(host)
self.protocol = TPLinkSmartHomeProtocol(host, port=port)
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown
# TODO: typing Any is just as using Optional[Dict] would require separate checks in

View File

@ -62,8 +62,8 @@ class SmartDimmer(SmartPlug):
DIMMER_SERVICE = "smartlife.iot.dimmer"
def __init__(self, host: str) -> None:
super().__init__(host)
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host, port=port)
self._device_type = DeviceType.Dimmer
# TODO: need to be verified if it's okay to call these on HS220 w/o these
# TODO: need to be figured out what's the best approach to detect support for these

View File

@ -41,8 +41,8 @@ class SmartLightStrip(SmartBulb):
LIGHT_SERVICE = "smartlife.iot.lightStrip"
SET_LIGHT_METHOD = "set_light_state"
def __init__(self, host: str) -> None:
super().__init__(host)
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host, port=port)
self._device_type = DeviceType.LightStrip
@property # type: ignore

View File

@ -1,6 +1,6 @@
"""Module for smart plugs (HS100, HS110, ..)."""
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage
from kasa.smartdevice import DeviceType, SmartDevice, requires_update
@ -37,8 +37,8 @@ class SmartPlug(SmartDevice):
For more examples, see the :class:`SmartDevice` class.
"""
def __init__(self, host: str) -> None:
super().__init__(host)
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host, port=port)
self._device_type = DeviceType.Plug
self.add_module("schedule", Schedule(self, "schedule"))
self.add_module("usage", Usage(self, "schedule"))

View File

@ -79,8 +79,8 @@ class SmartStrip(SmartDevice):
For more examples, see the :class:`SmartDevice` class.
"""
def __init__(self, host: str) -> None:
super().__init__(host=host)
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host=host, port=port)
self.emeter_type = "emeter"
self._device_type = DeviceType.Strip
self.add_module("antitheft", Antitheft(self, "anti_theft"))

View File

@ -52,12 +52,14 @@ async def test_type_unknown():
Discover._get_device_class(invalid_info)
async def test_discover_single(discovery_data: dict, mocker):
@pytest.mark.parametrize("custom_port", [123, None])
async def test_discover_single(discovery_data: dict, mocker, custom_port):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
x = await Discover.discover_single("127.0.0.1")
x = await Discover.discover_single("127.0.0.1", port=custom_port)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.port == custom_port
INVALIDS = [

View File

@ -129,6 +129,36 @@ async def test_protocol_logging(mocker, caplog, log_level):
assert "success" not in caplog.text
@pytest.mark.parametrize("custom_port", [123, None])
async def test_protocol_custom_port(mocker, custom_port):
encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[
TPLinkSmartHomeProtocol.BLOCK_SIZE :
]
async def _mock_read(byte_count):
nonlocal encrypted
if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE:
return struct.pack(">I", len(encrypted))
if byte_count == len(encrypted):
return encrypted
raise ValueError(f"No mock for {byte_count}")
def aio_mock_writer(_, port):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")
if custom_port is None:
assert port == 9999
else:
assert port == custom_port
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port)
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)