mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-04 06:32:07 +00:00 
			
		
		
		
	Fix modularize with strips (#326)
* Fix test_deprecated_type stalling * Fix strips with modularize * Fix test_deprecated_type stalling (#325)
This commit is contained in:
		@@ -70,20 +70,13 @@ class TPLinkSmartHomeProtocol:
 | 
			
		||||
        async with self.query_lock:
 | 
			
		||||
            return await self._query(request, retry_count, timeout)
 | 
			
		||||
 | 
			
		||||
    async def _connect(self, timeout: int) -> bool:
 | 
			
		||||
    async def _connect(self, timeout: int) -> None:
 | 
			
		||||
        """Try to connect or reconnect to the device."""
 | 
			
		||||
        if self.writer:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        with contextlib.suppress(Exception):
 | 
			
		||||
            self.reader = self.writer = None
 | 
			
		||||
            task = asyncio.open_connection(
 | 
			
		||||
                self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
 | 
			
		||||
            )
 | 
			
		||||
            self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        return False
 | 
			
		||||
            return
 | 
			
		||||
        self.reader = self.writer = None
 | 
			
		||||
        task = asyncio.open_connection(self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
 | 
			
		||||
        self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
    async def _execute_query(self, request: str) -> Dict:
 | 
			
		||||
        """Execute a query on the device and wait for the response."""
 | 
			
		||||
@@ -123,12 +116,14 @@ class TPLinkSmartHomeProtocol:
 | 
			
		||||
    async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
 | 
			
		||||
        """Try to query a device."""
 | 
			
		||||
        for retry in range(retry_count + 1):
 | 
			
		||||
            if not await self._connect(timeout):
 | 
			
		||||
            try:
 | 
			
		||||
                await self._connect(timeout)
 | 
			
		||||
            except Exception as ex:
 | 
			
		||||
                await self.close()
 | 
			
		||||
                if retry >= retry_count:
 | 
			
		||||
                    _LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
 | 
			
		||||
                    raise SmartDeviceException(
 | 
			
		||||
                        f"Unable to connect to the device: {self.host}"
 | 
			
		||||
                        f"Unable to connect to the device: {self.host}: {ex}"
 | 
			
		||||
                    )
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -314,6 +314,11 @@ class SmartDevice:
 | 
			
		||||
            self._last_update = await self.protocol.query(req)
 | 
			
		||||
            self._sys_info = self._last_update["system"]["get_sysinfo"]
 | 
			
		||||
 | 
			
		||||
        await self._modular_update(req)
 | 
			
		||||
        self._sys_info = self._last_update["system"]["get_sysinfo"]
 | 
			
		||||
 | 
			
		||||
    async def _modular_update(self, req: dict) -> None:
 | 
			
		||||
        """Execute an update query."""
 | 
			
		||||
        if self.has_emeter:
 | 
			
		||||
            _LOGGER.debug(
 | 
			
		||||
                "The device has emeter, querying its information along sysinfo"
 | 
			
		||||
@@ -326,10 +331,9 @@ class SmartDevice:
 | 
			
		||||
                continue
 | 
			
		||||
            q = module.query()
 | 
			
		||||
            _LOGGER.debug("Adding query for %s: %s", module, q)
 | 
			
		||||
            req = merge(req, module.query())
 | 
			
		||||
            req = merge(req, q)
 | 
			
		||||
 | 
			
		||||
        self._last_update = await self.protocol.query(req)
 | 
			
		||||
        self._sys_info = self._last_update["system"]["get_sysinfo"]
 | 
			
		||||
 | 
			
		||||
    def update_from_discover_info(self, info):
 | 
			
		||||
        """Update state from info from the discover call."""
 | 
			
		||||
 
 | 
			
		||||
@@ -3,13 +3,14 @@ import logging
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from typing import Any, DefaultDict, Dict, Optional
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
from kasa.smartdevice import (
 | 
			
		||||
    DeviceType,
 | 
			
		||||
    EmeterStatus,
 | 
			
		||||
    SmartDevice,
 | 
			
		||||
    SmartDeviceException,
 | 
			
		||||
    requires_update,
 | 
			
		||||
    merge,
 | 
			
		||||
)
 | 
			
		||||
from kasa.smartplug import SmartPlug
 | 
			
		||||
 | 
			
		||||
@@ -250,16 +251,16 @@ class SmartStripPlug(SmartPlug):
 | 
			
		||||
        self._last_update = parent._last_update
 | 
			
		||||
        self._sys_info = parent._sys_info
 | 
			
		||||
        self._device_type = DeviceType.StripSocket
 | 
			
		||||
        self.modules = {}
 | 
			
		||||
        self.protocol = parent.protocol  # Must use the same connection as the parent
 | 
			
		||||
        self.add_module("time", Time(self, "time"))
 | 
			
		||||
 | 
			
		||||
    async def update(self, update_children: bool = True):
 | 
			
		||||
        """Query the device to update the data.
 | 
			
		||||
 | 
			
		||||
        Needed for properties that are decorated with `requires_update`.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: it needs to be checked if this still works after modularization
 | 
			
		||||
        self._last_update = await self.parent.protocol.query(
 | 
			
		||||
            self._create_emeter_request()
 | 
			
		||||
        )
 | 
			
		||||
        await self._modular_update({})
 | 
			
		||||
 | 
			
		||||
    def _create_emeter_request(self, year: int = None, month: int = None):
 | 
			
		||||
        """Create a request for requesting all emeter statistics at once."""
 | 
			
		||||
 
 | 
			
		||||
@@ -36,7 +36,9 @@ async def test_initial_update_no_emeter(dev, mocker):
 | 
			
		||||
    dev._last_update = None
 | 
			
		||||
    spy = mocker.spy(dev.protocol, "query")
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert spy.call_count == 1
 | 
			
		||||
    # 2 calls are necessary as some devices crash on unexpected modules
 | 
			
		||||
    # See #105, #120, #161
 | 
			
		||||
    assert spy.call_count == 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_query_helper(dev):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user