Update DiscoveryResult to use Mashumaro instead of pydantic (#1231)

Mashumaro is faster and doesn't come with all versioning problems that
pydantic does.

A basic perf test deserializing all of our discovery results fixtures
shows mashumaro as being about 6 times faster deserializing dicts than
pydantic. It's much faster parsing from a json string but that's likely
because it uses orjson under the hood although that's not really our use
case at the moment.

```
PYDANTIC - ms
=================
json       dict
-----------------
4.7665     1.3268
3.1548     1.5922
3.1130     1.8039
4.2834     2.7606
2.0669     1.3757
2.0163     1.6377
3.1667     1.3561
4.1296     2.7297
2.0132     1.3471
4.0648     1.4105

MASHUMARO - ms
=================
json       dict
-----------------
0.5977     0.5543
0.5336     0.2983
0.3955     0.2549
0.6516     0.2742
0.5386     0.2706
0.6678     0.2580
0.4120     0.2511
0.3836     0.2472
0.4020     0.2465
0.4268     0.2487
```
This commit is contained in:
Steven B.
2024-11-12 21:00:04 +00:00
committed by GitHub
parent 9d5e07b969
commit 254a9af5c1
9 changed files with 81 additions and 41 deletions

View File

@@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None:
return
try:
dr = DiscoveryResult(**discovery_info)
dr = DiscoveryResult.from_dict(discovery_info)
except ValidationError:
_echo_dictionary(discovery_info)
return

View File

@@ -90,6 +90,7 @@ import secrets
import socket
import struct
from asyncio.transports import DatagramTransport
from dataclasses import dataclass, field
from pprint import pformat as pf
from typing import (
TYPE_CHECKING,
@@ -108,7 +109,8 @@ from aiohttp import ClientSession
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError
from mashumaro import field_options
from mashumaro.config import BaseConfig
from kasa import Device
from kasa.credentials import Credentials
@@ -130,6 +132,7 @@ from kasa.exceptions import (
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import DataClassJSONMixin
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import mask_mac, redact_data
@@ -647,7 +650,7 @@ class Discover:
def _get_device_class(info: dict) -> type[Device]:
"""Find SmartDevice subclass for device described by passed data."""
if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
discovery_result = DiscoveryResult.from_dict(info["result"])
https = discovery_result.mgt_encrypt_schm.is_support_https
dev_class = get_device_class_from_family(
discovery_result.device_type, https=https
@@ -721,12 +724,8 @@ class Discover:
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex:
discovery_result = DiscoveryResult.from_dict(info["result"])
except Exception as ex:
if debug_enabled:
data = (
redact_data(info, NEW_DISCOVERY_REDACTORS)
@@ -742,6 +741,16 @@ class Discover:
f"Unable to parse discovery from device: {config.host}: {ex}",
host=config.host,
) from ex
# Decrypt the data
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
try:
Discover._decrypt_discovery_data(discovery_result)
except Exception:
_LOGGER.exception(
"Unable to decrypt discovery data %s: %s", config.host, data
)
type_ = discovery_result.device_type
encrypt_schm = discovery_result.mgt_encrypt_schm
@@ -754,7 +763,7 @@ class Discover:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
config.connection_type = DeviceConnectionParameters.from_values(
@@ -767,7 +776,7 @@ class Discover:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
) from ex
if (
@@ -778,7 +787,7 @@ class Discover:
_LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
if (protocol := get_protocol(config)) is None:
@@ -788,7 +797,7 @@ class Discover:
raise UnsupportedDeviceError(
f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
@@ -801,22 +810,35 @@ class Discover:
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
device = device_class(config.host, protocol=protocol)
di = discovery_result.get_dict()
di = discovery_result.to_dict()
di["model"], _, _ = discovery_result.device_model.partition("(")
device.update_from_discover_info(di)
return device
class EncryptionScheme(BaseModel):
class _DiscoveryBaseMixin(DataClassJSONMixin):
"""Base class for serialization mixin."""
class Config(BaseConfig):
"""Serialization config."""
omit_none = True
omit_default = True
serialize_by_alias = True
@dataclass
class EncryptionScheme(_DiscoveryBaseMixin):
"""Base model for encryption scheme of discovery result."""
is_support_https: bool
encrypt_type: Optional[str] # noqa: UP007
encrypt_type: Optional[str] = None # noqa: UP007
http_port: Optional[int] = None # noqa: UP007
lv: Optional[int] = None # noqa: UP007
class EncryptionInfo(BaseModel):
@dataclass
class EncryptionInfo(_DiscoveryBaseMixin):
"""Base model for encryption info of discovery result."""
sym_schm: str
@@ -824,19 +846,23 @@ class EncryptionInfo(BaseModel):
data: str
class DiscoveryResult(BaseModel):
@dataclass
class DiscoveryResult(_DiscoveryBaseMixin):
"""Base model for discovery result."""
device_type: str
device_model: str
device_name: Optional[str] # noqa: UP007
device_id: str
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
device_name: Optional[str] = None # noqa: UP007
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
encrypt_type: Optional[list[str]] = None # noqa: UP007
decrypted_data: Optional[dict] = None # noqa: UP007
device_id: str
is_reset_wifi: Optional[bool] = field( # noqa: UP007
metadata=field_options(alias="isResetWiFi"), default=None
)
firmware_version: Optional[str] = None # noqa: UP007
hardware_version: Optional[str] = None # noqa: UP007
@@ -845,12 +871,3 @@ class DiscoveryResult(BaseModel):
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
obd_src: Optional[str] = None # noqa: UP007
factory_default: Optional[bool] = None # noqa: UP007
def get_dict(self) -> dict:
"""Return a dict for this discovery result.
containing only the values actually set and with aliases as field names.
"""
return self.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)

View File

@@ -21,3 +21,13 @@ except ImportError:
return json.dumps(obj, separators=(",", ":"))
loads = json.loads
try:
from mashumaro.mixins.orjson import DataClassORJSONMixin
DataClassJSONMixin = DataClassORJSONMixin
except ImportError:
from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin
DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc]