diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 234ea9fe..1c8b311c 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -14,7 +14,6 @@ http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations -import collections.abc import functools import inspect import logging @@ -29,22 +28,12 @@ from ..feature import Feature from ..module import Module from ..modulemapping import ModuleMapping, ModuleName from ..protocol import BaseProtocol -from .iotmodule import IotModule +from .iotmodule import IotModule, merge from .modules import Emeter _LOGGER = logging.getLogger(__name__) -def merge(d, u): - """Update dict recursively.""" - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = merge(d.get(k, {}), v) - else: - d[k] = v - return d - - def requires_update(f): """Indicate that `update` should be called before accessing this method.""" # noqa: D202 if inspect.iscoroutinefunction(f): diff --git a/kasa/iot/iotmodule.py b/kasa/iot/iotmodule.py index ca0c3adb..7829c856 100644 --- a/kasa/iot/iotmodule.py +++ b/kasa/iot/iotmodule.py @@ -1,6 +1,5 @@ """Base class for IOT module implementations.""" -import collections import logging from ..exceptions import KasaException @@ -9,15 +8,17 @@ from ..module import Module _LOGGER = logging.getLogger(__name__) -# TODO: This is used for query constructing, check for a better place -def merge(d, u): +def _merge_dict(dest: dict, source: dict) -> dict: """Update dict recursively.""" - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = merge(d.get(k, {}), v) + for k, v in source.items(): + if k in dest and type(v) is dict: # noqa: E721 - only accepts `dict` type + _merge_dict(dest[k], v) else: - d[k] = v - return d + dest[k] = v + return dest + + +merge = _merge_dict class IotModule(Module): diff --git a/kasa/tests/test_iotdevice.py b/kasa/tests/test_iotdevice.py index df37f762..976144fc 100644 --- a/kasa/tests/test_iotdevice.py +++ b/kasa/tests/test_iotdevice.py @@ -18,6 +18,7 @@ from voluptuous import ( from kasa import KasaException, Module from kasa.iot import IotDevice +from kasa.iot.iotmodule import _merge_dict from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot @@ -292,3 +293,21 @@ async def test_get_modules(): module = dummy_device.modules.get(Module.Cloud) assert module is None + + +def test_merge_dict(): + """Test the recursive dict merge.""" + dest = {"a": 1, "b": {"c": 2, "d": 3}} + source = {"b": {"c": 4, "e": 5}} + assert _merge_dict(dest, source) == {"a": 1, "b": {"c": 4, "d": 3, "e": 5}} + + dest = {"smartlife.iot.common.emeter": {"get_realtime": None}} + source = { + "smartlife.iot.common.emeter": {"get_daystat": {"month": 8, "year": 2024}} + } + assert _merge_dict(dest, source) == { + "smartlife.iot.common.emeter": { + "get_realtime": None, + "get_daystat": {"month": 8, "year": 2024}, + } + }