Update discover single to handle hostnames (#539)

This commit is contained in:
sdb9696 2023-11-07 01:15:41 +00:00 committed by GitHub
parent 805e4b8588
commit 26502982a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 7 deletions

View File

@ -1,6 +1,7 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import binascii
import ipaddress
import logging
import socket
from typing import Awaitable, Callable, Dict, Optional, Type, cast
@ -273,9 +274,34 @@ class Discover:
"""
loop = asyncio.get_event_loop()
event = asyncio.Event()
try:
ipaddress.ip_address(host)
ip = host
except ValueError:
try:
adrrinfo = await loop.getaddrinfo(
host,
0,
type=socket.SOCK_DGRAM,
family=socket.AF_INET,
)
# getaddrinfo returns a list of 5 tuples with the following structure:
# (family, type, proto, canonname, sockaddr)
# where sockaddr is 2 tuple (ip, port).
# hence [0][4][0] is a stable array access because if no socket
# address matches the host for SOCK_DGRAM AF_INET the gaierror
# would be raised.
# https://docs.python.org/3/library/socket.html#socket.getaddrinfo
ip = adrrinfo[0][4][0]
except socket.gaierror as gex:
raise SmartDeviceException(
f"Could not resolve hostname {host}"
) from gex
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=host,
target=ip,
port=port,
discovered_event=event,
credentials=credentials,
@ -297,16 +323,17 @@ class Discover:
finally:
transport.close()
if host in protocol.discovered_devices:
dev = protocol.discovered_devices[host]
if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[ip]
dev.host = host
await dev.update()
return dev
elif host in protocol.unsupported_devices:
elif ip in protocol.unsupported_devices:
raise UnsupportedDeviceException(
f"Unsupported device {host}: {protocol.unsupported_devices[host]}"
f"Unsupported device {host}: {protocol.unsupported_devices[ip]}"
)
elif host in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[host]
elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")

View File

@ -1,5 +1,6 @@
# type: ignore
import re
import socket
import sys
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
@ -74,6 +75,31 @@ async def test_discover_single(discovery_data: dict, mocker, custom_port):
assert x.port == custom_port or x.port == 9999
async def test_discover_single_hostname(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "foobar"
ip = "127.0.0.1"
def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(ip, 9999),
)
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])
x = await Discover.discover_single(host)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.host == host
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host)
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_single(discovery_data: dict, mocker, custom_port):
"""Make sure that connect_single returns an initialized SmartDevice instance."""