diff --git a/kasa/cachedzoneinfo.py b/kasa/cachedzoneinfo.py new file mode 100644 index 00000000..c70e8309 --- /dev/null +++ b/kasa/cachedzoneinfo.py @@ -0,0 +1,28 @@ +"""Module for caching ZoneInfos.""" + +from __future__ import annotations + +import asyncio + +from zoneinfo import ZoneInfo + + +class CachedZoneInfo(ZoneInfo): + """Cache ZoneInfo objects.""" + + _cache: dict[str, ZoneInfo] = {} + + @classmethod + async def get_cached_zone_info(cls, time_zone_str: str) -> ZoneInfo: + """Get a cached zone info object.""" + if cached := cls._cache.get(time_zone_str): + return cached + loop = asyncio.get_running_loop() + zinfo = await loop.run_in_executor(None, _get_zone_info, time_zone_str) + cls._cache[time_zone_str] = zinfo + return zinfo + + +def _get_zone_info(time_zone_str: str) -> ZoneInfo: + """Get a time zone object for the given time zone string.""" + return ZoneInfo(time_zone_str) diff --git a/kasa/iot/iottimezone.py b/kasa/iot/iottimezone.py index 53cb219e..ccbed3e7 100644 --- a/kasa/iot/iottimezone.py +++ b/kasa/iot/iottimezone.py @@ -2,11 +2,10 @@ from __future__ import annotations -import asyncio import logging from datetime import datetime, tzinfo -from zoneinfo import ZoneInfo +from ..cachedzoneinfo import CachedZoneInfo _LOGGER = logging.getLogger(__name__) @@ -17,10 +16,10 @@ async def get_timezone(index: int) -> tzinfo: _LOGGER.error( "Unexpected index %s not configured as a timezone, defaulting to UTC", index ) - return await _CachedZoneInfo.get_cached_zone_info("Etc/UTC") + return await CachedZoneInfo.get_cached_zone_info("Etc/UTC") name = TIMEZONE_INDEX[index] - return await _CachedZoneInfo.get_cached_zone_info(name) + return await CachedZoneInfo.get_cached_zone_info(name) async def get_timezone_index(name: str) -> int: @@ -30,7 +29,7 @@ async def get_timezone_index(name: str) -> int: return rev[name] # Try to find a supported timezone matching dst true/false - zone = await _CachedZoneInfo.get_cached_zone_info(name) + zone = await CachedZoneInfo.get_cached_zone_info(name) now = datetime.now() winter = datetime(now.year, 1, 1, 12) summer = datetime(now.year, 7, 1, 12) @@ -43,27 +42,6 @@ async def get_timezone_index(name: str) -> int: raise ValueError("Device does not support timezone %s", name) -class _CachedZoneInfo(ZoneInfo): - """Cache zone info objects.""" - - _cache: dict[str, ZoneInfo] = {} - - @classmethod - async def get_cached_zone_info(cls, time_zone_str: str) -> ZoneInfo: - """Get a cached zone info object.""" - if cached := cls._cache.get(time_zone_str): - return cached - loop = asyncio.get_running_loop() - zinfo = await loop.run_in_executor(None, _get_zone_info, time_zone_str) - cls._cache[time_zone_str] = zinfo - return zinfo - - -def _get_zone_info(time_zone_str: str) -> ZoneInfo: - """Get a time zone object for the given time zone string.""" - return ZoneInfo(time_zone_str) - - TIMEZONE_INDEX = { 0: "Etc/GMT+12", 1: "Pacific/Samoa", diff --git a/kasa/smart/modules/time.py b/kasa/smart/modules/time.py index 13831b2e..21dd13a4 100644 --- a/kasa/smart/modules/time.py +++ b/kasa/smart/modules/time.py @@ -6,8 +6,9 @@ from datetime import datetime, timedelta, timezone, tzinfo from time import mktime from typing import cast -from zoneinfo import ZoneInfo, ZoneInfoNotFoundError +from zoneinfo import ZoneInfoNotFoundError +from ...cachedzoneinfo import CachedZoneInfo from ...feature import Feature from ..smartmodule import SmartModule @@ -18,6 +19,8 @@ class Time(SmartModule): REQUIRED_COMPONENT = "time" QUERY_GETTER_NAME = "get_device_time" + _timezone: tzinfo = timezone.utc + def _initialize_features(self): """Initialize features after the initial update.""" self._add_feature( @@ -32,21 +35,25 @@ class Time(SmartModule): ) ) - @property - def timezone(self) -> tzinfo: - """Return current timezone.""" + async def _post_update_hook(self): + """Perform actions after a device update.""" td = timedelta(minutes=cast(float, self.data.get("time_diff"))) if region := self.data.get("region"): try: # Zoneinfo will return a DST aware object - tz: tzinfo = ZoneInfo(region) + tz: tzinfo = await CachedZoneInfo.get_cached_zone_info(region) except ZoneInfoNotFoundError: tz = timezone(td, region) else: # in case the device returns a blank region this will result in the # tzname being a UTC offset tz = timezone(td) - return tz + self._timezone = tz + + @property + def timezone(self) -> tzinfo: + """Return current timezone.""" + return self._timezone @property def time(self) -> datetime: