Replace custom deviceconfig serialization with mashumaru (#1274)

This commit is contained in:
Steven B. 2024-11-20 08:35:32 +00:00 committed by GitHub
parent bf23f73cce
commit 79ac9547e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 163 additions and 94 deletions

View File

@ -17,9 +17,10 @@ None
>>> config_dict = device.config.to_dict() >>> config_dict = device.config.to_dict()
>>> # DeviceConfig.to_dict() can be used to store for later >>> # DeviceConfig.to_dict() can be used to store for later
>>> print(config_dict) >>> print(config_dict)
{'host': '127.0.0.3', 'timeout': 5, 'credentials': Credentials(), 'connection_type'\ {'host': '127.0.0.3', 'timeout': 5, 'credentials': {'username': 'user@example.com', \
: {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'https': False, \ 'password': 'great_password'}, 'connection_type'\
'login_version': 2}, 'uses_http': True} : {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'login_version': 2, \
'https': False}, 'uses_http': True}
>>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict)) >>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict))
>>> print(later_device.alias) # Alias is available as connect() calls update() >>> print(later_device.alias) # Alias is available as connect() calls update()
@ -27,15 +28,21 @@ Living Room Bulb
""" """
# Module cannot use from __future__ import annotations until migrated to mashumaru from __future__ import annotations
# as dataclass.fields() will not resolve the type.
import logging import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import dataclass, field, replace
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, TypedDict from typing import TYPE_CHECKING, Any, Self, TypedDict
from aiohttp import ClientSession
from mashumaro import field_options
from mashumaro.config import BaseConfig
from mashumaro.types import SerializationStrategy
from .credentials import Credentials from .credentials import Credentials
from .exceptions import KasaException from .exceptions import KasaException
from .json import DataClassJSONMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from aiohttp import ClientSession from aiohttp import ClientSession
@ -73,45 +80,17 @@ class DeviceFamily(Enum):
SmartIpCamera = "SMART.IPCAMERA" SmartIpCamera = "SMART.IPCAMERA"
def _dataclass_from_dict(klass: Any, in_val: dict) -> Any: class _DeviceConfigBaseMixin(DataClassJSONMixin):
if is_dataclass(klass): """Base class for serialization mixin."""
fieldtypes = {f.name: f.type for f in fields(klass)}
val = {}
for dict_key in in_val:
if dict_key in fieldtypes:
if hasattr(fieldtypes[dict_key], "from_dict"):
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr]
else:
val[dict_key] = _dataclass_from_dict(
fieldtypes[dict_key], in_val[dict_key]
)
else:
raise KasaException(
f"Cannot create dataclass from dict, unknown key: {dict_key}"
)
return klass(**val) # type: ignore[operator]
else:
return in_val
class Config(BaseConfig):
"""Serialization config."""
def _dataclass_to_dict(in_val: Any) -> dict: omit_none = True
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
out_val = {}
for field_name in fieldtypes:
val = getattr(in_val, field_name)
if val is None:
continue
elif hasattr(val, "to_dict"):
out_val[field_name] = val.to_dict()
elif is_dataclass(fieldtypes[field_name]):
out_val[field_name] = asdict(val)
else:
out_val[field_name] = val
return out_val
@dataclass @dataclass
class DeviceConnectionParameters: class DeviceConnectionParameters(_DeviceConfigBaseMixin):
"""Class to hold the the parameters determining connection type.""" """Class to hold the the parameters determining connection type."""
device_family: DeviceFamily device_family: DeviceFamily
@ -125,7 +104,7 @@ class DeviceConnectionParameters:
encryption_type: str, encryption_type: str,
login_version: int | None = None, login_version: int | None = None,
https: bool | None = None, https: bool | None = None,
) -> "DeviceConnectionParameters": ) -> DeviceConnectionParameters:
"""Return connection parameters from string values.""" """Return connection parameters from string values."""
try: try:
if https is None: if https is None:
@ -142,39 +121,17 @@ class DeviceConnectionParameters:
+ f"{encryption_type}.{login_version}" + f"{encryption_type}.{login_version}"
) from ex ) from ex
@staticmethod
def from_dict(connection_type_dict: dict[str, Any]) -> "DeviceConnectionParameters":
"""Return connection parameters from dict."""
if (
isinstance(connection_type_dict, dict)
and (device_family := connection_type_dict.get("device_family"))
and (encryption_type := connection_type_dict.get("encryption_type"))
):
if login_version := connection_type_dict.get("login_version"):
login_version = int(login_version) # type: ignore[assignment]
return DeviceConnectionParameters.from_values(
device_family,
encryption_type,
login_version, # type: ignore[arg-type]
connection_type_dict.get("https", False),
)
raise KasaException(f"Invalid connection type data for {connection_type_dict}") class _DoNotSerialize(SerializationStrategy):
def serialize(self, value: Any) -> None:
return None # pragma: no cover
def to_dict(self) -> dict[str, str | int | bool]: def deserialize(self, value: Any) -> None:
"""Convert connection params to dict.""" return None # pragma: no cover
result: dict[str, str | int] = {
"device_family": self.device_family.value,
"encryption_type": self.encryption_type.value,
"https": self.https,
}
if self.login_version:
result["login_version"] = self.login_version
return result
@dataclass @dataclass
class DeviceConfig: class DeviceConfig(_DeviceConfigBaseMixin):
"""Class to represent paramaters that determine how to connect to devices.""" """Class to represent paramaters that determine how to connect to devices."""
DEFAULT_TIMEOUT = 5 DEFAULT_TIMEOUT = 5
@ -202,9 +159,12 @@ class DeviceConfig:
#: in order to determine whether they should pass a custom http client if desired. #: in order to determine whether they should pass a custom http client if desired.
uses_http: bool = False uses_http: bool = False
# compare=False will be excluded from the serialization and object comparison.
#: Set a custom http_client for the device to use. #: Set a custom http_client for the device to use.
http_client: Optional["ClientSession"] = field(default=None, compare=False) http_client: ClientSession | None = field(
default=None,
compare=False,
metadata=field_options(serialization_strategy=_DoNotSerialize()),
)
aes_keys: KeyPairDict | None = None aes_keys: KeyPairDict | None = None
@ -214,22 +174,30 @@ class DeviceConfig:
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor
) )
def to_dict( def __pre_serialize__(self) -> Self:
return replace(self, http_client=None)
def to_dict_control_credentials(
self, self,
*, *,
credentials_hash: str | None = None, credentials_hash: str | None = None,
exclude_credentials: bool = False, exclude_credentials: bool = False,
) -> dict[str, dict[str, str]]: ) -> dict[str, dict[str, str]]:
"""Convert device config to dict.""" """Convert deviceconfig to dict controlling how to serialize credentials.
if credentials_hash is not None or exclude_credentials:
self.credentials = None
if credentials_hash:
self.credentials_hash = credentials_hash
return _dataclass_to_dict(self)
@staticmethod If credentials_hash is provided credentials will be None.
def from_dict(config_dict: dict[str, dict[str, str]]) -> "DeviceConfig": If credentials_hash is '' credentials_hash and credentials will be None.
"""Return device config from dict.""" exclude credentials controls whether to include credentials.
if isinstance(config_dict, dict): The defaults are the same as calling to_dict().
return _dataclass_from_dict(DeviceConfig, config_dict) """
raise KasaException(f"Invalid device config data: {config_dict}") if credentials_hash is None:
if not exclude_credentials:
return self.to_dict()
else:
return replace(self, credentials=None).to_dict()
return replace(
self,
credentials_hash=credentials_hash if credentials_hash else None,
credentials=None,
).to_dict()

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import sys import sys
import warnings import warnings
from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -21,6 +22,13 @@ from .fixtureinfo import fixture_info # noqa: F401
turn_on = pytest.mark.parametrize("turn_on", [True, False]) turn_on = pytest.mark.parametrize("turn_on", [True, False])
def load_fixture(foldername, filename):
"""Load a fixture."""
path = Path(Path(__file__).parent / "fixtures" / foldername / filename)
with path.open() as fdp:
return fdp.read()
async def handle_turn_on(dev, turn_on): async def handle_turn_on(dev, turn_on):
if turn_on: if turn_on:
await dev.turn_on() await dev.turn_on()

View File

@ -0,0 +1,10 @@
{
"host": "127.0.0.1",
"timeout": 5,
"connection_type": {
"device_family": "SMART.IPCAMERA",
"encryption_type": "AES",
"https": true
},
"uses_http": false
}

View File

@ -0,0 +1,11 @@
{
"host": "127.0.0.1",
"timeout": 5,
"connection_type": {
"device_family": "SMART.TAPOPLUG",
"encryption_type": "KLAP",
"https": false,
"login_version": 2
},
"uses_http": false
}

View File

@ -0,0 +1,10 @@
{
"host": "127.0.0.1",
"timeout": 5,
"connection_type": {
"device_family": "IOT.SMARTPLUGSWITCH",
"encryption_type": "XOR",
"https": false
},
"uses_http": false
}

View File

@ -1,35 +1,97 @@
import json
from dataclasses import replace
from json import dumps as json_dumps from json import dumps as json_dumps
from json import loads as json_loads from json import loads as json_loads
import aiohttp import aiohttp
import pytest import pytest
from mashumaro import MissingField
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.deviceconfig import ( from kasa.deviceconfig import (
DeviceConfig, DeviceConfig,
DeviceConnectionParameters,
DeviceEncryptionType,
DeviceFamily,
)
from .conftest import load_fixture
PLUG_XOR_CONFIG = DeviceConfig(host="127.0.0.1")
PLUG_KLAP_CONFIG = DeviceConfig(
host="127.0.0.1",
connection_type=DeviceConnectionParameters(
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap, login_version=2
),
)
CAMERA_AES_CONFIG = DeviceConfig(
host="127.0.0.1",
connection_type=DeviceConnectionParameters(
DeviceFamily.SmartIpCamera, DeviceEncryptionType.Aes, https=True
),
) )
from kasa.exceptions import KasaException
async def test_serialization(): async def test_serialization():
"""Test device config serialization."""
config = DeviceConfig(host="Foo", http_client=aiohttp.ClientSession()) config = DeviceConfig(host="Foo", http_client=aiohttp.ClientSession())
config_dict = config.to_dict() config_dict = config.to_dict()
config_json = json_dumps(config_dict) config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json) config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict) config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2 assert config == config2
assert config.to_dict_control_credentials() == config.to_dict()
@pytest.mark.parametrize( @pytest.mark.parametrize(
("input_value", "expected_msg"), ("fixture_name", "expected_value"),
[ [
({"Foo": "Bar"}, "Cannot create dataclass from dict, unknown key: Foo"), ("deviceconfig_plug-xor.json", PLUG_XOR_CONFIG),
("foobar", "Invalid device config data: foobar"), ("deviceconfig_plug-klap.json", PLUG_KLAP_CONFIG),
("deviceconfig_camera-aes-https.json", CAMERA_AES_CONFIG),
],
ids=lambda arg: arg.split("_")[-1] if isinstance(arg, str) else "",
)
async def test_deserialization(fixture_name: str, expected_value: DeviceConfig):
"""Test device config deserialization."""
dict_val = json.loads(load_fixture("serialization", fixture_name))
config = DeviceConfig.from_dict(dict_val)
assert config == expected_value
assert expected_value.to_dict() == dict_val
async def test_serialization_http_client():
"""Test that the http client does not try to serialize."""
dict_val = json.loads(load_fixture("serialization", "deviceconfig_plug-klap.json"))
config = replace(PLUG_KLAP_CONFIG, http_client=object())
assert config.http_client
assert config.to_dict() == dict_val
async def test_conn_param_no_https():
"""Test no https in connection param defaults to False."""
dict_val = {
"device_family": "SMART.TAPOPLUG",
"encryption_type": "KLAP",
"login_version": 2,
}
param = DeviceConnectionParameters.from_dict(dict_val)
assert param.https is False
assert param.to_dict() == {**dict_val, "https": False}
@pytest.mark.parametrize(
("input_value", "expected_error"),
[
({"Foo": "Bar"}, MissingField),
("foobar", ValueError),
], ],
ids=["invalid-dict", "not-dict"], ids=["invalid-dict", "not-dict"],
) )
def test_deserialization_errors(input_value, expected_msg): def test_deserialization_errors(input_value, expected_error):
with pytest.raises(KasaException, match=expected_msg): with pytest.raises(expected_error):
DeviceConfig.from_dict(input_value) DeviceConfig.from_dict(input_value)
@ -39,7 +101,7 @@ async def test_credentials_hash():
http_client=aiohttp.ClientSession(), http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"), credentials=Credentials("foo", "bar"),
) )
config_dict = config.to_dict(credentials_hash="credhash") config_dict = config.to_dict_control_credentials(credentials_hash="credhash")
config_json = json_dumps(config_dict) config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json) config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict) config2 = DeviceConfig.from_dict(config2_dict)
@ -53,7 +115,7 @@ async def test_blank_credentials_hash():
http_client=aiohttp.ClientSession(), http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"), credentials=Credentials("foo", "bar"),
) )
config_dict = config.to_dict(credentials_hash="") config_dict = config.to_dict_control_credentials(credentials_hash="")
config_json = json_dumps(config_dict) config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json) config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict) config2 = DeviceConfig.from_dict(config2_dict)
@ -67,7 +129,7 @@ async def test_exclude_credentials():
http_client=aiohttp.ClientSession(), http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"), credentials=Credentials("foo", "bar"),
) )
config_dict = config.to_dict(exclude_credentials=True) config_dict = config.to_dict_control_credentials(exclude_credentials=True)
config_json = json_dumps(config_dict) config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json) config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict) config2 = DeviceConfig.from_dict(config2_dict)