Update discover single to handle hostnames (#539)

This commit is contained in:
sdb9696 2023-11-07 01:15:41 +00:00 committed by GitHub
parent 805e4b8588
commit 26502982a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 7 deletions

View File

@ -1,6 +1,7 @@
"""Discovery module for TP-Link Smart Home devices.""" """Discovery module for TP-Link Smart Home devices."""
import asyncio import asyncio
import binascii import binascii
import ipaddress
import logging import logging
import socket import socket
from typing import Awaitable, Callable, Dict, Optional, Type, cast from typing import Awaitable, Callable, Dict, Optional, Type, cast
@ -273,9 +274,34 @@ class Discover:
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
event = asyncio.Event() event = asyncio.Event()
try:
ipaddress.ip_address(host)
ip = host
except ValueError:
try:
adrrinfo = await loop.getaddrinfo(
host,
0,
type=socket.SOCK_DGRAM,
family=socket.AF_INET,
)
# getaddrinfo returns a list of 5 tuples with the following structure:
# (family, type, proto, canonname, sockaddr)
# where sockaddr is 2 tuple (ip, port).
# hence [0][4][0] is a stable array access because if no socket
# address matches the host for SOCK_DGRAM AF_INET the gaierror
# would be raised.
# https://docs.python.org/3/library/socket.html#socket.getaddrinfo
ip = adrrinfo[0][4][0]
except socket.gaierror as gex:
raise SmartDeviceException(
f"Could not resolve hostname {host}"
) from gex
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol( lambda: _DiscoverProtocol(
target=host, target=ip,
port=port, port=port,
discovered_event=event, discovered_event=event,
credentials=credentials, credentials=credentials,
@ -297,16 +323,17 @@ class Discover:
finally: finally:
transport.close() transport.close()
if host in protocol.discovered_devices: if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[host] dev = protocol.discovered_devices[ip]
dev.host = host
await dev.update() await dev.update()
return dev return dev
elif host in protocol.unsupported_devices: elif ip in protocol.unsupported_devices:
raise UnsupportedDeviceException( raise UnsupportedDeviceException(
f"Unsupported device {host}: {protocol.unsupported_devices[host]}" f"Unsupported device {host}: {protocol.unsupported_devices[ip]}"
) )
elif host in protocol.invalid_device_exceptions: elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[host] raise protocol.invalid_device_exceptions[ip]
else: else:
raise SmartDeviceException(f"Unable to get discovery response for {host}") raise SmartDeviceException(f"Unable to get discovery response for {host}")

View File

@ -1,5 +1,6 @@
# type: ignore # type: ignore
import re import re
import socket
import sys import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
@ -74,6 +75,31 @@ async def test_discover_single(discovery_data: dict, mocker, custom_port):
assert x.port == custom_port or x.port == 9999 assert x.port == custom_port or x.port == 9999
async def test_discover_single_hostname(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "foobar"
ip = "127.0.0.1"
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(ip, 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])
x = await Discover.discover_single(host)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.host == host
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host)
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_single(discovery_data: dict, mocker, custom_port): async def test_connect_single(discovery_data: dict, mocker, custom_port):
"""Make sure that connect_single returns an initialized SmartDevice instance.""" """Make sure that connect_single returns an initialized SmartDevice instance."""