Remove unnecessary cache (#40)

The cache was useful trick when the property accesses caused I/O,
which is unnecessary now as dev.update() does explicitly cache results until its called again.
This commit is contained in:
Teemu R 2020-04-12 15:57:49 +02:00 committed by GitHub
parent 5ff299664e
commit c90465c5dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 13 additions and 98 deletions

View File

@ -71,8 +71,8 @@ class SmartBulb(SmartDevice):
LIGHT_SERVICE = "smartlife.iot.smartbulb.lightingservice"
def __init__(self, host: str, *, cache_ttl: int = 3) -> None:
SmartDevice.__init__(self, host=host, cache_ttl=cache_ttl)
def __init__(self, host: str) -> None:
super().__init__(host=host)
self.emeter_type = "smartlife.iot.common.emeter"
self._device_type = DeviceType.Bulb
self._light_state = None

View File

@ -14,8 +14,7 @@ http://www.apache.org/licenses/LICENSE-2.0
import functools
import inspect
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
@ -103,7 +102,7 @@ def requires_update(f):
class SmartDevice:
"""Base class for all supported device types."""
def __init__(self, host: str, *, cache_ttl: int = 3) -> None:
def __init__(self, host: str) -> None:
"""Create a new SmartDevice instance.
:param str host: host name or ip address on which the device listens
@ -113,48 +112,10 @@ class SmartDevice:
self.protocol = TPLinkSmartHomeProtocol()
self.emeter_type = "emeter"
self.cache_ttl = timedelta(seconds=cache_ttl)
_LOGGER.debug("Initializing %s with cache ttl %s", self.host, self.cache_ttl)
self.cache = defaultdict(lambda: defaultdict(lambda: None)) # type: ignore
_LOGGER.debug("Initializing %s", self.host)
self._device_type = DeviceType.Unknown
self._sys_info: Optional[Dict] = None
def _result_from_cache(self, target, cmd) -> Optional[Dict]:
"""Return query result from cache if still fresh.
Only results from commands starting with `get_` are considered cacheable.
:param target: Target system
:param cmd: Command
:rtype: query result or None if expired.
"""
_LOGGER.debug("Checking cache for %s %s", target, cmd)
if cmd not in self.cache[target]:
return None
cached = self.cache[target][cmd]
if cached and cached["last_updated"] is not None:
if cached[
"last_updated"
] + self.cache_ttl > datetime.utcnow() and cmd.startswith("get_"):
_LOGGER.debug("Got cached %s %s", target, cmd)
return self.cache[target][cmd]
else:
_LOGGER.debug("Invalidating the cache for %s cmd %s", target, cmd)
for cache_entry in self.cache[target].values():
cache_entry["last_updated"] = datetime.utcfromtimestamp(0)
return None
def _insert_to_cache(self, target: str, cmd: str, response: Dict) -> None:
"""Add response for a given command to the cache.
:param target: Target system
:param cmd: Command
:param response: Response to be cached
"""
self.cache[target][cmd] = response.copy()
self.cache[target][cmd]["last_updated"] = datetime.utcnow()
async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
) -> Any:
@ -172,11 +133,7 @@ class SmartDevice:
request = {"context": {"child_ids": child_ids}, target: {cmd: arg}}
try:
response = self._result_from_cache(target, cmd)
if response is None:
_LOGGER.debug("Got no result from cache, querying the device.")
response = await self.protocol.query(host=self.host, request=request)
self._insert_to_cache(target, cmd, response)
response = await self.protocol.query(host=self.host, request=request)
except Exception as ex:
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex

View File

@ -35,8 +35,8 @@ class SmartPlug(SmartDevice):
and should be handled by the user of the library.
"""
def __init__(self, host: str, *, cache_ttl: int = 3) -> None:
SmartDevice.__init__(self, host, cache_ttl=cache_ttl)
def __init__(self, host: str) -> None:
super().__init__(host)
self.emeter_type = "emeter"
self._device_type = DeviceType.Plug

View File

@ -50,7 +50,7 @@ class SmartStrip(SmartDevice):
return True
def __init__(self, host: str, *, cache_ttl: int = 3) -> None:
SmartDevice.__init__(self, host=host, cache_ttl=cache_ttl)
super().__init__(host=host)
self.emeter_type = "emeter"
self._device_type = DeviceType.Strip
self.plugs: List[SmartStripPlug] = []
@ -78,12 +78,7 @@ class SmartStrip(SmartDevice):
_LOGGER.debug("Initializing %s child sockets", len(children))
for child in children:
self.plugs.append(
SmartStripPlug(
self.host,
parent=self,
child_id=child["id"],
cache_ttl=self.cache_ttl.total_seconds(),
)
SmartStripPlug(self.host, parent=self, child_id=child["id"])
)
async def turn_on(self):
@ -232,10 +227,8 @@ class SmartStripPlug(SmartPlug):
on the parent device before accessing the properties.
"""
def __init__(
self, host: str, parent: "SmartStrip", child_id: str, *, cache_ttl: int = 3
) -> None:
super().__init__(host, cache_ttl=cache_ttl)
def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None:
super().__init__(host)
self.parent = parent
self.child_id = child_id

View File

@ -116,8 +116,7 @@ def dev(request):
with open(file) as f:
sysinfo = json.load(f)
model = basename(file)
params = {"host": "123.123.123.123", "cache_ttl": 0}
p = device_for_file(model)(**params)
p = device_for_file(model)(host="123.123.123.123")
p.protocol = FakeTransportProtocol(sysinfo)
loop.run_until_complete(p.update())
yield p

View File

@ -471,40 +471,6 @@ async def test_all_binary_states(dev):
assert state == state_map[index]
# def test_cache(dev):
# from datetime import timedelta
# dev.cache_ttl = timedelta(seconds=3)
# with patch.object(
# FakeTransportProtocol, "query", wraps=dev.protocol.query
# ) as query_mock:
# CHECK_COUNT = 1
# # Smartstrip calls sysinfo in its __init__ to request children, so
# # the even first get call here will get its results from the cache.
# if dev.is_strip:
# CHECK_COUNT = 0
# dev.sys_info
# assert query_mock.call_count == CHECK_COUNT
# dev.sys_info
# assert query_mock.call_count == CHECK_COUNT
# def test_cache_invalidates(dev):
# from datetime import timedelta
# dev.cache_ttl = timedelta(seconds=0)
# with patch.object(
# FakeTransportProtocol, "query", wraps=dev.protocol.query
# ) as query_mock:
# dev.sys_info
# assert query_mock.call_count == 1
# dev.sys_info
# assert query_mock.call_count == 2
# # assert query_mock.called_once()
async def test_representation(dev):
import re