Add plumbing for passing credentials to devices (#507)

* Add plumbing for passing credentials as far as discovery

* Pass credentials to Smart devices

* Rename authentication exception

* Fix tests failure due to test_json_output leaving echo as nop

* Fix test_credentials test

* Do not print credentials, fix echo function bug and improve get type parameter

* Add device class constructor test

* Add comment for echo handling and move assignment
This commit is contained in:
sdb9696 2023-09-13 14:46:38 +01:00 committed by GitHub
parent f7c22f0a0c
commit 7bb4a456a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 258 additions and 41 deletions

View File

@ -13,9 +13,14 @@ to be handled by the user of the library.
"""
from importlib.metadata import version
from kasa.credentials import Credentials
from kasa.discover import Discover
from kasa.emeterstatus import EmeterStatus
from kasa.exceptions import SmartDeviceException
from kasa.exceptions import (
AuthenticationException,
SmartDeviceException,
UnsupportedDeviceException,
)
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors
from kasa.smartdevice import DeviceType, SmartDevice
@ -42,4 +47,7 @@ __all__ = [
"SmartStrip",
"SmartDimmer",
"SmartLightStrip",
"AuthenticationException",
"UnsupportedDeviceException",
"Credentials",
]

View File

@ -10,8 +10,19 @@ from typing import Any, Dict, cast
import asyncclick as click
from kasa import (
Credentials,
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
)
try:
from rich import print as echo
from rich import print as _do_echo
except ImportError:
def _strip_rich_formatting(echo_func):
@ -25,18 +36,11 @@ except ImportError:
return wrapper
echo = _strip_rich_formatting(click.echo)
_do_echo = _strip_rich_formatting(click.echo)
from kasa import (
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
)
# echo is set to _do_echo so that it can be reset to _do_echo later after
# --json has set it to _nop_echo
echo = _do_echo
TYPE_TO_CLASS = {
"plug": SmartPlug,
@ -48,7 +52,6 @@ TYPE_TO_CLASS = {
click.anyio_backend = "asyncio"
pass_dev = click.make_pass_decorator(SmartDevice)
@ -137,6 +140,20 @@ def json_formatter_cb(result, **kwargs):
required=False,
help="Timeout for discovery.",
)
@click.option(
"--username",
default=None,
required=False,
envvar="TPLINK_CLOUD_USERNAME",
help="Username/email address to authenticate to device.",
)
@click.option(
"--password",
default=None,
required=False,
envvar="TPLINK_CLOUD_PASSWORD",
help="Password to use to authenticate to device.",
)
@click.version_option(package_name="python-kasa")
@click.pass_context
async def cli(
@ -149,6 +166,8 @@ async def cli(
type,
json,
discovery_timeout,
username,
password,
):
"""A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help
@ -158,13 +177,17 @@ async def cli(
return
# If JSON output is requested, disable echo
global echo
if json:
global echo
def _nop_echo(*args, **kwargs):
pass
echo = _nop_echo
else:
# Set back to default is required if running tests with CliRunner
global _do_echo
echo = _do_echo
logging_config: Dict[str, Any] = {
"level": logging.DEBUG if debug > 0 else logging.INFO
@ -195,15 +218,25 @@ async def cli(
echo(f"No device with name {alias} found")
return
if bool(password) != bool(username):
echo("Using authentication requires both --username and --password")
return
credentials = Credentials(username=username, password=password)
if host is None:
echo("No host name given, trying discovery..")
return await ctx.invoke(discover, timeout=discovery_timeout)
if type is not None:
dev = TYPE_TO_CLASS[type](host)
dev = TYPE_TO_CLASS[type](host, credentials=credentials)
else:
echo("No --type defined, discovering..")
dev = await Discover.discover_single(host, port=port)
dev = await Discover.discover_single(
host,
port=port,
credentials=credentials,
)
await dev.update()
ctx.obj = dev
@ -261,6 +294,11 @@ async def join(dev: SmartDevice, ssid, password, keytype):
async def discover(ctx, timeout, show_unsupported):
"""Discover devices in the network."""
target = ctx.parent.params["target"]
username = ctx.parent.params["username"]
password = ctx.parent.params["password"]
credentials = Credentials(username, password)
sem = asyncio.Semaphore()
discovered = dict()
unsupported = []
@ -286,6 +324,7 @@ async def discover(ctx, timeout, show_unsupported):
timeout=timeout,
on_discovered=print_discovered,
on_unsupported=print_unsupported,
credentials=credentials,
)
echo(f"Found {len(discovered)} devices")

12
kasa/credentials.py Normal file
View File

@ -0,0 +1,12 @@
"""Credentials class for username / passwords."""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class Credentials:
"""Credentials for authentication."""
username: Optional[str] = field(default=None, repr=False)
password: Optional[str] = field(default=None, repr=False)

View File

@ -9,6 +9,7 @@ from typing import Awaitable, Callable, Dict, Optional, Type, cast
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from kasa.credentials import Credentials
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
@ -45,6 +46,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
credentials: Optional[Credentials] = None,
):
self.transport = None
self.discovery_packets = discovery_packets
@ -58,6 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
self.credentials = credentials
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
@ -106,9 +109,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(info))
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
if self.discovered_event is not None:
self.discovered_event.set()
return
@ -119,13 +120,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
)
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
if self.discovered_event is not None:
self.discovered_event.set()
return
device = device_class(ip, port=port)
device = device_class(ip, port=port, credentials=self.credentials)
device.update_from_discover_info(info)
self.discovered_devices[ip] = device
@ -133,7 +132,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
if self.discovered_event is not None and "255" not in self.target[0].split("."):
if self.discovered_event is not None:
self.discovered_event.set()
def error_received(self, ex):
@ -197,6 +196,7 @@ class Discover:
discovery_packets=3,
interface=None,
on_unsupported=None,
credentials=None,
) -> DeviceDict:
"""Discover supported devices.
@ -225,6 +225,7 @@ class Discover:
discovery_packets=discovery_packets,
interface=interface,
on_unsupported=on_unsupported,
credentials=credentials,
),
local_addr=("0.0.0.0", 0),
)
@ -242,7 +243,11 @@ class Discover:
@staticmethod
async def discover_single(
host: str, *, port: Optional[int] = None, timeout=5
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
) -> SmartDevice:
"""Discover a single device by the given IP address.
@ -253,7 +258,9 @@ class Discover:
loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event),
lambda: _DiscoverProtocol(
target=host, port=port, discovered_event=event, credentials=credentials
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)

View File

@ -7,3 +7,7 @@ class SmartDeviceException(Exception):
class UnsupportedDeviceException(SmartDeviceException):
"""Exception for trying to connect to unsupported devices."""
class AuthenticationException(SmartDeviceException):
"""Base exception for device authentication errors."""

View File

@ -9,6 +9,7 @@ try:
except ImportError:
from pydantic import BaseModel, Field, root_validator
from .credentials import Credentials
from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage
from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update
@ -202,8 +203,14 @@ class SmartBulb(SmartDevice):
SET_LIGHT_METHOD = "transition_light_state"
emeter_type = "smartlife.iot.common.emeter"
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host=host, port=port)
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None
) -> None:
super().__init__(host=host, port=port, credentials=credentials)
self._device_type = DeviceType.Bulb
self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule"))
self.add_module("usage", Usage(self, "smartlife.iot.common.schedule"))

View File

@ -20,6 +20,7 @@ from datetime import datetime, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Set
from .credentials import Credentials
from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException
from .modules import Emeter, Module
@ -191,7 +192,13 @@ class SmartDevice:
emeter_type = "emeter"
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None:
"""Create a new SmartDevice instance.
:param str host: host name or ip address on which the device listens
@ -200,6 +207,7 @@ class SmartDevice:
self.port = port
self.protocol = TPLinkSmartHomeProtocol(host, port=port)
self.credentials = credentials
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown
# TODO: typing Any is just as using Optional[Dict] would require separate checks in

View File

@ -2,6 +2,7 @@
from enum import Enum
from typing import Any, Dict, Optional
from kasa.credentials import Credentials
from kasa.modules import AmbientLight, Motion
from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update
from kasa.smartplug import SmartPlug
@ -62,8 +63,14 @@ class SmartDimmer(SmartPlug):
DIMMER_SERVICE = "smartlife.iot.dimmer"
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host, port=port)
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials)
self._device_type = DeviceType.Dimmer
# TODO: need to be verified if it's okay to call these on HS220 w/o these
# TODO: need to be figured out what's the best approach to detect support for these

View File

@ -1,6 +1,7 @@
"""Module for light strips (KL430)."""
from typing import Any, Dict, List, Optional
from .credentials import Credentials
from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1
from .smartbulb import SmartBulb
from .smartdevice import DeviceType, SmartDeviceException, requires_update
@ -41,8 +42,14 @@ class SmartLightStrip(SmartBulb):
LIGHT_SERVICE = "smartlife.iot.lightStrip"
SET_LIGHT_METHOD = "set_light_state"
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host, port=port)
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials)
self._device_type = DeviceType.LightStrip
@property # type: ignore

View File

@ -2,6 +2,7 @@
import logging
from typing import Any, Dict, Optional
from kasa.credentials import Credentials
from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage
from kasa.smartdevice import DeviceType, SmartDevice, requires_update
@ -37,8 +38,14 @@ class SmartPlug(SmartDevice):
For more examples, see the :class:`SmartDevice` class.
"""
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host, port=port)
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None
) -> None:
super().__init__(host, port=port, credentials=credentials)
self._device_type = DeviceType.Plug
self.add_module("schedule", Schedule(self, "schedule"))
self.add_module("usage", Usage(self, "schedule"))

View File

@ -14,6 +14,7 @@ from kasa.smartdevice import (
)
from kasa.smartplug import SmartPlug
from .credentials import Credentials
from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage
_LOGGER = logging.getLogger(__name__)
@ -79,8 +80,14 @@ class SmartStrip(SmartDevice):
For more examples, see the :class:`SmartDevice` class.
"""
def __init__(self, host: str, *, port: Optional[int] = None) -> None:
super().__init__(host=host, port=port)
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
) -> None:
super().__init__(host=host, port=port, credentials=credentials)
self.emeter_type = "emeter"
self._device_type = DeviceType.Strip
self.add_module("antitheft", Antitheft(self, "anti_theft"))

View File

@ -1,13 +1,26 @@
import json
import sys
import asyncclick as click
import pytest
from asyncclick.testing import CliRunner
from kasa import SmartDevice
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
from kasa import SmartDevice, TPLinkSmartHomeProtocol
from kasa.cli import (
TYPE_TO_CLASS,
alias,
brightness,
cli,
emeter,
raw_command,
state,
sysinfo,
toggle,
)
from kasa.discover import Discover
from .conftest import handle_turn_on, turn_on
from .newfakes import FakeTransportProtocol
async def test_sysinfo(dev):
@ -121,3 +134,70 @@ async def test_json_output(dev: SmartDevice, mocker):
res = await runner.invoke(cli, ["--json", "state"], obj=dev)
assert res.exit_code == 0
assert json.loads(res.output) == dev.internal_state
async def test_credentials(discovery_data: dict, mocker):
"""Test credentials are passed correctly from cli to device."""
# As this is testing the device constructor need to explicitly wire in
# the FakeTransportProtocol
ftp = FakeTransportProtocol(discovery_data)
mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query)
# Patch state to echo username and password
pass_dev = click.make_pass_decorator(SmartDevice)
@pass_dev
async def _state(dev: SmartDevice):
if dev.credentials:
click.echo(
f"Username:{dev.credentials.username} Password:{dev.credentials.password}"
)
mocker.patch("kasa.cli.state", new=_state)
# Get the type string parameter from the discovery_info
for cli_device_type in {
i
for i in TYPE_TO_CLASS
if TYPE_TO_CLASS[i] == Discover._get_device_class(discovery_data)
}:
break
runner = CliRunner()
res = await runner.invoke(
cli,
[
"--host",
"127.0.0.1",
"--type",
cli_device_type,
"--username",
"foo",
"--password",
"bar",
],
)
assert res.exit_code == 0
assert res.output == "Username:foo Password:bar\n"
@pytest.mark.parametrize("auth_param", ["--username", "--password"])
async def test_invalid_credential_params(auth_param):
runner = CliRunner()
# Test for handling only one of username or passowrd supplied.
res = await runner.invoke(
cli,
[
"--host",
"127.0.0.1",
"--type",
"plug",
auth_param,
"foo",
],
)
assert res.exit_code == 0
assert (
res.output == "Using authentication requires both --username and --password\n"
)

View File

@ -1,14 +1,26 @@
import inspect
from datetime import datetime
from unittest.mock import patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import SmartDeviceException
import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartstrip import SmartStripPlug
from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
# List of all SmartXXX classes including the SmartDevice base class
smart_device_classes = [
dc
for (mn, dc) in inspect.getmembers(
kasa,
lambda member: inspect.isclass(member)
and (member == SmartDevice or issubclass(member, SmartDevice)),
)
]
async def test_state_info(dev):
assert isinstance(dev.state_information, dict)
@ -150,3 +162,15 @@ async def test_features(dev):
assert dev.features == set(sysinfo["feature"].split(":"))
else:
assert dev.features == set()
@pytest.mark.parametrize("device_class", smart_device_classes)
def test_device_class_ctors(device_class):
"""Make sure constructor api not broken for new and existing SmartDevices."""
host = "127.0.0.2"
port = 1234
credentials = Credentials("foo", "bar")
dev = device_class(host, port=port, credentials=credentials)
assert dev.host == host
assert dev.port == port
assert dev.credentials == credentials