Update DiscoveryResult to use Mashumaro instead of pydantic (#1231)

Mashumaro is faster and doesn't come with all versioning problems that
pydantic does.

A basic perf test deserializing all of our discovery results fixtures
shows mashumaro as being about 6 times faster deserializing dicts than
pydantic. It's much faster parsing from a json string but that's likely
because it uses orjson under the hood although that's not really our use
case at the moment.

```
PYDANTIC - ms
=================
json       dict
-----------------
4.7665     1.3268
3.1548     1.5922
3.1130     1.8039
4.2834     2.7606
2.0669     1.3757
2.0163     1.6377
3.1667     1.3561
4.1296     2.7297
2.0132     1.3471
4.0648     1.4105

MASHUMARO - ms
=================
json       dict
-----------------
0.5977     0.5543
0.5336     0.2983
0.3955     0.2549
0.6516     0.2742
0.5386     0.2706
0.6678     0.2580
0.4120     0.2511
0.3836     0.2472
0.4020     0.2465
0.4268     0.2487
```
This commit is contained in:
Steven B. 2024-11-12 21:00:04 +00:00 committed by GitHub
parent 9d5e07b969
commit 254a9af5c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 81 additions and 41 deletions

View File

@ -319,7 +319,7 @@ async def cli(
click.echo("Host and discovery info given, trying connect on %s." % host) click.echo("Host and discovery info given, trying connect on %s." % host)
di = json.loads(discovery_info) di = json.loads(discovery_info)
dr = DiscoveryResult(**di) dr = DiscoveryResult.from_dict(di)
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.device_type,
dr.mgt_encrypt_schm.encrypt_type, dr.mgt_encrypt_schm.encrypt_type,
@ -336,7 +336,7 @@ async def cli(
basedir, basedir,
autosave, autosave,
device.protocol, device.protocol,
discovery_info=dr.get_dict(), discovery_info=dr.to_dict(),
batch_size=batch_size, batch_size=batch_size,
) )
elif device_family and encrypt_type: elif device_family and encrypt_type:
@ -443,7 +443,7 @@ async def get_legacy_fixture(protocol, *, discovery_info):
if discovery_info and not discovery_info.get("system"): if discovery_info and not discovery_info.get("system"):
# Need to recreate a DiscoverResult here because we don't want the aliases # Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device. # in the fixture, we want the actual field names as returned by the device.
dr = DiscoveryResult(**protocol._discovery_info) dr = DiscoveryResult.from_dict(protocol._discovery_info)
final["discovery_result"] = dr.dict( final["discovery_result"] = dr.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
) )
@ -960,10 +960,8 @@ async def get_smart_fixtures(
# Need to recreate a DiscoverResult here because we don't want the aliases # Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device. # in the fixture, we want the actual field names as returned by the device.
if discovery_info: if discovery_info:
dr = DiscoveryResult(**discovery_info) # type: ignore dr = DiscoveryResult.from_dict(discovery_info) # type: ignore
final["discovery_result"] = dr.dict( final["discovery_result"] = dr.to_dict()
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)
click.echo("Got %s successes" % len(successes)) click.echo("Got %s successes" % len(successes))
click.echo(click.style("## device info file ##", bold=True)) click.echo(click.style("## device info file ##", bold=True))

View File

@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None:
return return
try: try:
dr = DiscoveryResult(**discovery_info) dr = DiscoveryResult.from_dict(discovery_info)
except ValidationError: except ValidationError:
_echo_dictionary(discovery_info) _echo_dictionary(discovery_info)
return return

View File

@ -90,6 +90,7 @@ import secrets
import socket import socket
import struct import struct
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from dataclasses import dataclass, field
from pprint import pformat as pf from pprint import pformat as pf
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -108,7 +109,8 @@ from aiohttp import ClientSession
# When support for cpython older than 3.11 is dropped # When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout # async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError from mashumaro import field_options
from mashumaro.config import BaseConfig
from kasa import Device from kasa import Device
from kasa.credentials import Credentials from kasa.credentials import Credentials
@ -130,6 +132,7 @@ from kasa.exceptions import (
from kasa.experimental import Experimental from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import DataClassJSONMixin
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.protocol import mask_mac, redact_data from kasa.protocol import mask_mac, redact_data
@ -647,7 +650,7 @@ class Discover:
def _get_device_class(info: dict) -> type[Device]: def _get_device_class(info: dict) -> type[Device]:
"""Find SmartDevice subclass for device described by passed data.""" """Find SmartDevice subclass for device described by passed data."""
if "result" in info: if "result" in info:
discovery_result = DiscoveryResult(**info["result"]) discovery_result = DiscoveryResult.from_dict(info["result"])
https = discovery_result.mgt_encrypt_schm.is_support_https https = discovery_result.mgt_encrypt_schm.is_support_https
dev_class = get_device_class_from_family( dev_class = get_device_class_from_family(
discovery_result.device_type, https=https discovery_result.device_type, https=https
@ -721,12 +724,8 @@ class Discover:
f"Unable to read response from device: {config.host}: {ex}" f"Unable to read response from device: {config.host}: {ex}"
) from ex ) from ex
try: try:
discovery_result = DiscoveryResult(**info["result"]) discovery_result = DiscoveryResult.from_dict(info["result"])
if ( except Exception as ex:
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex:
if debug_enabled: if debug_enabled:
data = ( data = (
redact_data(info, NEW_DISCOVERY_REDACTORS) redact_data(info, NEW_DISCOVERY_REDACTORS)
@ -742,6 +741,16 @@ class Discover:
f"Unable to parse discovery from device: {config.host}: {ex}", f"Unable to parse discovery from device: {config.host}: {ex}",
host=config.host, host=config.host,
) from ex ) from ex
# Decrypt the data
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
try:
Discover._decrypt_discovery_data(discovery_result)
except Exception:
_LOGGER.exception(
"Unable to decrypt discovery data %s: %s", config.host, data
)
type_ = discovery_result.device_type type_ = discovery_result.device_type
encrypt_schm = discovery_result.mgt_encrypt_schm encrypt_schm = discovery_result.mgt_encrypt_schm
@ -754,7 +763,7 @@ class Discover:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} " f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type", + "with no encryption type",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.to_dict(),
host=config.host, host=config.host,
) )
config.connection_type = DeviceConnectionParameters.from_values( config.connection_type = DeviceConnectionParameters.from_values(
@ -767,7 +776,7 @@ class Discover:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} " f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.to_dict(),
host=config.host, host=config.host,
) from ex ) from ex
if ( if (
@ -778,7 +787,7 @@ class Discover:
_LOGGER.warning("Got unsupported device type: %s", type_) _LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_}: {info}", f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.to_dict(),
host=config.host, host=config.host,
) )
if (protocol := get_protocol(config)) is None: if (protocol := get_protocol(config)) is None:
@ -788,7 +797,7 @@ class Discover:
raise UnsupportedDeviceError( raise UnsupportedDeviceError(
f"Unsupported encryption scheme {config.host} of " f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}", + f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(), discovery_result=discovery_result.to_dict(),
host=config.host, host=config.host,
) )
@ -801,22 +810,35 @@ class Discover:
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data)) _LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
device = device_class(config.host, protocol=protocol) device = device_class(config.host, protocol=protocol)
di = discovery_result.get_dict() di = discovery_result.to_dict()
di["model"], _, _ = discovery_result.device_model.partition("(") di["model"], _, _ = discovery_result.device_model.partition("(")
device.update_from_discover_info(di) device.update_from_discover_info(di)
return device return device
class EncryptionScheme(BaseModel): class _DiscoveryBaseMixin(DataClassJSONMixin):
"""Base class for serialization mixin."""
class Config(BaseConfig):
"""Serialization config."""
omit_none = True
omit_default = True
serialize_by_alias = True
@dataclass
class EncryptionScheme(_DiscoveryBaseMixin):
"""Base model for encryption scheme of discovery result.""" """Base model for encryption scheme of discovery result."""
is_support_https: bool is_support_https: bool
encrypt_type: Optional[str] # noqa: UP007 encrypt_type: Optional[str] = None # noqa: UP007
http_port: Optional[int] = None # noqa: UP007 http_port: Optional[int] = None # noqa: UP007
lv: Optional[int] = None # noqa: UP007 lv: Optional[int] = None # noqa: UP007
class EncryptionInfo(BaseModel): @dataclass
class EncryptionInfo(_DiscoveryBaseMixin):
"""Base model for encryption info of discovery result.""" """Base model for encryption info of discovery result."""
sym_schm: str sym_schm: str
@ -824,19 +846,23 @@ class EncryptionInfo(BaseModel):
data: str data: str
class DiscoveryResult(BaseModel): @dataclass
class DiscoveryResult(_DiscoveryBaseMixin):
"""Base model for discovery result.""" """Base model for discovery result."""
device_type: str device_type: str
device_model: str device_model: str
device_name: Optional[str] # noqa: UP007 device_id: str
ip: str ip: str
mac: str mac: str
mgt_encrypt_schm: EncryptionScheme mgt_encrypt_schm: EncryptionScheme
device_name: Optional[str] = None # noqa: UP007
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007 encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
encrypt_type: Optional[list[str]] = None # noqa: UP007 encrypt_type: Optional[list[str]] = None # noqa: UP007
decrypted_data: Optional[dict] = None # noqa: UP007 decrypted_data: Optional[dict] = None # noqa: UP007
device_id: str is_reset_wifi: Optional[bool] = field( # noqa: UP007
metadata=field_options(alias="isResetWiFi"), default=None
)
firmware_version: Optional[str] = None # noqa: UP007 firmware_version: Optional[str] = None # noqa: UP007
hardware_version: Optional[str] = None # noqa: UP007 hardware_version: Optional[str] = None # noqa: UP007
@ -845,12 +871,3 @@ class DiscoveryResult(BaseModel):
is_support_iot_cloud: Optional[bool] = None # noqa: UP007 is_support_iot_cloud: Optional[bool] = None # noqa: UP007
obd_src: Optional[str] = None # noqa: UP007 obd_src: Optional[str] = None # noqa: UP007
factory_default: Optional[bool] = None # noqa: UP007 factory_default: Optional[bool] = None # noqa: UP007
def get_dict(self) -> dict:
"""Return a dict for this discovery result.
containing only the values actually set and with aliases as field names.
"""
return self.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)

View File

@ -21,3 +21,13 @@ except ImportError:
return json.dumps(obj, separators=(",", ":")) return json.dumps(obj, separators=(",", ":"))
loads = json.loads loads = json.loads
try:
from mashumaro.mixins.orjson import DataClassORJSONMixin
DataClassJSONMixin = DataClassORJSONMixin
except ImportError:
from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin
DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc]

View File

@ -14,6 +14,7 @@ dependencies = [
"aiohttp>=3", "aiohttp>=3",
"typing-extensions>=4.12.2,<5.0", "typing-extensions>=4.12.2,<5.0",
"tzdata>=2024.2 ; platform_system == 'Windows'", "tzdata>=2024.2 ; platform_system == 'Windows'",
"mashumaro>=3.14",
] ]
classifiers = [ classifiers = [

View File

@ -616,7 +616,7 @@ async def test_credentials(discovery_mock, mocker, runner):
mocker.patch("kasa.cli.device.state", new=_state) mocker.patch("kasa.cli.device.state", new=_state)
dr = DiscoveryResult(**discovery_mock.discovery_data["result"]) dr = DiscoveryResult.from_dict(discovery_mock.discovery_data["result"])
res = await runner.invoke( res = await runner.invoke(
cli, cli,
[ [

View File

@ -43,7 +43,7 @@ pytestmark = [pytest.mark.requires_dummy]
def _get_connection_type_device_class(discovery_info): def _get_connection_type_device_class(discovery_info):
if "result" in discovery_info: if "result" in discovery_info:
device_class = Discover._get_device_class(discovery_info) device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult(**discovery_info["result"]) dr = DiscoveryResult.from_dict(discovery_info["result"])
connection_type = DeviceConnectionParameters.from_values( connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.mgt_encrypt_schm.encrypt_type dr.device_type, dr.mgt_encrypt_schm.encrypt_type

View File

@ -391,8 +391,8 @@ async def test_device_update_from_new_discovery_info(discovery_mock):
discovery_data = discovery_mock.discovery_data discovery_data = discovery_mock.discovery_data
device_class = Discover._get_device_class(discovery_data) device_class = Discover._get_device_class(discovery_data)
device = device_class("127.0.0.1") device = device_class("127.0.0.1")
discover_info = DiscoveryResult(**discovery_data["result"]) discover_info = DiscoveryResult.from_dict(discovery_data["result"])
discover_dump = discover_info.get_dict() discover_dump = discover_info.to_dict()
model, _, _ = discover_dump["device_model"].partition("(") model, _, _ = discover_dump["device_model"].partition("(")
discover_dump["model"] = model discover_dump["model"] = model
device.update_from_discover_info(discover_dump) device.update_from_discover_info(discover_dump)
@ -652,7 +652,7 @@ async def test_discovery_decryption():
"sym_schm": "AES", "sym_schm": "AES",
} }
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info} info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
dr = DiscoveryResult(**info) dr = DiscoveryResult.from_dict(info)
Discover._decrypt_discovery_data(dr) Discover._decrypt_discovery_data(dr)
assert dr.decrypted_data == data_dict assert dr.decrypted_data == data_dict

14
uv.lock
View File

@ -831,6 +831,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 }, { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 },
] ]
[[package]]
name = "mashumaro"
version = "3.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/eb/47/0a450b281bef2d7e97ec02c8e1168d821e283f58e02e6c403b2bb4d73c1c/mashumaro-3.14.tar.gz", hash = "sha256:5ef6f2b963892cbe9a4ceb3441dfbea37f8c3412523f25d42e9b3a7186555f1d", size = 166160 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1b/35/8d63733a2c12149d0c7663c29bf626bdbeea5f0ff963afe58a42b4810981/mashumaro-3.14-py3-none-any.whl", hash = "sha256:c12a649599a8f7b1a0b35d18f12e678423c3066189f7bc7bd8dd431c5c8132c3", size = 92183 },
]
[[package]] [[package]]
name = "mdit-py-plugins" name = "mdit-py-plugins"
version = "0.3.5" version = "0.3.5"
@ -1494,6 +1506,7 @@ dependencies = [
{ name = "async-timeout" }, { name = "async-timeout" },
{ name = "asyncclick" }, { name = "asyncclick" },
{ name = "cryptography" }, { name = "cryptography" },
{ name = "mashumaro" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
{ name = "tzdata", marker = "platform_system == 'Windows'" }, { name = "tzdata", marker = "platform_system == 'Windows'" },
@ -1544,6 +1557,7 @@ requires-dist = [
{ name = "cryptography", specifier = ">=1.9" }, { name = "cryptography", specifier = ">=1.9" },
{ name = "docutils", marker = "extra == 'docs'", specifier = ">=0.17" }, { name = "docutils", marker = "extra == 'docs'", specifier = ">=0.17" },
{ name = "kasa-crypt", marker = "extra == 'speedups'", specifier = ">=0.2.0" }, { name = "kasa-crypt", marker = "extra == 'speedups'", specifier = ">=0.2.0" },
{ name = "mashumaro", specifier = ">=3.14" },
{ name = "myst-parser", marker = "extra == 'docs'" }, { name = "myst-parser", marker = "extra == 'docs'" },
{ name = "orjson", marker = "extra == 'speedups'", specifier = ">=3.9.1" }, { name = "orjson", marker = "extra == 'speedups'", specifier = ">=3.9.1" },
{ name = "ptpython", marker = "extra == 'shell'" }, { name = "ptpython", marker = "extra == 'shell'" },