From 254a9af5c1f057a6bbbf5a87363bbc5422a3889b Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 12 Nov 2024 21:00:04 +0000 Subject: [PATCH] Update DiscoveryResult to use Mashumaro instead of pydantic (#1231) Mashumaro is faster and doesn't come with all versioning problems that pydantic does. A basic perf test deserializing all of our discovery results fixtures shows mashumaro as being about 6 times faster deserializing dicts than pydantic. It's much faster parsing from a json string but that's likely because it uses orjson under the hood although that's not really our use case at the moment. ``` PYDANTIC - ms ================= json dict ----------------- 4.7665 1.3268 3.1548 1.5922 3.1130 1.8039 4.2834 2.7606 2.0669 1.3757 2.0163 1.6377 3.1667 1.3561 4.1296 2.7297 2.0132 1.3471 4.0648 1.4105 MASHUMARO - ms ================= json dict ----------------- 0.5977 0.5543 0.5336 0.2983 0.3955 0.2549 0.6516 0.2742 0.5386 0.2706 0.6678 0.2580 0.4120 0.2511 0.3836 0.2472 0.4020 0.2465 0.4268 0.2487 ``` --- devtools/dump_devinfo.py | 12 +++--- kasa/cli/discover.py | 2 +- kasa/discover.py | 73 ++++++++++++++++++++++-------------- kasa/json.py | 10 +++++ pyproject.toml | 1 + tests/test_cli.py | 2 +- tests/test_device_factory.py | 2 +- tests/test_discovery.py | 6 +-- uv.lock | 14 +++++++ 9 files changed, 81 insertions(+), 41 deletions(-) diff --git a/devtools/dump_devinfo.py b/devtools/dump_devinfo.py index 5cbfff76..83df9dcd 100644 --- a/devtools/dump_devinfo.py +++ b/devtools/dump_devinfo.py @@ -319,7 +319,7 @@ async def cli( click.echo("Host and discovery info given, trying connect on %s." % host) di = json.loads(discovery_info) - dr = DiscoveryResult(**di) + dr = DiscoveryResult.from_dict(di) connection_type = DeviceConnectionParameters.from_values( dr.device_type, dr.mgt_encrypt_schm.encrypt_type, @@ -336,7 +336,7 @@ async def cli( basedir, autosave, device.protocol, - discovery_info=dr.get_dict(), + discovery_info=dr.to_dict(), batch_size=batch_size, ) elif device_family and encrypt_type: @@ -443,7 +443,7 @@ async def get_legacy_fixture(protocol, *, discovery_info): if discovery_info and not discovery_info.get("system"): # Need to recreate a DiscoverResult here because we don't want the aliases # in the fixture, we want the actual field names as returned by the device. - dr = DiscoveryResult(**protocol._discovery_info) + dr = DiscoveryResult.from_dict(protocol._discovery_info) final["discovery_result"] = dr.dict( by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True ) @@ -960,10 +960,8 @@ async def get_smart_fixtures( # Need to recreate a DiscoverResult here because we don't want the aliases # in the fixture, we want the actual field names as returned by the device. if discovery_info: - dr = DiscoveryResult(**discovery_info) # type: ignore - final["discovery_result"] = dr.dict( - by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True - ) + dr = DiscoveryResult.from_dict(discovery_info) # type: ignore + final["discovery_result"] = dr.to_dict() click.echo("Got %s successes" % len(successes)) click.echo(click.style("## device info file ##", bold=True)) diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index 8df59de8..3ebb4a9f 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None: return try: - dr = DiscoveryResult(**discovery_info) + dr = DiscoveryResult.from_dict(discovery_info) except ValidationError: _echo_dictionary(discovery_info) return diff --git a/kasa/discover.py b/kasa/discover.py index bed43e85..d1240aa8 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -90,6 +90,7 @@ import secrets import socket import struct from asyncio.transports import DatagramTransport +from dataclasses import dataclass, field from pprint import pformat as pf from typing import ( TYPE_CHECKING, @@ -108,7 +109,8 @@ from aiohttp import ClientSession # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout -from pydantic.v1 import BaseModel, ValidationError +from mashumaro import field_options +from mashumaro.config import BaseConfig from kasa import Device from kasa.credentials import Credentials @@ -130,6 +132,7 @@ from kasa.exceptions import ( from kasa.experimental import Experimental from kasa.iot.iotdevice import IotDevice from kasa.iotprotocol import REDACTORS as IOT_REDACTORS +from kasa.json import DataClassJSONMixin from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import mask_mac, redact_data @@ -647,7 +650,7 @@ class Discover: def _get_device_class(info: dict) -> type[Device]: """Find SmartDevice subclass for device described by passed data.""" if "result" in info: - discovery_result = DiscoveryResult(**info["result"]) + discovery_result = DiscoveryResult.from_dict(info["result"]) https = discovery_result.mgt_encrypt_schm.is_support_https dev_class = get_device_class_from_family( discovery_result.device_type, https=https @@ -721,12 +724,8 @@ class Discover: f"Unable to read response from device: {config.host}: {ex}" ) from ex try: - discovery_result = DiscoveryResult(**info["result"]) - if ( - encrypt_info := discovery_result.encrypt_info - ) and encrypt_info.sym_schm == "AES": - Discover._decrypt_discovery_data(discovery_result) - except ValidationError as ex: + discovery_result = DiscoveryResult.from_dict(info["result"]) + except Exception as ex: if debug_enabled: data = ( redact_data(info, NEW_DISCOVERY_REDACTORS) @@ -742,6 +741,16 @@ class Discover: f"Unable to parse discovery from device: {config.host}: {ex}", host=config.host, ) from ex + # Decrypt the data + if ( + encrypt_info := discovery_result.encrypt_info + ) and encrypt_info.sym_schm == "AES": + try: + Discover._decrypt_discovery_data(discovery_result) + except Exception: + _LOGGER.exception( + "Unable to decrypt discovery data %s: %s", config.host, data + ) type_ = discovery_result.device_type encrypt_schm = discovery_result.mgt_encrypt_schm @@ -754,7 +763,7 @@ class Discover: raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_} " + "with no encryption type", - discovery_result=discovery_result.get_dict(), + discovery_result=discovery_result.to_dict(), host=config.host, ) config.connection_type = DeviceConnectionParameters.from_values( @@ -767,7 +776,7 @@ class Discover: raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_} " + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", - discovery_result=discovery_result.get_dict(), + discovery_result=discovery_result.to_dict(), host=config.host, ) from ex if ( @@ -778,7 +787,7 @@ class Discover: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceError( f"Unsupported device {config.host} of type {type_}: {info}", - discovery_result=discovery_result.get_dict(), + discovery_result=discovery_result.to_dict(), host=config.host, ) if (protocol := get_protocol(config)) is None: @@ -788,7 +797,7 @@ class Discover: raise UnsupportedDeviceError( f"Unsupported encryption scheme {config.host} of " + f"type {config.connection_type.to_dict()}: {info}", - discovery_result=discovery_result.get_dict(), + discovery_result=discovery_result.to_dict(), host=config.host, ) @@ -801,22 +810,35 @@ class Discover: _LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data)) device = device_class(config.host, protocol=protocol) - di = discovery_result.get_dict() + di = discovery_result.to_dict() di["model"], _, _ = discovery_result.device_model.partition("(") device.update_from_discover_info(di) return device -class EncryptionScheme(BaseModel): +class _DiscoveryBaseMixin(DataClassJSONMixin): + """Base class for serialization mixin.""" + + class Config(BaseConfig): + """Serialization config.""" + + omit_none = True + omit_default = True + serialize_by_alias = True + + +@dataclass +class EncryptionScheme(_DiscoveryBaseMixin): """Base model for encryption scheme of discovery result.""" is_support_https: bool - encrypt_type: Optional[str] # noqa: UP007 + encrypt_type: Optional[str] = None # noqa: UP007 http_port: Optional[int] = None # noqa: UP007 lv: Optional[int] = None # noqa: UP007 -class EncryptionInfo(BaseModel): +@dataclass +class EncryptionInfo(_DiscoveryBaseMixin): """Base model for encryption info of discovery result.""" sym_schm: str @@ -824,19 +846,23 @@ class EncryptionInfo(BaseModel): data: str -class DiscoveryResult(BaseModel): +@dataclass +class DiscoveryResult(_DiscoveryBaseMixin): """Base model for discovery result.""" device_type: str device_model: str - device_name: Optional[str] # noqa: UP007 + device_id: str ip: str mac: str mgt_encrypt_schm: EncryptionScheme + device_name: Optional[str] = None # noqa: UP007 encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007 encrypt_type: Optional[list[str]] = None # noqa: UP007 decrypted_data: Optional[dict] = None # noqa: UP007 - device_id: str + is_reset_wifi: Optional[bool] = field( # noqa: UP007 + metadata=field_options(alias="isResetWiFi"), default=None + ) firmware_version: Optional[str] = None # noqa: UP007 hardware_version: Optional[str] = None # noqa: UP007 @@ -845,12 +871,3 @@ class DiscoveryResult(BaseModel): is_support_iot_cloud: Optional[bool] = None # noqa: UP007 obd_src: Optional[str] = None # noqa: UP007 factory_default: Optional[bool] = None # noqa: UP007 - - def get_dict(self) -> dict: - """Return a dict for this discovery result. - - containing only the values actually set and with aliases as field names. - """ - return self.dict( - by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True - ) diff --git a/kasa/json.py b/kasa/json.py index 10edc690..6f1149fa 100755 --- a/kasa/json.py +++ b/kasa/json.py @@ -21,3 +21,13 @@ except ImportError: return json.dumps(obj, separators=(",", ":")) loads = json.loads + + +try: + from mashumaro.mixins.orjson import DataClassORJSONMixin + + DataClassJSONMixin = DataClassORJSONMixin +except ImportError: + from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin + + DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc] diff --git a/pyproject.toml b/pyproject.toml index 92ef7bbe..44959c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "aiohttp>=3", "typing-extensions>=4.12.2,<5.0", "tzdata>=2024.2 ; platform_system == 'Windows'", + "mashumaro>=3.14", ] classifiers = [ diff --git a/tests/test_cli.py b/tests/test_cli.py index d22bb112..b6bcdfd4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -616,7 +616,7 @@ async def test_credentials(discovery_mock, mocker, runner): mocker.patch("kasa.cli.device.state", new=_state) - dr = DiscoveryResult(**discovery_mock.discovery_data["result"]) + dr = DiscoveryResult.from_dict(discovery_mock.discovery_data["result"]) res = await runner.invoke( cli, [ diff --git a/tests/test_device_factory.py b/tests/test_device_factory.py index 8690e580..0042d6e2 100644 --- a/tests/test_device_factory.py +++ b/tests/test_device_factory.py @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.requires_dummy] def _get_connection_type_device_class(discovery_info): if "result" in discovery_info: device_class = Discover._get_device_class(discovery_info) - dr = DiscoveryResult(**discovery_info["result"]) + dr = DiscoveryResult.from_dict(discovery_info["result"]) connection_type = DeviceConnectionParameters.from_values( dr.device_type, dr.mgt_encrypt_schm.encrypt_type diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 32330dca..aeda423e 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -391,8 +391,8 @@ async def test_device_update_from_new_discovery_info(discovery_mock): discovery_data = discovery_mock.discovery_data device_class = Discover._get_device_class(discovery_data) device = device_class("127.0.0.1") - discover_info = DiscoveryResult(**discovery_data["result"]) - discover_dump = discover_info.get_dict() + discover_info = DiscoveryResult.from_dict(discovery_data["result"]) + discover_dump = discover_info.to_dict() model, _, _ = discover_dump["device_model"].partition("(") discover_dump["model"] = model device.update_from_discover_info(discover_dump) @@ -652,7 +652,7 @@ async def test_discovery_decryption(): "sym_schm": "AES", } info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info} - dr = DiscoveryResult(**info) + dr = DiscoveryResult.from_dict(info) Discover._decrypt_discovery_data(dr) assert dr.decrypted_data == data_dict diff --git a/uv.lock b/uv.lock index 27a1100a..79a6f989 100644 --- a/uv.lock +++ b/uv.lock @@ -831,6 +831,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 }, ] +[[package]] +name = "mashumaro" +version = "3.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/47/0a450b281bef2d7e97ec02c8e1168d821e283f58e02e6c403b2bb4d73c1c/mashumaro-3.14.tar.gz", hash = "sha256:5ef6f2b963892cbe9a4ceb3441dfbea37f8c3412523f25d42e9b3a7186555f1d", size = 166160 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/35/8d63733a2c12149d0c7663c29bf626bdbeea5f0ff963afe58a42b4810981/mashumaro-3.14-py3-none-any.whl", hash = "sha256:c12a649599a8f7b1a0b35d18f12e678423c3066189f7bc7bd8dd431c5c8132c3", size = 92183 }, +] + [[package]] name = "mdit-py-plugins" version = "0.3.5" @@ -1494,6 +1506,7 @@ dependencies = [ { name = "async-timeout" }, { name = "asyncclick" }, { name = "cryptography" }, + { name = "mashumaro" }, { name = "pydantic" }, { name = "typing-extensions" }, { name = "tzdata", marker = "platform_system == 'Windows'" }, @@ -1544,6 +1557,7 @@ requires-dist = [ { name = "cryptography", specifier = ">=1.9" }, { name = "docutils", marker = "extra == 'docs'", specifier = ">=0.17" }, { name = "kasa-crypt", marker = "extra == 'speedups'", specifier = ">=0.2.0" }, + { name = "mashumaro", specifier = ">=3.14" }, { name = "myst-parser", marker = "extra == 'docs'" }, { name = "orjson", marker = "extra == 'speedups'", specifier = ">=3.9.1" }, { name = "ptpython", marker = "extra == 'shell'" },