diff --git a/kasa/smartbulb.py b/kasa/smartbulb.py index 401214cd..c9587be1 100644 --- a/kasa/smartbulb.py +++ b/kasa/smartbulb.py @@ -71,8 +71,8 @@ class SmartBulb(SmartDevice): LIGHT_SERVICE = "smartlife.iot.smartbulb.lightingservice" - def __init__(self, host: str, *, cache_ttl: int = 3) -> None: - SmartDevice.__init__(self, host=host, cache_ttl=cache_ttl) + def __init__(self, host: str) -> None: + super().__init__(host=host) self.emeter_type = "smartlife.iot.common.emeter" self._device_type = DeviceType.Bulb self._light_state = None diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 1a02ca31..593c035e 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -14,8 +14,7 @@ http://www.apache.org/licenses/LICENSE-2.0 import functools import inspect import logging -from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime from enum import Enum from typing import Any, Dict, Optional @@ -103,7 +102,7 @@ def requires_update(f): class SmartDevice: """Base class for all supported device types.""" - def __init__(self, host: str, *, cache_ttl: int = 3) -> None: + def __init__(self, host: str) -> None: """Create a new SmartDevice instance. :param str host: host name or ip address on which the device listens @@ -113,48 +112,10 @@ class SmartDevice: self.protocol = TPLinkSmartHomeProtocol() self.emeter_type = "emeter" - self.cache_ttl = timedelta(seconds=cache_ttl) - _LOGGER.debug("Initializing %s with cache ttl %s", self.host, self.cache_ttl) - self.cache = defaultdict(lambda: defaultdict(lambda: None)) # type: ignore + _LOGGER.debug("Initializing %s", self.host) self._device_type = DeviceType.Unknown self._sys_info: Optional[Dict] = None - def _result_from_cache(self, target, cmd) -> Optional[Dict]: - """Return query result from cache if still fresh. - - Only results from commands starting with `get_` are considered cacheable. - - :param target: Target system - :param cmd: Command - :rtype: query result or None if expired. - """ - _LOGGER.debug("Checking cache for %s %s", target, cmd) - if cmd not in self.cache[target]: - return None - - cached = self.cache[target][cmd] - if cached and cached["last_updated"] is not None: - if cached[ - "last_updated" - ] + self.cache_ttl > datetime.utcnow() and cmd.startswith("get_"): - _LOGGER.debug("Got cached %s %s", target, cmd) - return self.cache[target][cmd] - else: - _LOGGER.debug("Invalidating the cache for %s cmd %s", target, cmd) - for cache_entry in self.cache[target].values(): - cache_entry["last_updated"] = datetime.utcfromtimestamp(0) - return None - - def _insert_to_cache(self, target: str, cmd: str, response: Dict) -> None: - """Add response for a given command to the cache. - - :param target: Target system - :param cmd: Command - :param response: Response to be cached - """ - self.cache[target][cmd] = response.copy() - self.cache[target][cmd]["last_updated"] = datetime.utcnow() - async def _query_helper( self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None ) -> Any: @@ -172,11 +133,7 @@ class SmartDevice: request = {"context": {"child_ids": child_ids}, target: {cmd: arg}} try: - response = self._result_from_cache(target, cmd) - if response is None: - _LOGGER.debug("Got no result from cache, querying the device.") - response = await self.protocol.query(host=self.host, request=request) - self._insert_to_cache(target, cmd, response) + response = await self.protocol.query(host=self.host, request=request) except Exception as ex: raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex diff --git a/kasa/smartplug.py b/kasa/smartplug.py index 949bc319..e21dda92 100644 --- a/kasa/smartplug.py +++ b/kasa/smartplug.py @@ -35,8 +35,8 @@ class SmartPlug(SmartDevice): and should be handled by the user of the library. """ - def __init__(self, host: str, *, cache_ttl: int = 3) -> None: - SmartDevice.__init__(self, host, cache_ttl=cache_ttl) + def __init__(self, host: str) -> None: + super().__init__(host) self.emeter_type = "emeter" self._device_type = DeviceType.Plug diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index 27ca8817..b382629c 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -50,7 +50,7 @@ class SmartStrip(SmartDevice): return True def __init__(self, host: str, *, cache_ttl: int = 3) -> None: - SmartDevice.__init__(self, host=host, cache_ttl=cache_ttl) + super().__init__(host=host) self.emeter_type = "emeter" self._device_type = DeviceType.Strip self.plugs: List[SmartStripPlug] = [] @@ -78,12 +78,7 @@ class SmartStrip(SmartDevice): _LOGGER.debug("Initializing %s child sockets", len(children)) for child in children: self.plugs.append( - SmartStripPlug( - self.host, - parent=self, - child_id=child["id"], - cache_ttl=self.cache_ttl.total_seconds(), - ) + SmartStripPlug(self.host, parent=self, child_id=child["id"]) ) async def turn_on(self): @@ -232,10 +227,8 @@ class SmartStripPlug(SmartPlug): on the parent device before accessing the properties. """ - def __init__( - self, host: str, parent: "SmartStrip", child_id: str, *, cache_ttl: int = 3 - ) -> None: - super().__init__(host, cache_ttl=cache_ttl) + def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None: + super().__init__(host) self.parent = parent self.child_id = child_id diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 464e764a..3b29a14c 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -116,8 +116,7 @@ def dev(request): with open(file) as f: sysinfo = json.load(f) model = basename(file) - params = {"host": "123.123.123.123", "cache_ttl": 0} - p = device_for_file(model)(**params) + p = device_for_file(model)(host="123.123.123.123") p.protocol = FakeTransportProtocol(sysinfo) loop.run_until_complete(p.update()) yield p diff --git a/kasa/tests/test_fixtures.py b/kasa/tests/test_fixtures.py index 0e3af70a..1861cc7a 100644 --- a/kasa/tests/test_fixtures.py +++ b/kasa/tests/test_fixtures.py @@ -471,40 +471,6 @@ async def test_all_binary_states(dev): assert state == state_map[index] -# def test_cache(dev): -# from datetime import timedelta - -# dev.cache_ttl = timedelta(seconds=3) -# with patch.object( -# FakeTransportProtocol, "query", wraps=dev.protocol.query -# ) as query_mock: -# CHECK_COUNT = 1 -# # Smartstrip calls sysinfo in its __init__ to request children, so -# # the even first get call here will get its results from the cache. -# if dev.is_strip: -# CHECK_COUNT = 0 - -# dev.sys_info -# assert query_mock.call_count == CHECK_COUNT -# dev.sys_info -# assert query_mock.call_count == CHECK_COUNT - - -# def test_cache_invalidates(dev): -# from datetime import timedelta - -# dev.cache_ttl = timedelta(seconds=0) - -# with patch.object( -# FakeTransportProtocol, "query", wraps=dev.protocol.query -# ) as query_mock: -# dev.sys_info -# assert query_mock.call_count == 1 -# dev.sys_info -# assert query_mock.call_count == 2 -# # assert query_mock.called_once() - - async def test_representation(dev): import re