Add plumbing for passing credentials to devices (#507)

* Add plumbing for passing credentials as far as discovery

* Pass credentials to Smart devices

* Rename authentication exception

* Fix tests failure due to test_json_output leaving echo as nop

* Fix test_credentials test

* Do not print credentials, fix echo function bug and improve get type parameter

* Add device class constructor test

* Add comment for echo handling and move assignment
This commit is contained in:
sdb9696
2023-09-13 14:46:38 +01:00
committed by GitHub
parent f7c22f0a0c
commit 7bb4a456a2
13 changed files with 258 additions and 41 deletions

View File

@@ -9,6 +9,7 @@ from typing import Awaitable, Callable, Dict, Optional, Type, cast
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from kasa.credentials import Credentials
from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
@@ -45,6 +46,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
credentials: Optional[Credentials] = None,
):
self.transport = None
self.discovery_packets = discovery_packets
@@ -58,6 +60,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
self.credentials = credentials
def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
@@ -106,9 +109,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(info))
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
if self.discovered_event is not None:
self.discovered_event.set()
return
@@ -119,13 +120,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
)
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
if self.discovered_event is not None:
self.discovered_event.set()
return
device = device_class(ip, port=port)
device = device_class(ip, port=port, credentials=self.credentials)
device.update_from_discover_info(info)
self.discovered_devices[ip] = device
@@ -133,7 +132,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
if self.discovered_event is not None and "255" not in self.target[0].split("."):
if self.discovered_event is not None:
self.discovered_event.set()
def error_received(self, ex):
@@ -197,6 +196,7 @@ class Discover:
discovery_packets=3,
interface=None,
on_unsupported=None,
credentials=None,
) -> DeviceDict:
"""Discover supported devices.
@@ -225,6 +225,7 @@ class Discover:
discovery_packets=discovery_packets,
interface=interface,
on_unsupported=on_unsupported,
credentials=credentials,
),
local_addr=("0.0.0.0", 0),
)
@@ -242,7 +243,11 @@ class Discover:
@staticmethod
async def discover_single(
host: str, *, port: Optional[int] = None, timeout=5
host: str,
*,
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
) -> SmartDevice:
"""Discover a single device by the given IP address.
@@ -253,7 +258,9 @@ class Discover:
loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event),
lambda: _DiscoverProtocol(
target=host, port=port, discovered_event=event, credentials=credentials
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)