Fix modularize with strips (#326)

* Fix test_deprecated_type stalling

* Fix strips with modularize

* Fix test_deprecated_type stalling (#325)
This commit is contained in:
J. Nick Koston 2022-04-05 06:16:36 -10:00 committed by Teemu R
parent f0d66e4195
commit 1e4df7ec1b
4 changed files with 24 additions and 22 deletions

View File

@ -70,20 +70,13 @@ class TPLinkSmartHomeProtocol:
async with self.query_lock: async with self.query_lock:
return await self._query(request, retry_count, timeout) return await self._query(request, retry_count, timeout)
async def _connect(self, timeout: int) -> bool: async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device.""" """Try to connect or reconnect to the device."""
if self.writer: if self.writer:
return True return
with contextlib.suppress(Exception):
self.reader = self.writer = None self.reader = self.writer = None
task = asyncio.open_connection( task = asyncio.open_connection(self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout) self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True
return False
async def _execute_query(self, request: str) -> Dict: async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response.""" """Execute a query on the device and wait for the response."""
@ -123,12 +116,14 @@ class TPLinkSmartHomeProtocol:
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device.""" """Try to query a device."""
for retry in range(retry_count + 1): for retry in range(retry_count + 1):
if not await self._connect(timeout): try:
await self._connect(timeout)
except Exception as ex:
await self.close() await self.close()
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry) _LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException( raise SmartDeviceException(
f"Unable to connect to the device: {self.host}" f"Unable to connect to the device: {self.host}: {ex}"
) )
continue continue

View File

@ -314,6 +314,11 @@ class SmartDevice:
self._last_update = await self.protocol.query(req) self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"] self._sys_info = self._last_update["system"]["get_sysinfo"]
await self._modular_update(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]
async def _modular_update(self, req: dict) -> None:
"""Execute an update query."""
if self.has_emeter: if self.has_emeter:
_LOGGER.debug( _LOGGER.debug(
"The device has emeter, querying its information along sysinfo" "The device has emeter, querying its information along sysinfo"
@ -326,10 +331,9 @@ class SmartDevice:
continue continue
q = module.query() q = module.query()
_LOGGER.debug("Adding query for %s: %s", module, q) _LOGGER.debug("Adding query for %s: %s", module, q)
req = merge(req, module.query()) req = merge(req, q)
self._last_update = await self.protocol.query(req) self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]
def update_from_discover_info(self, info): def update_from_discover_info(self, info):
"""Update state from info from the discover call.""" """Update state from info from the discover call."""

View File

@ -3,13 +3,14 @@ import logging
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, DefaultDict, Dict, Optional from typing import Any, DefaultDict, Dict, Optional
import asyncio
from kasa.smartdevice import ( from kasa.smartdevice import (
DeviceType, DeviceType,
EmeterStatus, EmeterStatus,
SmartDevice, SmartDevice,
SmartDeviceException, SmartDeviceException,
requires_update, requires_update,
merge,
) )
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
@ -250,16 +251,16 @@ class SmartStripPlug(SmartPlug):
self._last_update = parent._last_update self._last_update = parent._last_update
self._sys_info = parent._sys_info self._sys_info = parent._sys_info
self._device_type = DeviceType.StripSocket self._device_type = DeviceType.StripSocket
self.modules = {}
self.protocol = parent.protocol # Must use the same connection as the parent
self.add_module("time", Time(self, "time"))
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Query the device to update the data. """Query the device to update the data.
Needed for properties that are decorated with `requires_update`. Needed for properties that are decorated with `requires_update`.
""" """
# TODO: it needs to be checked if this still works after modularization await self._modular_update({})
self._last_update = await self.parent.protocol.query(
self._create_emeter_request()
)
def _create_emeter_request(self, year: int = None, month: int = None): def _create_emeter_request(self, year: int = None, month: int = None):
"""Create a request for requesting all emeter statistics at once.""" """Create a request for requesting all emeter statistics at once."""

View File

@ -36,7 +36,9 @@ async def test_initial_update_no_emeter(dev, mocker):
dev._last_update = None dev._last_update = None
spy = mocker.spy(dev.protocol, "query") spy = mocker.spy(dev.protocol, "query")
await dev.update() await dev.update()
assert spy.call_count == 1 # 2 calls are necessary as some devices crash on unexpected modules
# See #105, #120, #161
assert spy.call_count == 2
async def test_query_helper(dev): async def test_query_helper(dev):