diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 5956a357..8dd3a6b3 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -6,10 +6,12 @@ import asyncio import logging from asyncio import timeout as asyncio_timeout from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field from datetime import date -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated -from pydantic.v1 import BaseModel, Field, validator +from mashumaro import DataClassDictMixin, field_options +from mashumaro.types import Alias from ...exceptions import KasaException from ...feature import Feature @@ -22,36 +24,36 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) -class DownloadState(BaseModel): +@dataclass +class DownloadState(DataClassDictMixin): """Download state.""" # Example: # {'status': 0, 'download_progress': 0, 'reboot_time': 5, # 'upgrade_time': 5, 'auto_upgrade': False} status: int - progress: int = Field(alias="download_progress") + progress: Annotated[int, Alias("download_progress")] reboot_time: int upgrade_time: int auto_upgrade: bool -class UpdateInfo(BaseModel): +@dataclass +class UpdateInfo(DataClassDictMixin): """Update info status object.""" - status: int = Field(alias="type") - version: str | None = Field(alias="fw_ver", default=None) - release_date: date | None = None - release_notes: str | None = Field(alias="release_note", default=None) + status: Annotated[int, Alias("type")] + needs_upgrade: Annotated[bool, Alias("need_to_upgrade")] + version: Annotated[str | None, Alias("fw_ver")] = None + release_date: date | None = field( + default=None, + metadata=field_options( + deserialize=lambda x: date.fromisoformat(x) if x else None + ), + ) + release_notes: Annotated[str | None, Alias("release_note")] = None fw_size: int | None = None oem_id: str | None = None - needs_upgrade: bool = Field(alias="need_to_upgrade") - - @validator("release_date", pre=True) - def _release_date_optional(cls, v: str) -> str | None: - if not v: - return None - - return v @property def update_available(self) -> bool: @@ -139,7 +141,7 @@ class Firmware(SmartModule): """Check for the latest firmware for the device.""" try: fw = await self.call("get_latest_fw") - self._firmware_update_info = UpdateInfo.parse_obj(fw["get_latest_fw"]) + self._firmware_update_info = UpdateInfo.from_dict(fw["get_latest_fw"]) return self._firmware_update_info except Exception: _LOGGER.exception("Error getting latest firmware for %s:", self._device) @@ -174,7 +176,7 @@ class Firmware(SmartModule): """Return update state.""" resp = await self.call("get_fw_download_state") state = resp["get_fw_download_state"] - return DownloadState(**state) + return DownloadState.from_dict(state) @allow_update_after async def update( @@ -232,7 +234,7 @@ class Firmware(SmartModule): else: _LOGGER.warning("Unhandled state code: %s", state) - return state.dict() + return state.to_dict() @property def auto_update_enabled(self) -> bool: diff --git a/tests/smart/modules/test_firmware.py b/tests/smart/modules/test_firmware.py index 3115c56f..0bc0a4ea 100644 --- a/tests/smart/modules/test_firmware.py +++ b/tests/smart/modules/test_firmware.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import logging from contextlib import nullcontext +from datetime import date from typing import TypedDict import pytest @@ -52,6 +53,20 @@ async def test_firmware_features( assert isinstance(feat.value, type) +@firmware +async def test_firmware_update_info(dev: SmartDevice): + """Test that the firmware UpdateInfo object deserializes correctly.""" + fw = dev.modules.get(Module.Firmware) + assert fw + + if not dev.is_cloud_connected: + pytest.skip("Device is not cloud connected, skipping test") + assert fw.firmware_update_info is None + await fw.check_latest_firmware() + assert fw.firmware_update_info is not None + assert isinstance(fw.firmware_update_info.release_date, date | None) + + @firmware async def test_update_available_without_cloud(dev: SmartDevice): """Test that update_available returns None when disconnected.""" @@ -105,15 +120,15 @@ async def test_firmware_update( } update_states = [ # Unknown 1 - DownloadState(status=1, download_progress=0, **extras), + DownloadState(status=1, progress=0, **extras), # Downloading - DownloadState(status=2, download_progress=10, **extras), - DownloadState(status=2, download_progress=100, **extras), + DownloadState(status=2, progress=10, **extras), + DownloadState(status=2, progress=100, **extras), # Flashing - DownloadState(status=3, download_progress=100, **extras), - DownloadState(status=3, download_progress=100, **extras), + DownloadState(status=3, progress=100, **extras), + DownloadState(status=3, progress=100, **extras), # Done - DownloadState(status=0, download_progress=100, **extras), + DownloadState(status=0, progress=100, **extras), ] asyncio_sleep = asyncio.sleep