Update DiscoveryResult to use Mashumaro instead of pydantic (#1231)

Mashumaro is faster and doesn't come with all versioning problems that
pydantic does.

A basic perf test deserializing all of our discovery results fixtures
shows mashumaro as being about 6 times faster deserializing dicts than
pydantic. It's much faster parsing from a json string but that's likely
because it uses orjson under the hood although that's not really our use
case at the moment.

```
PYDANTIC - ms
=================
json       dict
-----------------
4.7665     1.3268
3.1548     1.5922
3.1130     1.8039
4.2834     2.7606
2.0669     1.3757
2.0163     1.6377
3.1667     1.3561
4.1296     2.7297
2.0132     1.3471
4.0648     1.4105

MASHUMARO - ms
=================
json       dict
-----------------
0.5977     0.5543
0.5336     0.2983
0.3955     0.2549
0.6516     0.2742
0.5386     0.2706
0.6678     0.2580
0.4120     0.2511
0.3836     0.2472
0.4020     0.2465
0.4268     0.2487
```
This commit is contained in:
Steven B. 2024-11-12 21:00:04 +00:00 committed by GitHub
parent 9d5e07b969
commit 254a9af5c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 81 additions and 41 deletions

View File

@ -319,7 +319,7 @@ async def cli(
click.echo("Host and discovery info given, trying connect on %s." % host)
di = json.loads(discovery_info)
dr = DiscoveryResult(**di)
dr = DiscoveryResult.from_dict(di)
connection_type = DeviceConnectionParameters.from_values(
dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
@ -336,7 +336,7 @@ async def cli(
basedir,
autosave,
device.protocol,
discovery_info=dr.get_dict(),
discovery_info=dr.to_dict(),
batch_size=batch_size,
)
elif device_family and encrypt_type:
@ -443,7 +443,7 @@ async def get_legacy_fixture(protocol, *, discovery_info):
if discovery_info and not discovery_info.get("system"):
# Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device.
dr = DiscoveryResult(**protocol._discovery_info)
dr = DiscoveryResult.from_dict(protocol._discovery_info)
final["discovery_result"] = dr.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)
@ -960,10 +960,8 @@ async def get_smart_fixtures(
# Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device.
if discovery_info:
dr = DiscoveryResult(**discovery_info) # type: ignore
final["discovery_result"] = dr.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)
dr = DiscoveryResult.from_dict(discovery_info) # type: ignore
final["discovery_result"] = dr.to_dict()
click.echo("Got %s successes" % len(successes))
click.echo(click.style("## device info file ##", bold=True))

View File

@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None:
return
try:
dr = DiscoveryResult(**discovery_info)
dr = DiscoveryResult.from_dict(discovery_info)
except ValidationError:
_echo_dictionary(discovery_info)
return

View File

@ -90,6 +90,7 @@ import secrets
import socket
import struct
from asyncio.transports import DatagramTransport
from dataclasses import dataclass, field
from pprint import pformat as pf
from typing import (
TYPE_CHECKING,
@ -108,7 +109,8 @@ from aiohttp import ClientSession
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError
from mashumaro import field_options
from mashumaro.config import BaseConfig
from kasa import Device
from kasa.credentials import Credentials
@ -130,6 +132,7 @@ from kasa.exceptions import (
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import DataClassJSONMixin
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import mask_mac, redact_data
@ -647,7 +650,7 @@ class Discover:
def _get_device_class(info: dict) -> type[Device]:
"""Find SmartDevice subclass for device described by passed data."""
if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
discovery_result = DiscoveryResult.from_dict(info["result"])
https = discovery_result.mgt_encrypt_schm.is_support_https
dev_class = get_device_class_from_family(
discovery_result.device_type, https=https
@ -721,12 +724,8 @@ class Discover:
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex:
discovery_result = DiscoveryResult.from_dict(info["result"])
except Exception as ex:
if debug_enabled:
data = (
redact_data(info, NEW_DISCOVERY_REDACTORS)
@ -742,6 +741,16 @@ class Discover:
f"Unable to parse discovery from device: {config.host}: {ex}",
host=config.host,
) from ex
# Decrypt the data
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
try:
Discover._decrypt_discovery_data(discovery_result)
except Exception:
_LOGGER.exception(
"Unable to decrypt discovery data %s: %s", config.host, data
)
type_ = discovery_result.device_type
encrypt_schm = discovery_result.mgt_encrypt_schm
@ -754,7 +763,7 @@ class Discover:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
config.connection_type = DeviceConnectionParameters.from_values(
@ -767,7 +776,7 @@ class Discover:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
) from ex
if (
@ -778,7 +787,7 @@ class Discover:
_LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
if (protocol := get_protocol(config)) is None:
@ -788,7 +797,7 @@ class Discover:
raise UnsupportedDeviceError(
f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
@ -801,22 +810,35 @@ class Discover:
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
device = device_class(config.host, protocol=protocol)
di = discovery_result.get_dict()
di = discovery_result.to_dict()
di["model"], _, _ = discovery_result.device_model.partition("(")
device.update_from_discover_info(di)
return device
class EncryptionScheme(BaseModel):
class _DiscoveryBaseMixin(DataClassJSONMixin):
"""Base class for serialization mixin."""
class Config(BaseConfig):
"""Serialization config."""
omit_none = True
omit_default = True
serialize_by_alias = True
@dataclass
class EncryptionScheme(_DiscoveryBaseMixin):
"""Base model for encryption scheme of discovery result."""
is_support_https: bool
encrypt_type: Optional[str] # noqa: UP007
encrypt_type: Optional[str] = None # noqa: UP007
http_port: Optional[int] = None # noqa: UP007
lv: Optional[int] = None # noqa: UP007
class EncryptionInfo(BaseModel):
@dataclass
class EncryptionInfo(_DiscoveryBaseMixin):
"""Base model for encryption info of discovery result."""
sym_schm: str
@ -824,19 +846,23 @@ class EncryptionInfo(BaseModel):
data: str
class DiscoveryResult(BaseModel):
@dataclass
class DiscoveryResult(_DiscoveryBaseMixin):
"""Base model for discovery result."""
device_type: str
device_model: str
device_name: Optional[str] # noqa: UP007
device_id: str
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
device_name: Optional[str] = None # noqa: UP007
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
encrypt_type: Optional[list[str]] = None # noqa: UP007
decrypted_data: Optional[dict] = None # noqa: UP007
device_id: str
is_reset_wifi: Optional[bool] = field( # noqa: UP007
metadata=field_options(alias="isResetWiFi"), default=None
)
firmware_version: Optional[str] = None # noqa: UP007
hardware_version: Optional[str] = None # noqa: UP007
@ -845,12 +871,3 @@ class DiscoveryResult(BaseModel):
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
obd_src: Optional[str] = None # noqa: UP007
factory_default: Optional[bool] = None # noqa: UP007
def get_dict(self) -> dict:
"""Return a dict for this discovery result.
containing only the values actually set and with aliases as field names.
"""
return self.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)

View File

@ -21,3 +21,13 @@ except ImportError:
return json.dumps(obj, separators=(",", ":"))
loads = json.loads
try:
from mashumaro.mixins.orjson import DataClassORJSONMixin
DataClassJSONMixin = DataClassORJSONMixin
except ImportError:
from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin
DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc]

View File

@ -14,6 +14,7 @@ dependencies = [
"aiohttp>=3",
"typing-extensions>=4.12.2,<5.0",
"tzdata>=2024.2 ; platform_system == 'Windows'",
"mashumaro>=3.14",
]
classifiers = [

View File

@ -616,7 +616,7 @@ async def test_credentials(discovery_mock, mocker, runner):
mocker.patch("kasa.cli.device.state", new=_state)
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
dr = DiscoveryResult.from_dict(discovery_mock.discovery_data["result"])
res = await runner.invoke(
cli,
[

View File

@ -43,7 +43,7 @@ pytestmark = [pytest.mark.requires_dummy]
def _get_connection_type_device_class(discovery_info):
if "result" in discovery_info:
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult(**discovery_info["result"])
dr = DiscoveryResult.from_dict(discovery_info["result"])
connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.mgt_encrypt_schm.encrypt_type

View File

@ -391,8 +391,8 @@ async def test_device_update_from_new_discovery_info(discovery_mock):
discovery_data = discovery_mock.discovery_data
device_class = Discover._get_device_class(discovery_data)
device = device_class("127.0.0.1")
discover_info = DiscoveryResult(**discovery_data["result"])
discover_dump = discover_info.get_dict()
discover_info = DiscoveryResult.from_dict(discovery_data["result"])
discover_dump = discover_info.to_dict()
model, _, _ = discover_dump["device_model"].partition("(")
discover_dump["model"] = model
device.update_from_discover_info(discover_dump)
@ -652,7 +652,7 @@ async def test_discovery_decryption():
"sym_schm": "AES",
}
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
dr = DiscoveryResult(**info)
dr = DiscoveryResult.from_dict(info)
Discover._decrypt_discovery_data(dr)
assert dr.decrypted_data == data_dict

14
uv.lock
View File

@ -831,6 +831,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 },
]
[[package]]
name = "mashumaro"
version = "3.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/eb/47/0a450b281bef2d7e97ec02c8e1168d821e283f58e02e6c403b2bb4d73c1c/mashumaro-3.14.tar.gz", hash = "sha256:5ef6f2b963892cbe9a4ceb3441dfbea37f8c3412523f25d42e9b3a7186555f1d", size = 166160 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1b/35/8d63733a2c12149d0c7663c29bf626bdbeea5f0ff963afe58a42b4810981/mashumaro-3.14-py3-none-any.whl", hash = "sha256:c12a649599a8f7b1a0b35d18f12e678423c3066189f7bc7bd8dd431c5c8132c3", size = 92183 },
]
[[package]]
name = "mdit-py-plugins"
version = "0.3.5"
@ -1494,6 +1506,7 @@ dependencies = [
{ name = "async-timeout" },
{ name = "asyncclick" },
{ name = "cryptography" },
{ name = "mashumaro" },
{ name = "pydantic" },
{ name = "typing-extensions" },
{ name = "tzdata", marker = "platform_system == 'Windows'" },
@ -1544,6 +1557,7 @@ requires-dist = [
{ name = "cryptography", specifier = ">=1.9" },
{ name = "docutils", marker = "extra == 'docs'", specifier = ">=0.17" },
{ name = "kasa-crypt", marker = "extra == 'speedups'", specifier = ">=0.2.0" },
{ name = "mashumaro", specifier = ">=3.14" },
{ name = "myst-parser", marker = "extra == 'docs'" },
{ name = "orjson", marker = "extra == 'speedups'", specifier = ">=3.9.1" },
{ name = "ptpython", marker = "extra == 'shell'" },