diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index ede3f595..56e97f5e 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -17,9 +17,10 @@ None >>> config_dict = device.config.to_dict() >>> # DeviceConfig.to_dict() can be used to store for later >>> print(config_dict) -{'host': '127.0.0.3', 'timeout': 5, 'credentials': Credentials(), 'connection_type'\ -: {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'https': False, \ -'login_version': 2}, 'uses_http': True} +{'host': '127.0.0.3', 'timeout': 5, 'credentials': {'username': 'user@example.com', \ +'password': 'great_password'}, 'connection_type'\ +: {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'login_version': 2, \ +'https': False}, 'uses_http': True} >>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict)) >>> print(later_device.alias) # Alias is available as connect() calls update() @@ -27,15 +28,21 @@ Living Room Bulb """ -# Module cannot use from __future__ import annotations until migrated to mashumaru -# as dataclass.fields() will not resolve the type. +from __future__ import annotations + import logging -from dataclasses import asdict, dataclass, field, fields, is_dataclass +from dataclasses import dataclass, field, replace from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Self, TypedDict + +from aiohttp import ClientSession +from mashumaro import field_options +from mashumaro.config import BaseConfig +from mashumaro.types import SerializationStrategy from .credentials import Credentials from .exceptions import KasaException +from .json import DataClassJSONMixin if TYPE_CHECKING: from aiohttp import ClientSession @@ -73,45 +80,17 @@ class DeviceFamily(Enum): SmartIpCamera = "SMART.IPCAMERA" -def _dataclass_from_dict(klass: Any, in_val: dict) -> Any: - if is_dataclass(klass): - fieldtypes = {f.name: f.type for f in fields(klass)} - val = {} - for dict_key in in_val: - if dict_key in fieldtypes: - if hasattr(fieldtypes[dict_key], "from_dict"): - val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr] - else: - val[dict_key] = _dataclass_from_dict( - fieldtypes[dict_key], in_val[dict_key] - ) - else: - raise KasaException( - f"Cannot create dataclass from dict, unknown key: {dict_key}" - ) - return klass(**val) # type: ignore[operator] - else: - return in_val +class _DeviceConfigBaseMixin(DataClassJSONMixin): + """Base class for serialization mixin.""" + class Config(BaseConfig): + """Serialization config.""" -def _dataclass_to_dict(in_val: Any) -> dict: - fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare} - out_val = {} - for field_name in fieldtypes: - val = getattr(in_val, field_name) - if val is None: - continue - elif hasattr(val, "to_dict"): - out_val[field_name] = val.to_dict() - elif is_dataclass(fieldtypes[field_name]): - out_val[field_name] = asdict(val) - else: - out_val[field_name] = val - return out_val + omit_none = True @dataclass -class DeviceConnectionParameters: +class DeviceConnectionParameters(_DeviceConfigBaseMixin): """Class to hold the the parameters determining connection type.""" device_family: DeviceFamily @@ -125,7 +104,7 @@ class DeviceConnectionParameters: encryption_type: str, login_version: int | None = None, https: bool | None = None, - ) -> "DeviceConnectionParameters": + ) -> DeviceConnectionParameters: """Return connection parameters from string values.""" try: if https is None: @@ -142,39 +121,17 @@ class DeviceConnectionParameters: + f"{encryption_type}.{login_version}" ) from ex - @staticmethod - def from_dict(connection_type_dict: dict[str, Any]) -> "DeviceConnectionParameters": - """Return connection parameters from dict.""" - if ( - isinstance(connection_type_dict, dict) - and (device_family := connection_type_dict.get("device_family")) - and (encryption_type := connection_type_dict.get("encryption_type")) - ): - if login_version := connection_type_dict.get("login_version"): - login_version = int(login_version) # type: ignore[assignment] - return DeviceConnectionParameters.from_values( - device_family, - encryption_type, - login_version, # type: ignore[arg-type] - connection_type_dict.get("https", False), - ) - raise KasaException(f"Invalid connection type data for {connection_type_dict}") +class _DoNotSerialize(SerializationStrategy): + def serialize(self, value: Any) -> None: + return None # pragma: no cover - def to_dict(self) -> dict[str, str | int | bool]: - """Convert connection params to dict.""" - result: dict[str, str | int] = { - "device_family": self.device_family.value, - "encryption_type": self.encryption_type.value, - "https": self.https, - } - if self.login_version: - result["login_version"] = self.login_version - return result + def deserialize(self, value: Any) -> None: + return None # pragma: no cover @dataclass -class DeviceConfig: +class DeviceConfig(_DeviceConfigBaseMixin): """Class to represent paramaters that determine how to connect to devices.""" DEFAULT_TIMEOUT = 5 @@ -202,9 +159,12 @@ class DeviceConfig: #: in order to determine whether they should pass a custom http client if desired. uses_http: bool = False - # compare=False will be excluded from the serialization and object comparison. #: Set a custom http_client for the device to use. - http_client: Optional["ClientSession"] = field(default=None, compare=False) + http_client: ClientSession | None = field( + default=None, + compare=False, + metadata=field_options(serialization_strategy=_DoNotSerialize()), + ) aes_keys: KeyPairDict | None = None @@ -214,22 +174,30 @@ class DeviceConfig: DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor ) - def to_dict( + def __pre_serialize__(self) -> Self: + return replace(self, http_client=None) + + def to_dict_control_credentials( self, *, credentials_hash: str | None = None, exclude_credentials: bool = False, ) -> dict[str, dict[str, str]]: - """Convert device config to dict.""" - if credentials_hash is not None or exclude_credentials: - self.credentials = None - if credentials_hash: - self.credentials_hash = credentials_hash - return _dataclass_to_dict(self) + """Convert deviceconfig to dict controlling how to serialize credentials. - @staticmethod - def from_dict(config_dict: dict[str, dict[str, str]]) -> "DeviceConfig": - """Return device config from dict.""" - if isinstance(config_dict, dict): - return _dataclass_from_dict(DeviceConfig, config_dict) - raise KasaException(f"Invalid device config data: {config_dict}") + If credentials_hash is provided credentials will be None. + If credentials_hash is '' credentials_hash and credentials will be None. + exclude credentials controls whether to include credentials. + The defaults are the same as calling to_dict(). + """ + if credentials_hash is None: + if not exclude_credentials: + return self.to_dict() + else: + return replace(self, credentials=None).to_dict() + + return replace( + self, + credentials_hash=credentials_hash if credentials_hash else None, + credentials=None, + ).to_dict() diff --git a/tests/conftest.py b/tests/conftest.py index 1b11e01b..3da689c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import sys import warnings +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -21,6 +22,13 @@ from .fixtureinfo import fixture_info # noqa: F401 turn_on = pytest.mark.parametrize("turn_on", [True, False]) +def load_fixture(foldername, filename): + """Load a fixture.""" + path = Path(Path(__file__).parent / "fixtures" / foldername / filename) + with path.open() as fdp: + return fdp.read() + + async def handle_turn_on(dev, turn_on): if turn_on: await dev.turn_on() diff --git a/tests/fixtures/serialization/deviceconfig_camera-aes-https.json b/tests/fixtures/serialization/deviceconfig_camera-aes-https.json new file mode 100644 index 00000000..559e834b --- /dev/null +++ b/tests/fixtures/serialization/deviceconfig_camera-aes-https.json @@ -0,0 +1,10 @@ +{ + "host": "127.0.0.1", + "timeout": 5, + "connection_type": { + "device_family": "SMART.IPCAMERA", + "encryption_type": "AES", + "https": true + }, + "uses_http": false +} diff --git a/tests/fixtures/serialization/deviceconfig_plug-klap.json b/tests/fixtures/serialization/deviceconfig_plug-klap.json new file mode 100644 index 00000000..ef42bb2f --- /dev/null +++ b/tests/fixtures/serialization/deviceconfig_plug-klap.json @@ -0,0 +1,11 @@ +{ + "host": "127.0.0.1", + "timeout": 5, + "connection_type": { + "device_family": "SMART.TAPOPLUG", + "encryption_type": "KLAP", + "https": false, + "login_version": 2 + }, + "uses_http": false +} diff --git a/tests/fixtures/serialization/deviceconfig_plug-xor.json b/tests/fixtures/serialization/deviceconfig_plug-xor.json new file mode 100644 index 00000000..78cc05a9 --- /dev/null +++ b/tests/fixtures/serialization/deviceconfig_plug-xor.json @@ -0,0 +1,10 @@ +{ + "host": "127.0.0.1", + "timeout": 5, + "connection_type": { + "device_family": "IOT.SMARTPLUGSWITCH", + "encryption_type": "XOR", + "https": false + }, + "uses_http": false +} diff --git a/tests/test_deviceconfig.py b/tests/test_deviceconfig.py index cefc6179..aebdd3a6 100644 --- a/tests/test_deviceconfig.py +++ b/tests/test_deviceconfig.py @@ -1,35 +1,97 @@ +import json +from dataclasses import replace from json import dumps as json_dumps from json import loads as json_loads import aiohttp import pytest +from mashumaro import MissingField from kasa.credentials import Credentials from kasa.deviceconfig import ( DeviceConfig, + DeviceConnectionParameters, + DeviceEncryptionType, + DeviceFamily, +) + +from .conftest import load_fixture + +PLUG_XOR_CONFIG = DeviceConfig(host="127.0.0.1") +PLUG_KLAP_CONFIG = DeviceConfig( + host="127.0.0.1", + connection_type=DeviceConnectionParameters( + DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap, login_version=2 + ), +) +CAMERA_AES_CONFIG = DeviceConfig( + host="127.0.0.1", + connection_type=DeviceConnectionParameters( + DeviceFamily.SmartIpCamera, DeviceEncryptionType.Aes, https=True + ), ) -from kasa.exceptions import KasaException async def test_serialization(): + """Test device config serialization.""" config = DeviceConfig(host="Foo", http_client=aiohttp.ClientSession()) config_dict = config.to_dict() config_json = json_dumps(config_dict) config2_dict = json_loads(config_json) config2 = DeviceConfig.from_dict(config2_dict) assert config == config2 + assert config.to_dict_control_credentials() == config.to_dict() @pytest.mark.parametrize( - ("input_value", "expected_msg"), + ("fixture_name", "expected_value"), [ - ({"Foo": "Bar"}, "Cannot create dataclass from dict, unknown key: Foo"), - ("foobar", "Invalid device config data: foobar"), + ("deviceconfig_plug-xor.json", PLUG_XOR_CONFIG), + ("deviceconfig_plug-klap.json", PLUG_KLAP_CONFIG), + ("deviceconfig_camera-aes-https.json", CAMERA_AES_CONFIG), + ], + ids=lambda arg: arg.split("_")[-1] if isinstance(arg, str) else "", +) +async def test_deserialization(fixture_name: str, expected_value: DeviceConfig): + """Test device config deserialization.""" + dict_val = json.loads(load_fixture("serialization", fixture_name)) + config = DeviceConfig.from_dict(dict_val) + assert config == expected_value + assert expected_value.to_dict() == dict_val + + +async def test_serialization_http_client(): + """Test that the http client does not try to serialize.""" + dict_val = json.loads(load_fixture("serialization", "deviceconfig_plug-klap.json")) + + config = replace(PLUG_KLAP_CONFIG, http_client=object()) + assert config.http_client + + assert config.to_dict() == dict_val + + +async def test_conn_param_no_https(): + """Test no https in connection param defaults to False.""" + dict_val = { + "device_family": "SMART.TAPOPLUG", + "encryption_type": "KLAP", + "login_version": 2, + } + param = DeviceConnectionParameters.from_dict(dict_val) + assert param.https is False + assert param.to_dict() == {**dict_val, "https": False} + + +@pytest.mark.parametrize( + ("input_value", "expected_error"), + [ + ({"Foo": "Bar"}, MissingField), + ("foobar", ValueError), ], ids=["invalid-dict", "not-dict"], ) -def test_deserialization_errors(input_value, expected_msg): - with pytest.raises(KasaException, match=expected_msg): +def test_deserialization_errors(input_value, expected_error): + with pytest.raises(expected_error): DeviceConfig.from_dict(input_value) @@ -39,7 +101,7 @@ async def test_credentials_hash(): http_client=aiohttp.ClientSession(), credentials=Credentials("foo", "bar"), ) - config_dict = config.to_dict(credentials_hash="credhash") + config_dict = config.to_dict_control_credentials(credentials_hash="credhash") config_json = json_dumps(config_dict) config2_dict = json_loads(config_json) config2 = DeviceConfig.from_dict(config2_dict) @@ -53,7 +115,7 @@ async def test_blank_credentials_hash(): http_client=aiohttp.ClientSession(), credentials=Credentials("foo", "bar"), ) - config_dict = config.to_dict(credentials_hash="") + config_dict = config.to_dict_control_credentials(credentials_hash="") config_json = json_dumps(config_dict) config2_dict = json_loads(config_json) config2 = DeviceConfig.from_dict(config2_dict) @@ -67,7 +129,7 @@ async def test_exclude_credentials(): http_client=aiohttp.ClientSession(), credentials=Credentials("foo", "bar"), ) - config_dict = config.to_dict(exclude_credentials=True) + config_dict = config.to_dict_control_credentials(exclude_credentials=True) config_json = json_dumps(config_dict) config2_dict = json_loads(config_json) config2 = DeviceConfig.from_dict(config2_dict)