diff --git a/kasa/__init__.py b/kasa/__init__.py index d8cb0825..4ccf6286 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -13,9 +13,14 @@ to be handled by the user of the library. """ from importlib.metadata import version +from kasa.credentials import Credentials from kasa.discover import Discover from kasa.emeterstatus import EmeterStatus -from kasa.exceptions import SmartDeviceException +from kasa.exceptions import ( + AuthenticationException, + SmartDeviceException, + UnsupportedDeviceException, +) from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.smartdevice import DeviceType, SmartDevice @@ -42,4 +47,7 @@ __all__ = [ "SmartStrip", "SmartDimmer", "SmartLightStrip", + "AuthenticationException", + "UnsupportedDeviceException", + "Credentials", ] diff --git a/kasa/cli.py b/kasa/cli.py index cc782432..f0c180c5 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -10,8 +10,19 @@ from typing import Any, Dict, cast import asyncclick as click +from kasa import ( + Credentials, + Discover, + SmartBulb, + SmartDevice, + SmartDimmer, + SmartLightStrip, + SmartPlug, + SmartStrip, +) + try: - from rich import print as echo + from rich import print as _do_echo except ImportError: def _strip_rich_formatting(echo_func): @@ -25,18 +36,11 @@ except ImportError: return wrapper - echo = _strip_rich_formatting(click.echo) + _do_echo = _strip_rich_formatting(click.echo) - -from kasa import ( - Discover, - SmartBulb, - SmartDevice, - SmartDimmer, - SmartLightStrip, - SmartPlug, - SmartStrip, -) +# echo is set to _do_echo so that it can be reset to _do_echo later after +# --json has set it to _nop_echo +echo = _do_echo TYPE_TO_CLASS = { "plug": SmartPlug, @@ -48,7 +52,6 @@ TYPE_TO_CLASS = { click.anyio_backend = "asyncio" - pass_dev = click.make_pass_decorator(SmartDevice) @@ -137,6 +140,20 @@ def json_formatter_cb(result, **kwargs): required=False, help="Timeout for discovery.", ) +@click.option( + "--username", + default=None, + required=False, + envvar="TPLINK_CLOUD_USERNAME", + help="Username/email address to authenticate to device.", +) +@click.option( + "--password", + default=None, + required=False, + envvar="TPLINK_CLOUD_PASSWORD", + help="Password to use to authenticate to device.", +) @click.version_option(package_name="python-kasa") @click.pass_context async def cli( @@ -149,6 +166,8 @@ async def cli( type, json, discovery_timeout, + username, + password, ): """A tool for controlling TP-Link smart home devices.""" # noqa # no need to perform any checks if we are just displaying the help @@ -158,13 +177,17 @@ async def cli( return # If JSON output is requested, disable echo + global echo if json: - global echo def _nop_echo(*args, **kwargs): pass echo = _nop_echo + else: + # Set back to default is required if running tests with CliRunner + global _do_echo + echo = _do_echo logging_config: Dict[str, Any] = { "level": logging.DEBUG if debug > 0 else logging.INFO @@ -195,15 +218,25 @@ async def cli( echo(f"No device with name {alias} found") return + if bool(password) != bool(username): + echo("Using authentication requires both --username and --password") + return + + credentials = Credentials(username=username, password=password) + if host is None: echo("No host name given, trying discovery..") return await ctx.invoke(discover, timeout=discovery_timeout) if type is not None: - dev = TYPE_TO_CLASS[type](host) + dev = TYPE_TO_CLASS[type](host, credentials=credentials) else: echo("No --type defined, discovering..") - dev = await Discover.discover_single(host, port=port) + dev = await Discover.discover_single( + host, + port=port, + credentials=credentials, + ) await dev.update() ctx.obj = dev @@ -261,6 +294,11 @@ async def join(dev: SmartDevice, ssid, password, keytype): async def discover(ctx, timeout, show_unsupported): """Discover devices in the network.""" target = ctx.parent.params["target"] + username = ctx.parent.params["username"] + password = ctx.parent.params["password"] + + credentials = Credentials(username, password) + sem = asyncio.Semaphore() discovered = dict() unsupported = [] @@ -286,6 +324,7 @@ async def discover(ctx, timeout, show_unsupported): timeout=timeout, on_discovered=print_discovered, on_unsupported=print_unsupported, + credentials=credentials, ) echo(f"Found {len(discovered)} devices") diff --git a/kasa/credentials.py b/kasa/credentials.py new file mode 100644 index 00000000..a56f5710 --- /dev/null +++ b/kasa/credentials.py @@ -0,0 +1,12 @@ +"""Credentials class for username / passwords.""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class Credentials: + """Credentials for authentication.""" + + username: Optional[str] = field(default=None, repr=False) + password: Optional[str] = field(default=None, repr=False) diff --git a/kasa/discover.py b/kasa/discover.py index 5a78d193..f8e11a62 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -9,6 +9,7 @@ from typing import Awaitable, Callable, Dict, Optional, Type, cast # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout +from kasa.credentials import Credentials from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads @@ -45,6 +46,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None, port: Optional[int] = None, discovered_event: Optional[asyncio.Event] = None, + credentials: Optional[Credentials] = None, ): self.transport = None self.discovery_packets = discovery_packets @@ -58,6 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.invalid_device_exceptions: Dict = {} self.on_unsupported = on_unsupported self.discovered_event = discovered_event + self.credentials = credentials def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -106,9 +109,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): if self.on_unsupported is not None: asyncio.ensure_future(self.on_unsupported(info)) _LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info) - if self.discovered_event is not None and "255" not in self.target[0].split( - "." - ): + if self.discovered_event is not None: self.discovered_event.set() return @@ -119,13 +120,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): "[DISCOVERY] Unable to find device type from %s: %s", info, ex ) self.invalid_device_exceptions[ip] = ex - if self.discovered_event is not None and "255" not in self.target[0].split( - "." - ): + if self.discovered_event is not None: self.discovered_event.set() return - device = device_class(ip, port=port) + device = device_class(ip, port=port, credentials=self.credentials) device.update_from_discover_info(info) self.discovered_devices[ip] = device @@ -133,7 +132,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): if self.on_discovered is not None: asyncio.ensure_future(self.on_discovered(device)) - if self.discovered_event is not None and "255" not in self.target[0].split("."): + if self.discovered_event is not None: self.discovered_event.set() def error_received(self, ex): @@ -197,6 +196,7 @@ class Discover: discovery_packets=3, interface=None, on_unsupported=None, + credentials=None, ) -> DeviceDict: """Discover supported devices. @@ -225,6 +225,7 @@ class Discover: discovery_packets=discovery_packets, interface=interface, on_unsupported=on_unsupported, + credentials=credentials, ), local_addr=("0.0.0.0", 0), ) @@ -242,7 +243,11 @@ class Discover: @staticmethod async def discover_single( - host: str, *, port: Optional[int] = None, timeout=5 + host: str, + *, + port: Optional[int] = None, + timeout=5, + credentials: Optional[Credentials] = None, ) -> SmartDevice: """Discover a single device by the given IP address. @@ -253,7 +258,9 @@ class Discover: loop = asyncio.get_event_loop() event = asyncio.Event() transport, protocol = await loop.create_datagram_endpoint( - lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event), + lambda: _DiscoverProtocol( + target=host, port=port, discovered_event=event, credentials=credentials + ), local_addr=("0.0.0.0", 0), ) protocol = cast(_DiscoverProtocol, protocol) diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 0d2ff826..35870d1f 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -7,3 +7,7 @@ class SmartDeviceException(Exception): class UnsupportedDeviceException(SmartDeviceException): """Exception for trying to connect to unsupported devices.""" + + +class AuthenticationException(SmartDeviceException): + """Base exception for device authentication errors.""" diff --git a/kasa/smartbulb.py b/kasa/smartbulb.py index ad72701a..2d2f28ca 100644 --- a/kasa/smartbulb.py +++ b/kasa/smartbulb.py @@ -9,6 +9,7 @@ try: except ImportError: from pydantic import BaseModel, Field, root_validator +from .credentials import Credentials from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update @@ -202,8 +203,14 @@ class SmartBulb(SmartDevice): SET_LIGHT_METHOD = "transition_light_state" emeter_type = "smartlife.iot.common.emeter" - def __init__(self, host: str, *, port: Optional[int] = None) -> None: - super().__init__(host=host, port=port) + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None + ) -> None: + super().__init__(host=host, port=port, credentials=credentials) self._device_type = DeviceType.Bulb self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule")) self.add_module("usage", Usage(self, "smartlife.iot.common.schedule")) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index fd8d3768..5c24c943 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -20,6 +20,7 @@ from datetime import datetime, timedelta from enum import Enum, auto from typing import Any, Dict, List, Optional, Set +from .credentials import Credentials from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException from .modules import Emeter, Module @@ -191,7 +192,13 @@ class SmartDevice: emeter_type = "emeter" - def __init__(self, host: str, *, port: Optional[int] = None) -> None: + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + ) -> None: """Create a new SmartDevice instance. :param str host: host name or ip address on which the device listens @@ -200,6 +207,7 @@ class SmartDevice: self.port = port self.protocol = TPLinkSmartHomeProtocol(host, port=port) + self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown # TODO: typing Any is just as using Optional[Dict] would require separate checks in diff --git a/kasa/smartdimmer.py b/kasa/smartdimmer.py index 247455e3..05fb75ac 100644 --- a/kasa/smartdimmer.py +++ b/kasa/smartdimmer.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Any, Dict, Optional +from kasa.credentials import Credentials from kasa.modules import AmbientLight, Motion from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update from kasa.smartplug import SmartPlug @@ -62,8 +63,14 @@ class SmartDimmer(SmartPlug): DIMMER_SERVICE = "smartlife.iot.dimmer" - def __init__(self, host: str, *, port: Optional[int] = None) -> None: - super().__init__(host, port=port) + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + ) -> None: + super().__init__(host, port=port, credentials=credentials) self._device_type = DeviceType.Dimmer # TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be figured out what's the best approach to detect support for these diff --git a/kasa/smartlightstrip.py b/kasa/smartlightstrip.py index 6afe5d11..34e58115 100644 --- a/kasa/smartlightstrip.py +++ b/kasa/smartlightstrip.py @@ -1,6 +1,7 @@ """Module for light strips (KL430).""" from typing import Any, Dict, List, Optional +from .credentials import Credentials from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1 from .smartbulb import SmartBulb from .smartdevice import DeviceType, SmartDeviceException, requires_update @@ -41,8 +42,14 @@ class SmartLightStrip(SmartBulb): LIGHT_SERVICE = "smartlife.iot.lightStrip" SET_LIGHT_METHOD = "set_light_state" - def __init__(self, host: str, *, port: Optional[int] = None) -> None: - super().__init__(host, port=port) + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + ) -> None: + super().__init__(host, port=port, credentials=credentials) self._device_type = DeviceType.LightStrip @property # type: ignore diff --git a/kasa/smartplug.py b/kasa/smartplug.py index 94a5e350..f3d635d9 100644 --- a/kasa/smartplug.py +++ b/kasa/smartplug.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict, Optional +from kasa.credentials import Credentials from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage from kasa.smartdevice import DeviceType, SmartDevice, requires_update @@ -37,8 +38,14 @@ class SmartPlug(SmartDevice): For more examples, see the :class:`SmartDevice` class. """ - def __init__(self, host: str, *, port: Optional[int] = None) -> None: - super().__init__(host, port=port) + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None + ) -> None: + super().__init__(host, port=port, credentials=credentials) self._device_type = DeviceType.Plug self.add_module("schedule", Schedule(self, "schedule")) self.add_module("usage", Usage(self, "schedule")) diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index a970925b..479b0e56 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -14,6 +14,7 @@ from kasa.smartdevice import ( ) from kasa.smartplug import SmartPlug +from .credentials import Credentials from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage _LOGGER = logging.getLogger(__name__) @@ -79,8 +80,14 @@ class SmartStrip(SmartDevice): For more examples, see the :class:`SmartDevice` class. """ - def __init__(self, host: str, *, port: Optional[int] = None) -> None: - super().__init__(host=host, port=port) + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + ) -> None: + super().__init__(host=host, port=port, credentials=credentials) self.emeter_type = "emeter" self._device_type = DeviceType.Strip self.add_module("antitheft", Antitheft(self, "anti_theft")) diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index f7e04619..4a47284b 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -1,13 +1,26 @@ import json import sys +import asyncclick as click import pytest from asyncclick.testing import CliRunner -from kasa import SmartDevice -from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle +from kasa import SmartDevice, TPLinkSmartHomeProtocol +from kasa.cli import ( + TYPE_TO_CLASS, + alias, + brightness, + cli, + emeter, + raw_command, + state, + sysinfo, + toggle, +) +from kasa.discover import Discover from .conftest import handle_turn_on, turn_on +from .newfakes import FakeTransportProtocol async def test_sysinfo(dev): @@ -121,3 +134,70 @@ async def test_json_output(dev: SmartDevice, mocker): res = await runner.invoke(cli, ["--json", "state"], obj=dev) assert res.exit_code == 0 assert json.loads(res.output) == dev.internal_state + + +async def test_credentials(discovery_data: dict, mocker): + """Test credentials are passed correctly from cli to device.""" + # As this is testing the device constructor need to explicitly wire in + # the FakeTransportProtocol + ftp = FakeTransportProtocol(discovery_data) + mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query) + + # Patch state to echo username and password + pass_dev = click.make_pass_decorator(SmartDevice) + + @pass_dev + async def _state(dev: SmartDevice): + if dev.credentials: + click.echo( + f"Username:{dev.credentials.username} Password:{dev.credentials.password}" + ) + + mocker.patch("kasa.cli.state", new=_state) + + # Get the type string parameter from the discovery_info + for cli_device_type in { + i + for i in TYPE_TO_CLASS + if TYPE_TO_CLASS[i] == Discover._get_device_class(discovery_data) + }: + break + + runner = CliRunner() + res = await runner.invoke( + cli, + [ + "--host", + "127.0.0.1", + "--type", + cli_device_type, + "--username", + "foo", + "--password", + "bar", + ], + ) + assert res.exit_code == 0 + assert res.output == "Username:foo Password:bar\n" + + +@pytest.mark.parametrize("auth_param", ["--username", "--password"]) +async def test_invalid_credential_params(auth_param): + runner = CliRunner() + + # Test for handling only one of username or passowrd supplied. + res = await runner.invoke( + cli, + [ + "--host", + "127.0.0.1", + "--type", + "plug", + auth_param, + "foo", + ], + ) + assert res.exit_code == 0 + assert ( + res.output == "Using authentication requires both --username and --password\n" + ) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index dd97b081..0839bc06 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -1,14 +1,26 @@ +import inspect from datetime import datetime from unittest.mock import patch import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 -from kasa import SmartDeviceException +import kasa +from kasa import Credentials, SmartDevice, SmartDeviceException from kasa.smartstrip import SmartStripPlug from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol +# List of all SmartXXX classes including the SmartDevice base class +smart_device_classes = [ + dc + for (mn, dc) in inspect.getmembers( + kasa, + lambda member: inspect.isclass(member) + and (member == SmartDevice or issubclass(member, SmartDevice)), + ) +] + async def test_state_info(dev): assert isinstance(dev.state_information, dict) @@ -150,3 +162,15 @@ async def test_features(dev): assert dev.features == set(sysinfo["feature"].split(":")) else: assert dev.features == set() + + +@pytest.mark.parametrize("device_class", smart_device_classes) +def test_device_class_ctors(device_class): + """Make sure constructor api not broken for new and existing SmartDevices.""" + host = "127.0.0.2" + port = 1234 + credentials = Credentials("foo", "bar") + dev = device_class(host, port=port, credentials=credentials) + assert dev.host == host + assert dev.port == port + assert dev.credentials == credentials