Improve performance of dict merge code (#1097)

Co-authored-by: Teemu R. <tpr@iki.fi>
This commit is contained in:
J. Nick Koston 2024-08-14 16:33:54 -05:00 committed by GitHub
parent 633f57dcce
commit 4669e08605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 20 deletions

View File

@ -14,7 +14,6 @@ http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations from __future__ import annotations
import collections.abc
import functools import functools
import inspect import inspect
import logging import logging
@ -29,22 +28,12 @@ from ..feature import Feature
from ..module import Module from ..module import Module
from ..modulemapping import ModuleMapping, ModuleName from ..modulemapping import ModuleMapping, ModuleName
from ..protocol import BaseProtocol from ..protocol import BaseProtocol
from .iotmodule import IotModule from .iotmodule import IotModule, merge
from .modules import Emeter from .modules import Emeter
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def merge(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = merge(d.get(k, {}), v)
else:
d[k] = v
return d
def requires_update(f): def requires_update(f):
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202 """Indicate that `update` should be called before accessing this method.""" # noqa: D202
if inspect.iscoroutinefunction(f): if inspect.iscoroutinefunction(f):

View File

@ -1,6 +1,5 @@
"""Base class for IOT module implementations.""" """Base class for IOT module implementations."""
import collections
import logging import logging
from ..exceptions import KasaException from ..exceptions import KasaException
@ -9,15 +8,17 @@ from ..module import Module
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# TODO: This is used for query constructing, check for a better place def _merge_dict(dest: dict, source: dict) -> dict:
def merge(d, u):
"""Update dict recursively.""" """Update dict recursively."""
for k, v in u.items(): for k, v in source.items():
if isinstance(v, collections.abc.Mapping): if k in dest and type(v) is dict: # noqa: E721 - only accepts `dict` type
d[k] = merge(d.get(k, {}), v) _merge_dict(dest[k], v)
else: else:
d[k] = v dest[k] = v
return d return dest
merge = _merge_dict
class IotModule(Module): class IotModule(Module):

View File

@ -18,6 +18,7 @@ from voluptuous import (
from kasa import KasaException, Module from kasa import KasaException, Module
from kasa.iot import IotDevice from kasa.iot import IotDevice
from kasa.iot.iotmodule import _merge_dict
from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
@ -292,3 +293,21 @@ async def test_get_modules():
module = dummy_device.modules.get(Module.Cloud) module = dummy_device.modules.get(Module.Cloud)
assert module is None assert module is None
def test_merge_dict():
"""Test the recursive dict merge."""
dest = {"a": 1, "b": {"c": 2, "d": 3}}
source = {"b": {"c": 4, "e": 5}}
assert _merge_dict(dest, source) == {"a": 1, "b": {"c": 4, "d": 3, "e": 5}}
dest = {"smartlife.iot.common.emeter": {"get_realtime": None}}
source = {
"smartlife.iot.common.emeter": {"get_daystat": {"month": 8, "year": 2024}}
}
assert _merge_dict(dest, source) == {
"smartlife.iot.common.emeter": {
"get_realtime": None,
"get_daystat": {"month": 8, "year": 2024},
}
}