mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Update DiscoveryResult to use Mashumaro instead of pydantic (#1231)
Mashumaro is faster and doesn't come with all versioning problems that pydantic does. A basic perf test deserializing all of our discovery results fixtures shows mashumaro as being about 6 times faster deserializing dicts than pydantic. It's much faster parsing from a json string but that's likely because it uses orjson under the hood although that's not really our use case at the moment. ``` PYDANTIC - ms ================= json dict ----------------- 4.7665 1.3268 3.1548 1.5922 3.1130 1.8039 4.2834 2.7606 2.0669 1.3757 2.0163 1.6377 3.1667 1.3561 4.1296 2.7297 2.0132 1.3471 4.0648 1.4105 MASHUMARO - ms ================= json dict ----------------- 0.5977 0.5543 0.5336 0.2983 0.3955 0.2549 0.6516 0.2742 0.5386 0.2706 0.6678 0.2580 0.4120 0.2511 0.3836 0.2472 0.4020 0.2465 0.4268 0.2487 ```
This commit is contained in:
parent
9d5e07b969
commit
254a9af5c1
@ -319,7 +319,7 @@ async def cli(
|
||||
click.echo("Host and discovery info given, trying connect on %s." % host)
|
||||
|
||||
di = json.loads(discovery_info)
|
||||
dr = DiscoveryResult(**di)
|
||||
dr = DiscoveryResult.from_dict(di)
|
||||
connection_type = DeviceConnectionParameters.from_values(
|
||||
dr.device_type,
|
||||
dr.mgt_encrypt_schm.encrypt_type,
|
||||
@ -336,7 +336,7 @@ async def cli(
|
||||
basedir,
|
||||
autosave,
|
||||
device.protocol,
|
||||
discovery_info=dr.get_dict(),
|
||||
discovery_info=dr.to_dict(),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
elif device_family and encrypt_type:
|
||||
@ -443,7 +443,7 @@ async def get_legacy_fixture(protocol, *, discovery_info):
|
||||
if discovery_info and not discovery_info.get("system"):
|
||||
# Need to recreate a DiscoverResult here because we don't want the aliases
|
||||
# in the fixture, we want the actual field names as returned by the device.
|
||||
dr = DiscoveryResult(**protocol._discovery_info)
|
||||
dr = DiscoveryResult.from_dict(protocol._discovery_info)
|
||||
final["discovery_result"] = dr.dict(
|
||||
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
|
||||
)
|
||||
@ -960,10 +960,8 @@ async def get_smart_fixtures(
|
||||
# Need to recreate a DiscoverResult here because we don't want the aliases
|
||||
# in the fixture, we want the actual field names as returned by the device.
|
||||
if discovery_info:
|
||||
dr = DiscoveryResult(**discovery_info) # type: ignore
|
||||
final["discovery_result"] = dr.dict(
|
||||
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
|
||||
)
|
||||
dr = DiscoveryResult.from_dict(discovery_info) # type: ignore
|
||||
final["discovery_result"] = dr.to_dict()
|
||||
|
||||
click.echo("Got %s successes" % len(successes))
|
||||
click.echo(click.style("## device info file ##", bold=True))
|
||||
|
@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None:
|
||||
return
|
||||
|
||||
try:
|
||||
dr = DiscoveryResult(**discovery_info)
|
||||
dr = DiscoveryResult.from_dict(discovery_info)
|
||||
except ValidationError:
|
||||
_echo_dictionary(discovery_info)
|
||||
return
|
||||
|
@ -90,6 +90,7 @@ import secrets
|
||||
import socket
|
||||
import struct
|
||||
from asyncio.transports import DatagramTransport
|
||||
from dataclasses import dataclass, field
|
||||
from pprint import pformat as pf
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -108,7 +109,8 @@ from aiohttp import ClientSession
|
||||
# When support for cpython older than 3.11 is dropped
|
||||
# async_timeout can be replaced with asyncio.timeout
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
from pydantic.v1 import BaseModel, ValidationError
|
||||
from mashumaro import field_options
|
||||
from mashumaro.config import BaseConfig
|
||||
|
||||
from kasa import Device
|
||||
from kasa.credentials import Credentials
|
||||
@ -130,6 +132,7 @@ from kasa.exceptions import (
|
||||
from kasa.experimental import Experimental
|
||||
from kasa.iot.iotdevice import IotDevice
|
||||
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
|
||||
from kasa.json import DataClassJSONMixin
|
||||
from kasa.json import dumps as json_dumps
|
||||
from kasa.json import loads as json_loads
|
||||
from kasa.protocol import mask_mac, redact_data
|
||||
@ -647,7 +650,7 @@ class Discover:
|
||||
def _get_device_class(info: dict) -> type[Device]:
|
||||
"""Find SmartDevice subclass for device described by passed data."""
|
||||
if "result" in info:
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
discovery_result = DiscoveryResult.from_dict(info["result"])
|
||||
https = discovery_result.mgt_encrypt_schm.is_support_https
|
||||
dev_class = get_device_class_from_family(
|
||||
discovery_result.device_type, https=https
|
||||
@ -721,12 +724,8 @@ class Discover:
|
||||
f"Unable to read response from device: {config.host}: {ex}"
|
||||
) from ex
|
||||
try:
|
||||
discovery_result = DiscoveryResult(**info["result"])
|
||||
if (
|
||||
encrypt_info := discovery_result.encrypt_info
|
||||
) and encrypt_info.sym_schm == "AES":
|
||||
Discover._decrypt_discovery_data(discovery_result)
|
||||
except ValidationError as ex:
|
||||
discovery_result = DiscoveryResult.from_dict(info["result"])
|
||||
except Exception as ex:
|
||||
if debug_enabled:
|
||||
data = (
|
||||
redact_data(info, NEW_DISCOVERY_REDACTORS)
|
||||
@ -742,6 +741,16 @@ class Discover:
|
||||
f"Unable to parse discovery from device: {config.host}: {ex}",
|
||||
host=config.host,
|
||||
) from ex
|
||||
# Decrypt the data
|
||||
if (
|
||||
encrypt_info := discovery_result.encrypt_info
|
||||
) and encrypt_info.sym_schm == "AES":
|
||||
try:
|
||||
Discover._decrypt_discovery_data(discovery_result)
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Unable to decrypt discovery data %s: %s", config.host, data
|
||||
)
|
||||
|
||||
type_ = discovery_result.device_type
|
||||
encrypt_schm = discovery_result.mgt_encrypt_schm
|
||||
@ -754,7 +763,7 @@ class Discover:
|
||||
raise UnsupportedDeviceError(
|
||||
f"Unsupported device {config.host} of type {type_} "
|
||||
+ "with no encryption type",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
discovery_result=discovery_result.to_dict(),
|
||||
host=config.host,
|
||||
)
|
||||
config.connection_type = DeviceConnectionParameters.from_values(
|
||||
@ -767,7 +776,7 @@ class Discover:
|
||||
raise UnsupportedDeviceError(
|
||||
f"Unsupported device {config.host} of type {type_} "
|
||||
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
discovery_result=discovery_result.to_dict(),
|
||||
host=config.host,
|
||||
) from ex
|
||||
if (
|
||||
@ -778,7 +787,7 @@ class Discover:
|
||||
_LOGGER.warning("Got unsupported device type: %s", type_)
|
||||
raise UnsupportedDeviceError(
|
||||
f"Unsupported device {config.host} of type {type_}: {info}",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
discovery_result=discovery_result.to_dict(),
|
||||
host=config.host,
|
||||
)
|
||||
if (protocol := get_protocol(config)) is None:
|
||||
@ -788,7 +797,7 @@ class Discover:
|
||||
raise UnsupportedDeviceError(
|
||||
f"Unsupported encryption scheme {config.host} of "
|
||||
+ f"type {config.connection_type.to_dict()}: {info}",
|
||||
discovery_result=discovery_result.get_dict(),
|
||||
discovery_result=discovery_result.to_dict(),
|
||||
host=config.host,
|
||||
)
|
||||
|
||||
@ -801,22 +810,35 @@ class Discover:
|
||||
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
|
||||
device = device_class(config.host, protocol=protocol)
|
||||
|
||||
di = discovery_result.get_dict()
|
||||
di = discovery_result.to_dict()
|
||||
di["model"], _, _ = discovery_result.device_model.partition("(")
|
||||
device.update_from_discover_info(di)
|
||||
return device
|
||||
|
||||
|
||||
class EncryptionScheme(BaseModel):
|
||||
class _DiscoveryBaseMixin(DataClassJSONMixin):
|
||||
"""Base class for serialization mixin."""
|
||||
|
||||
class Config(BaseConfig):
|
||||
"""Serialization config."""
|
||||
|
||||
omit_none = True
|
||||
omit_default = True
|
||||
serialize_by_alias = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncryptionScheme(_DiscoveryBaseMixin):
|
||||
"""Base model for encryption scheme of discovery result."""
|
||||
|
||||
is_support_https: bool
|
||||
encrypt_type: Optional[str] # noqa: UP007
|
||||
encrypt_type: Optional[str] = None # noqa: UP007
|
||||
http_port: Optional[int] = None # noqa: UP007
|
||||
lv: Optional[int] = None # noqa: UP007
|
||||
|
||||
|
||||
class EncryptionInfo(BaseModel):
|
||||
@dataclass
|
||||
class EncryptionInfo(_DiscoveryBaseMixin):
|
||||
"""Base model for encryption info of discovery result."""
|
||||
|
||||
sym_schm: str
|
||||
@ -824,19 +846,23 @@ class EncryptionInfo(BaseModel):
|
||||
data: str
|
||||
|
||||
|
||||
class DiscoveryResult(BaseModel):
|
||||
@dataclass
|
||||
class DiscoveryResult(_DiscoveryBaseMixin):
|
||||
"""Base model for discovery result."""
|
||||
|
||||
device_type: str
|
||||
device_model: str
|
||||
device_name: Optional[str] # noqa: UP007
|
||||
device_id: str
|
||||
ip: str
|
||||
mac: str
|
||||
mgt_encrypt_schm: EncryptionScheme
|
||||
device_name: Optional[str] = None # noqa: UP007
|
||||
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
|
||||
encrypt_type: Optional[list[str]] = None # noqa: UP007
|
||||
decrypted_data: Optional[dict] = None # noqa: UP007
|
||||
device_id: str
|
||||
is_reset_wifi: Optional[bool] = field( # noqa: UP007
|
||||
metadata=field_options(alias="isResetWiFi"), default=None
|
||||
)
|
||||
|
||||
firmware_version: Optional[str] = None # noqa: UP007
|
||||
hardware_version: Optional[str] = None # noqa: UP007
|
||||
@ -845,12 +871,3 @@ class DiscoveryResult(BaseModel):
|
||||
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
|
||||
obd_src: Optional[str] = None # noqa: UP007
|
||||
factory_default: Optional[bool] = None # noqa: UP007
|
||||
|
||||
def get_dict(self) -> dict:
|
||||
"""Return a dict for this discovery result.
|
||||
|
||||
containing only the values actually set and with aliases as field names.
|
||||
"""
|
||||
return self.dict(
|
||||
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
|
||||
)
|
||||
|
10
kasa/json.py
10
kasa/json.py
@ -21,3 +21,13 @@ except ImportError:
|
||||
return json.dumps(obj, separators=(",", ":"))
|
||||
|
||||
loads = json.loads
|
||||
|
||||
|
||||
try:
|
||||
from mashumaro.mixins.orjson import DataClassORJSONMixin
|
||||
|
||||
DataClassJSONMixin = DataClassORJSONMixin
|
||||
except ImportError:
|
||||
from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin
|
||||
|
||||
DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc]
|
||||
|
@ -14,6 +14,7 @@ dependencies = [
|
||||
"aiohttp>=3",
|
||||
"typing-extensions>=4.12.2,<5.0",
|
||||
"tzdata>=2024.2 ; platform_system == 'Windows'",
|
||||
"mashumaro>=3.14",
|
||||
]
|
||||
|
||||
classifiers = [
|
||||
|
@ -616,7 +616,7 @@ async def test_credentials(discovery_mock, mocker, runner):
|
||||
|
||||
mocker.patch("kasa.cli.device.state", new=_state)
|
||||
|
||||
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
|
||||
dr = DiscoveryResult.from_dict(discovery_mock.discovery_data["result"])
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
[
|
||||
|
@ -43,7 +43,7 @@ pytestmark = [pytest.mark.requires_dummy]
|
||||
def _get_connection_type_device_class(discovery_info):
|
||||
if "result" in discovery_info:
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
dr = DiscoveryResult.from_dict(discovery_info["result"])
|
||||
|
||||
connection_type = DeviceConnectionParameters.from_values(
|
||||
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
|
||||
|
@ -391,8 +391,8 @@ async def test_device_update_from_new_discovery_info(discovery_mock):
|
||||
discovery_data = discovery_mock.discovery_data
|
||||
device_class = Discover._get_device_class(discovery_data)
|
||||
device = device_class("127.0.0.1")
|
||||
discover_info = DiscoveryResult(**discovery_data["result"])
|
||||
discover_dump = discover_info.get_dict()
|
||||
discover_info = DiscoveryResult.from_dict(discovery_data["result"])
|
||||
discover_dump = discover_info.to_dict()
|
||||
model, _, _ = discover_dump["device_model"].partition("(")
|
||||
discover_dump["model"] = model
|
||||
device.update_from_discover_info(discover_dump)
|
||||
@ -652,7 +652,7 @@ async def test_discovery_decryption():
|
||||
"sym_schm": "AES",
|
||||
}
|
||||
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
|
||||
dr = DiscoveryResult(**info)
|
||||
dr = DiscoveryResult.from_dict(info)
|
||||
Discover._decrypt_discovery_data(dr)
|
||||
assert dr.decrypted_data == data_dict
|
||||
|
||||
|
14
uv.lock
14
uv.lock
@ -831,6 +831,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mashumaro"
|
||||
version = "3.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/eb/47/0a450b281bef2d7e97ec02c8e1168d821e283f58e02e6c403b2bb4d73c1c/mashumaro-3.14.tar.gz", hash = "sha256:5ef6f2b963892cbe9a4ceb3441dfbea37f8c3412523f25d42e9b3a7186555f1d", size = 166160 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/35/8d63733a2c12149d0c7663c29bf626bdbeea5f0ff963afe58a42b4810981/mashumaro-3.14-py3-none-any.whl", hash = "sha256:c12a649599a8f7b1a0b35d18f12e678423c3066189f7bc7bd8dd431c5c8132c3", size = 92183 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdit-py-plugins"
|
||||
version = "0.3.5"
|
||||
@ -1494,6 +1506,7 @@ dependencies = [
|
||||
{ name = "async-timeout" },
|
||||
{ name = "asyncclick" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "mashumaro" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "tzdata", marker = "platform_system == 'Windows'" },
|
||||
@ -1544,6 +1557,7 @@ requires-dist = [
|
||||
{ name = "cryptography", specifier = ">=1.9" },
|
||||
{ name = "docutils", marker = "extra == 'docs'", specifier = ">=0.17" },
|
||||
{ name = "kasa-crypt", marker = "extra == 'speedups'", specifier = ">=0.2.0" },
|
||||
{ name = "mashumaro", specifier = ">=3.14" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'" },
|
||||
{ name = "orjson", marker = "extra == 'speedups'", specifier = ">=3.9.1" },
|
||||
{ name = "ptpython", marker = "extra == 'shell'" },
|
||||
|
Loading…
Reference in New Issue
Block a user