diff --git a/kasa/cli.py b/kasa/cli.py index 48ce039c..b2d9d91a 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -99,6 +99,12 @@ def json_formatter_cb(result, **kwargs): required=False, help="The host name or IP address of the device to connect to.", ) +@click.option( + "--port", + envvar="KASA_PORT", + required=False, + help="The port of the device to connect to.", +) @click.option( "--alias", envvar="KASA_NAME", @@ -125,7 +131,7 @@ def json_formatter_cb(result, **kwargs): ) @click.version_option(package_name="python-kasa") @click.pass_context -async def cli(ctx, host, alias, target, debug, type, json): +async def cli(ctx, host, port, alias, target, debug, type, json): """A tool for controlling TP-Link smart home devices.""" # noqa # no need to perform any checks if we are just displaying the help if sys.argv[-1] == "--help": @@ -179,7 +185,7 @@ async def cli(ctx, host, alias, target, debug, type, json): dev = TYPE_TO_CLASS[type](host) else: echo("No --type defined, discovering..") - dev = await Discover.discover_single(host) + dev = await Discover.discover_single(host, port=port) await dev.update() ctx.obj = dev @@ -275,6 +281,7 @@ async def state(dev: SmartDevice): """Print out device state and versions.""" echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]") echo(f"\tHost: {dev.host}") + echo(f"\tPort: {dev.port}") echo(f"\tDevice state: {dev.is_on}") if dev.is_strip: echo("\t[bold]== Plugs ==[/bold]") diff --git a/kasa/discover.py b/kasa/discover.py index 217ec32c..f7b5fbbf 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -193,19 +193,19 @@ class Discover: return protocol.discovered_devices @staticmethod - async def discover_single(host: str) -> SmartDevice: + async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice: """Discover a single device by the given IP address. :param host: Hostname of device to query :rtype: SmartDevice :return: Object for querying/controlling found device. """ - protocol = TPLinkSmartHomeProtocol(host) + protocol = TPLinkSmartHomeProtocol(host, port=port) info = await protocol.query(Discover.DISCOVERY_QUERY) device_class = Discover._get_device_class(info) - dev = device_class(host) + dev = device_class(host, port=port) await dev.update() return dev diff --git a/kasa/protocol.py b/kasa/protocol.py index e21866e3..cd9066c6 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -33,9 +33,10 @@ class TPLinkSmartHomeProtocol: DEFAULT_TIMEOUT = 5 BLOCK_SIZE = 4 - def __init__(self, host: str) -> None: + def __init__(self, host: str, *, port: Optional[int] = None) -> None: """Create a protocol object.""" self.host = host + self.port = port or TPLinkSmartHomeProtocol.DEFAULT_PORT self.reader: Optional[asyncio.StreamReader] = None self.writer: Optional[asyncio.StreamWriter] = None self.query_lock: Optional[asyncio.Lock] = None @@ -78,7 +79,7 @@ class TPLinkSmartHomeProtocol: if self.writer: return self.reader = self.writer = None - task = asyncio.open_connection(self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT) + task = asyncio.open_connection(self.host, self.port) self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout) async def _execute_query(self, request: str) -> Dict: @@ -133,13 +134,13 @@ class TPLinkSmartHomeProtocol: except ConnectionRefusedError as ex: await self.close() raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" + f"Unable to connect to the device: {self.host}:{self.port}: {ex}" ) except OSError as ex: await self.close() if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" + f"Unable to connect to the device: {self.host}:{self.port}: {ex}" ) continue except Exception as ex: @@ -147,7 +148,7 @@ class TPLinkSmartHomeProtocol: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" + f"Unable to connect to the device: {self.host}:{self.port}: {ex}" ) continue @@ -162,7 +163,7 @@ class TPLinkSmartHomeProtocol: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) raise SmartDeviceException( - f"Unable to query the device {self.host}: {ex}" + f"Unable to query the device {self.host}:{self.port}: {ex}" ) from ex _LOGGER.debug( diff --git a/kasa/smartbulb.py b/kasa/smartbulb.py index b28edab1..1d6ba31e 100644 --- a/kasa/smartbulb.py +++ b/kasa/smartbulb.py @@ -199,8 +199,8 @@ class SmartBulb(SmartDevice): SET_LIGHT_METHOD = "transition_light_state" emeter_type = "smartlife.iot.common.emeter" - def __init__(self, host: str) -> None: - super().__init__(host=host) + def __init__(self, host: str, *, port: Optional[int] = None) -> None: + super().__init__(host=host, port=port) self._device_type = DeviceType.Bulb self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule")) self.add_module("usage", Usage(self, "smartlife.iot.common.schedule")) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 75efcded..fd8d3768 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -191,14 +191,15 @@ class SmartDevice: emeter_type = "emeter" - def __init__(self, host: str) -> None: + def __init__(self, host: str, *, port: Optional[int] = None) -> None: """Create a new SmartDevice instance. :param str host: host name or ip address on which the device listens """ self.host = host + self.port = port - self.protocol = TPLinkSmartHomeProtocol(host) + self.protocol = TPLinkSmartHomeProtocol(host, port=port) _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown # TODO: typing Any is just as using Optional[Dict] would require separate checks in diff --git a/kasa/smartdimmer.py b/kasa/smartdimmer.py index 9565437d..247455e3 100644 --- a/kasa/smartdimmer.py +++ b/kasa/smartdimmer.py @@ -62,8 +62,8 @@ class SmartDimmer(SmartPlug): DIMMER_SERVICE = "smartlife.iot.dimmer" - def __init__(self, host: str) -> None: - super().__init__(host) + def __init__(self, host: str, *, port: Optional[int] = None) -> None: + super().__init__(host, port=port) self._device_type = DeviceType.Dimmer # TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be figured out what's the best approach to detect support for these diff --git a/kasa/smartlightstrip.py b/kasa/smartlightstrip.py index 566bf0a7..6afe5d11 100644 --- a/kasa/smartlightstrip.py +++ b/kasa/smartlightstrip.py @@ -41,8 +41,8 @@ class SmartLightStrip(SmartBulb): LIGHT_SERVICE = "smartlife.iot.lightStrip" SET_LIGHT_METHOD = "set_light_state" - def __init__(self, host: str) -> None: - super().__init__(host) + def __init__(self, host: str, *, port: Optional[int] = None) -> None: + super().__init__(host, port=port) self._device_type = DeviceType.LightStrip @property # type: ignore diff --git a/kasa/smartplug.py b/kasa/smartplug.py index d49e4054..94a5e350 100644 --- a/kasa/smartplug.py +++ b/kasa/smartplug.py @@ -1,6 +1,6 @@ """Module for smart plugs (HS100, HS110, ..).""" import logging -from typing import Any, Dict +from typing import Any, Dict, Optional from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage from kasa.smartdevice import DeviceType, SmartDevice, requires_update @@ -37,8 +37,8 @@ class SmartPlug(SmartDevice): For more examples, see the :class:`SmartDevice` class. """ - def __init__(self, host: str) -> None: - super().__init__(host) + def __init__(self, host: str, *, port: Optional[int] = None) -> None: + super().__init__(host, port=port) self._device_type = DeviceType.Plug self.add_module("schedule", Schedule(self, "schedule")) self.add_module("usage", Usage(self, "schedule")) diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index 69ea03e5..a970925b 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -79,8 +79,8 @@ class SmartStrip(SmartDevice): For more examples, see the :class:`SmartDevice` class. """ - def __init__(self, host: str) -> None: - super().__init__(host=host) + def __init__(self, host: str, *, port: Optional[int] = None) -> None: + super().__init__(host=host, port=port) self.emeter_type = "emeter" self._device_type = DeviceType.Strip self.add_module("antitheft", Antitheft(self, "anti_theft")) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 52325f39..bbdaf8a8 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -52,12 +52,14 @@ async def test_type_unknown(): Discover._get_device_class(invalid_info) -async def test_discover_single(discovery_data: dict, mocker): +@pytest.mark.parametrize("custom_port", [123, None]) +async def test_discover_single(discovery_data: dict, mocker, custom_port): """Make sure that discover_single returns an initialized SmartDevice instance.""" mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - x = await Discover.discover_single("127.0.0.1") + x = await Discover.discover_single("127.0.0.1", port=custom_port) assert issubclass(x.__class__, SmartDevice) assert x._sys_info is not None + assert x.port == custom_port INVALIDS = [ diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 7aa95a5a..b438f498 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -129,6 +129,36 @@ async def test_protocol_logging(mocker, caplog, log_level): assert "success" not in caplog.text +@pytest.mark.parametrize("custom_port", [123, None]) +async def test_protocol_custom_port(mocker, custom_port): + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, port): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + if custom_port is None: + assert port == 9999 + else: + assert port == custom_port + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + protocol = TPLinkSmartHomeProtocol("127.0.0.1", port=custom_port) + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + response = await protocol.query({}) + assert response == {"great": "success"} + + def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) encrypted = TPLinkSmartHomeProtocol.encrypt(d)