From 0ec0826cc770a85fc2ed5a52b69551872b5c27f1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 7 Oct 2023 08:58:00 -1000 Subject: [PATCH] Make timeout adjustable (#494) --- kasa/discover.py | 15 ++++++++++++--- kasa/protocol.py | 9 +++++---- kasa/smartbulb.py | 7 ++++--- kasa/smartdevice.py | 3 ++- kasa/smartdimmer.py | 3 ++- kasa/smartlightstrip.py | 3 ++- kasa/smartplug.py | 5 +++-- kasa/smartstrip.py | 3 ++- kasa/tests/conftest.py | 2 +- kasa/tests/test_smartdevice.py | 6 ++++++ 10 files changed, 39 insertions(+), 17 deletions(-) diff --git a/kasa/discover.py b/kasa/discover.py index f8e11a62..a39b2790 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -47,7 +47,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): port: Optional[int] = None, discovered_event: Optional[asyncio.Event] = None, credentials: Optional[Credentials] = None, - ): + timeout: Optional[int] = None, + ) -> None: self.transport = None self.discovery_packets = discovery_packets self.interface = interface @@ -61,6 +62,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.on_unsupported = on_unsupported self.discovered_event = discovered_event self.credentials = credentials + self.timeout = timeout def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -124,7 +126,9 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): self.discovered_event.set() return - device = device_class(ip, port=port, credentials=self.credentials) + device = device_class( + ip, port=port, credentials=self.credentials, timeout=self.timeout + ) device.update_from_discover_info(info) self.discovered_devices[ip] = device @@ -226,6 +230,7 @@ class Discover: interface=interface, on_unsupported=on_unsupported, credentials=credentials, + timeout=timeout, ), local_addr=("0.0.0.0", 0), ) @@ -259,7 +264,11 @@ class Discover: event = asyncio.Event() transport, protocol = await loop.create_datagram_endpoint( lambda: _DiscoverProtocol( - target=host, port=port, discovered_event=event, credentials=credentials + target=host, + port=port, + discovered_event=event, + credentials=credentials, + timeout=timeout, ), local_addr=("0.0.0.0", 0), ) diff --git a/kasa/protocol.py b/kasa/protocol.py index 461dd85a..3558b820 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -37,7 +37,9 @@ class TPLinkSmartHomeProtocol: DEFAULT_TIMEOUT = 5 BLOCK_SIZE = 4 - def __init__(self, host: str, *, port: Optional[int] = None) -> None: + def __init__( + self, host: str, *, port: Optional[int] = None, timeout: Optional[int] = None + ) -> None: """Create a protocol object.""" self.host = host self.port = port or TPLinkSmartHomeProtocol.DEFAULT_PORT @@ -45,6 +47,7 @@ class TPLinkSmartHomeProtocol: self.writer: Optional[asyncio.StreamWriter] = None self.query_lock: Optional[asyncio.Lock] = None self.loop: Optional[asyncio.AbstractEventLoop] = None + self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT def _detect_event_loop_change(self) -> None: """Check if this object has been reused betwen event loops.""" @@ -73,10 +76,8 @@ class TPLinkSmartHomeProtocol: request = json_dumps(request) assert isinstance(request, str) - timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT - async with self.query_lock: - return await self._query(request, retry_count, timeout) + return await self._query(request, retry_count, self.timeout) async def _connect(self, timeout: int) -> None: """Try to connect or reconnect to the device.""" diff --git a/kasa/smartbulb.py b/kasa/smartbulb.py index a09487d2..09d42053 100644 --- a/kasa/smartbulb.py +++ b/kasa/smartbulb.py @@ -208,9 +208,10 @@ class SmartBulb(SmartDevice): host: str, *, port: Optional[int] = None, - credentials: Optional[Credentials] = None + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: - super().__init__(host=host, port=port, credentials=credentials) + super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) self._device_type = DeviceType.Bulb self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule")) self.add_module("usage", Usage(self, "smartlife.iot.common.schedule")) @@ -372,7 +373,7 @@ class SmartBulb(SmartDevice): saturation: int, value: Optional[int] = None, *, - transition: Optional[int] = None + transition: Optional[int] = None, ) -> Dict: """Set new HSV. diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 4c1a3b93..bdef809a 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -195,6 +195,7 @@ class SmartDevice: *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: """Create a new SmartDevice instance. @@ -203,7 +204,7 @@ class SmartDevice: self.host = host self.port = port - self.protocol = TPLinkSmartHomeProtocol(host, port=port) + self.protocol = TPLinkSmartHomeProtocol(host, port=port, timeout=timeout) self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown diff --git a/kasa/smartdimmer.py b/kasa/smartdimmer.py index 05fb75ac..a412021c 100644 --- a/kasa/smartdimmer.py +++ b/kasa/smartdimmer.py @@ -69,8 +69,9 @@ class SmartDimmer(SmartPlug): *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials) + super().__init__(host, port=port, credentials=credentials, timeout=timeout) self._device_type = DeviceType.Dimmer # TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be figured out what's the best approach to detect support for these diff --git a/kasa/smartlightstrip.py b/kasa/smartlightstrip.py index 34e58115..e3dfc15f 100644 --- a/kasa/smartlightstrip.py +++ b/kasa/smartlightstrip.py @@ -48,8 +48,9 @@ class SmartLightStrip(SmartBulb): *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials) + super().__init__(host, port=port, credentials=credentials, timeout=timeout) self._device_type = DeviceType.LightStrip @property # type: ignore diff --git a/kasa/smartplug.py b/kasa/smartplug.py index f3d635d9..cd323c8d 100644 --- a/kasa/smartplug.py +++ b/kasa/smartplug.py @@ -43,9 +43,10 @@ class SmartPlug(SmartDevice): host: str, *, port: Optional[int] = None, - credentials: Optional[Credentials] = None + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials) + super().__init__(host, port=port, credentials=credentials, timeout=timeout) self._device_type = DeviceType.Plug self.add_module("schedule", Schedule(self, "schedule")) self.add_module("usage", Usage(self, "schedule")) diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index 479b0e56..2a55b2a8 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -86,8 +86,9 @@ class SmartStrip(SmartDevice): *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, ) -> None: - super().__init__(host=host, port=port, credentials=credentials) + super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) self.emeter_type = "emeter" self._device_type = DeviceType.Strip self.add_module("antitheft", Antitheft(self, "anti_theft")) diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 9b5a394d..2b2adc7d 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -166,7 +166,7 @@ async def _update_and_close(d): async def _discover_update_and_close(ip): - d = await Discover.discover_single(ip) + d = await Discover.discover_single(ip, timeout=10) return await _update_and_close(d) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 283fcfef..ec4e3d56 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -194,3 +194,9 @@ async def test_modules_preserved(dev: SmartDevice): dev._last_update["some_module_not_being_updated"] = "should_be_kept" await dev.update() assert dev._last_update["some_module_not_being_updated"] == "should_be_kept" + + +async def test_create_smart_device_with_timeout(): + """Make sure timeout is passed to the protocol.""" + dev = SmartDevice(host="127.0.0.1", timeout=100) + assert dev.protocol.timeout == 100