From 26502982a0d41bf0be8cc2d825633cfbacd70bec Mon Sep 17 00:00:00 2001 From: sdb9696 <51370195+sdb9696@users.noreply.github.com> Date: Tue, 7 Nov 2023 01:15:41 +0000 Subject: [PATCH] Update discover single to handle hostnames (#539) --- kasa/discover.py | 41 ++++++++++++++++++++++++++++++------ kasa/tests/test_discovery.py | 26 +++++++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/kasa/discover.py b/kasa/discover.py index b43df57b..5b11bed5 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -1,6 +1,7 @@ """Discovery module for TP-Link Smart Home devices.""" import asyncio import binascii +import ipaddress import logging import socket from typing import Awaitable, Callable, Dict, Optional, Type, cast @@ -273,9 +274,34 @@ class Discover: """ loop = asyncio.get_event_loop() event = asyncio.Event() + + try: + ipaddress.ip_address(host) + ip = host + except ValueError: + try: + adrrinfo = await loop.getaddrinfo( + host, + 0, + type=socket.SOCK_DGRAM, + family=socket.AF_INET, + ) + # getaddrinfo returns a list of 5 tuples with the following structure: + # (family, type, proto, canonname, sockaddr) + # where sockaddr is 2 tuple (ip, port). + # hence [0][4][0] is a stable array access because if no socket + # address matches the host for SOCK_DGRAM AF_INET the gaierror + # would be raised. + # https://docs.python.org/3/library/socket.html#socket.getaddrinfo + ip = adrrinfo[0][4][0] + except socket.gaierror as gex: + raise SmartDeviceException( + f"Could not resolve hostname {host}" + ) from gex + transport, protocol = await loop.create_datagram_endpoint( lambda: _DiscoverProtocol( - target=host, + target=ip, port=port, discovered_event=event, credentials=credentials, @@ -297,16 +323,17 @@ class Discover: finally: transport.close() - if host in protocol.discovered_devices: - dev = protocol.discovered_devices[host] + if ip in protocol.discovered_devices: + dev = protocol.discovered_devices[ip] + dev.host = host await dev.update() return dev - elif host in protocol.unsupported_devices: + elif ip in protocol.unsupported_devices: raise UnsupportedDeviceException( - f"Unsupported device {host}: {protocol.unsupported_devices[host]}" + f"Unsupported device {host}: {protocol.unsupported_devices[ip]}" ) - elif host in protocol.invalid_device_exceptions: - raise protocol.invalid_device_exceptions[host] + elif ip in protocol.invalid_device_exceptions: + raise protocol.invalid_device_exceptions[ip] else: raise SmartDeviceException(f"Unable to get discovery response for {host}") diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 3039f30c..7aeabe2f 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,5 +1,6 @@ # type: ignore import re +import socket import sys import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 @@ -74,6 +75,31 @@ async def test_discover_single(discovery_data: dict, mocker, custom_port): assert x.port == custom_port or x.port == 9999 +async def test_discover_single_hostname(discovery_data: dict, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "foobar" + ip = "127.0.0.1" + + def mock_discover(self): + self.datagram_received( + protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:], + (ip, 9999), + ) + + mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))]) + + x = await Discover.discover_single(host) + assert issubclass(x.__class__, SmartDevice) + assert x._sys_info is not None + assert x.host == host + + mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) + with pytest.raises(SmartDeviceException): + x = await Discover.discover_single(host) + + @pytest.mark.parametrize("custom_port", [123, None]) async def test_connect_single(discovery_data: dict, mocker, custom_port): """Make sure that connect_single returns an initialized SmartDevice instance."""