Migrate smart firmware module to mashumaro (#1276)

This commit is contained in:
Steven B. 2024-11-20 11:54:13 +00:00 committed by GitHub
parent 03c073c293
commit 999e84d2de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 26 deletions

View File

@ -6,10 +6,12 @@ import asyncio
import logging import logging
from asyncio import timeout as asyncio_timeout from asyncio import timeout as asyncio_timeout
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import date from datetime import date
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Annotated
from pydantic.v1 import BaseModel, Field, validator from mashumaro import DataClassDictMixin, field_options
from mashumaro.types import Alias
from ...exceptions import KasaException from ...exceptions import KasaException
from ...feature import Feature from ...feature import Feature
@ -22,36 +24,36 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class DownloadState(BaseModel): @dataclass
class DownloadState(DataClassDictMixin):
"""Download state.""" """Download state."""
# Example: # Example:
# {'status': 0, 'download_progress': 0, 'reboot_time': 5, # {'status': 0, 'download_progress': 0, 'reboot_time': 5,
# 'upgrade_time': 5, 'auto_upgrade': False} # 'upgrade_time': 5, 'auto_upgrade': False}
status: int status: int
progress: int = Field(alias="download_progress") progress: Annotated[int, Alias("download_progress")]
reboot_time: int reboot_time: int
upgrade_time: int upgrade_time: int
auto_upgrade: bool auto_upgrade: bool
class UpdateInfo(BaseModel): @dataclass
class UpdateInfo(DataClassDictMixin):
"""Update info status object.""" """Update info status object."""
status: int = Field(alias="type") status: Annotated[int, Alias("type")]
version: str | None = Field(alias="fw_ver", default=None) needs_upgrade: Annotated[bool, Alias("need_to_upgrade")]
release_date: date | None = None version: Annotated[str | None, Alias("fw_ver")] = None
release_notes: str | None = Field(alias="release_note", default=None) release_date: date | None = field(
default=None,
metadata=field_options(
deserialize=lambda x: date.fromisoformat(x) if x else None
),
)
release_notes: Annotated[str | None, Alias("release_note")] = None
fw_size: int | None = None fw_size: int | None = None
oem_id: str | None = None oem_id: str | None = None
needs_upgrade: bool = Field(alias="need_to_upgrade")
@validator("release_date", pre=True)
def _release_date_optional(cls, v: str) -> str | None:
if not v:
return None
return v
@property @property
def update_available(self) -> bool: def update_available(self) -> bool:
@ -139,7 +141,7 @@ class Firmware(SmartModule):
"""Check for the latest firmware for the device.""" """Check for the latest firmware for the device."""
try: try:
fw = await self.call("get_latest_fw") fw = await self.call("get_latest_fw")
self._firmware_update_info = UpdateInfo.parse_obj(fw["get_latest_fw"]) self._firmware_update_info = UpdateInfo.from_dict(fw["get_latest_fw"])
return self._firmware_update_info return self._firmware_update_info
except Exception: except Exception:
_LOGGER.exception("Error getting latest firmware for %s:", self._device) _LOGGER.exception("Error getting latest firmware for %s:", self._device)
@ -174,7 +176,7 @@ class Firmware(SmartModule):
"""Return update state.""" """Return update state."""
resp = await self.call("get_fw_download_state") resp = await self.call("get_fw_download_state")
state = resp["get_fw_download_state"] state = resp["get_fw_download_state"]
return DownloadState(**state) return DownloadState.from_dict(state)
@allow_update_after @allow_update_after
async def update( async def update(
@ -232,7 +234,7 @@ class Firmware(SmartModule):
else: else:
_LOGGER.warning("Unhandled state code: %s", state) _LOGGER.warning("Unhandled state code: %s", state)
return state.dict() return state.to_dict()
@property @property
def auto_update_enabled(self) -> bool: def auto_update_enabled(self) -> bool:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from contextlib import nullcontext from contextlib import nullcontext
from datetime import date
from typing import TypedDict from typing import TypedDict
import pytest import pytest
@ -52,6 +53,20 @@ async def test_firmware_features(
assert isinstance(feat.value, type) assert isinstance(feat.value, type)
@firmware
async def test_firmware_update_info(dev: SmartDevice):
"""Test that the firmware UpdateInfo object deserializes correctly."""
fw = dev.modules.get(Module.Firmware)
assert fw
if not dev.is_cloud_connected:
pytest.skip("Device is not cloud connected, skipping test")
assert fw.firmware_update_info is None
await fw.check_latest_firmware()
assert fw.firmware_update_info is not None
assert isinstance(fw.firmware_update_info.release_date, date | None)
@firmware @firmware
async def test_update_available_without_cloud(dev: SmartDevice): async def test_update_available_without_cloud(dev: SmartDevice):
"""Test that update_available returns None when disconnected.""" """Test that update_available returns None when disconnected."""
@ -105,15 +120,15 @@ async def test_firmware_update(
} }
update_states = [ update_states = [
# Unknown 1 # Unknown 1
DownloadState(status=1, download_progress=0, **extras), DownloadState(status=1, progress=0, **extras),
# Downloading # Downloading
DownloadState(status=2, download_progress=10, **extras), DownloadState(status=2, progress=10, **extras),
DownloadState(status=2, download_progress=100, **extras), DownloadState(status=2, progress=100, **extras),
# Flashing # Flashing
DownloadState(status=3, download_progress=100, **extras), DownloadState(status=3, progress=100, **extras),
DownloadState(status=3, download_progress=100, **extras), DownloadState(status=3, progress=100, **extras),
# Done # Done
DownloadState(status=0, download_progress=100, **extras), DownloadState(status=0, progress=100, **extras),
] ]
asyncio_sleep = asyncio.sleep asyncio_sleep = asyncio.sleep