Migrate smart firmware module to mashumaro (#1276)

This commit is contained in:
Steven B. 2024-11-20 11:54:13 +00:00 committed by GitHub
parent 03c073c293
commit 999e84d2de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 26 deletions

View File

@ -6,10 +6,12 @@ import asyncio
import logging
from asyncio import timeout as asyncio_timeout
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import date
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated
from pydantic.v1 import BaseModel, Field, validator
from mashumaro import DataClassDictMixin, field_options
from mashumaro.types import Alias
from ...exceptions import KasaException
from ...feature import Feature
@ -22,36 +24,36 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
class DownloadState(BaseModel):
@dataclass
class DownloadState(DataClassDictMixin):
"""Download state."""
# Example:
# {'status': 0, 'download_progress': 0, 'reboot_time': 5,
# 'upgrade_time': 5, 'auto_upgrade': False}
status: int
progress: int = Field(alias="download_progress")
progress: Annotated[int, Alias("download_progress")]
reboot_time: int
upgrade_time: int
auto_upgrade: bool
class UpdateInfo(BaseModel):
@dataclass
class UpdateInfo(DataClassDictMixin):
"""Update info status object."""
status: int = Field(alias="type")
version: str | None = Field(alias="fw_ver", default=None)
release_date: date | None = None
release_notes: str | None = Field(alias="release_note", default=None)
status: Annotated[int, Alias("type")]
needs_upgrade: Annotated[bool, Alias("need_to_upgrade")]
version: Annotated[str | None, Alias("fw_ver")] = None
release_date: date | None = field(
default=None,
metadata=field_options(
deserialize=lambda x: date.fromisoformat(x) if x else None
),
)
release_notes: Annotated[str | None, Alias("release_note")] = None
fw_size: int | None = None
oem_id: str | None = None
needs_upgrade: bool = Field(alias="need_to_upgrade")
@validator("release_date", pre=True)
def _release_date_optional(cls, v: str) -> str | None:
if not v:
return None
return v
@property
def update_available(self) -> bool:
@ -139,7 +141,7 @@ class Firmware(SmartModule):
"""Check for the latest firmware for the device."""
try:
fw = await self.call("get_latest_fw")
self._firmware_update_info = UpdateInfo.parse_obj(fw["get_latest_fw"])
self._firmware_update_info = UpdateInfo.from_dict(fw["get_latest_fw"])
return self._firmware_update_info
except Exception:
_LOGGER.exception("Error getting latest firmware for %s:", self._device)
@ -174,7 +176,7 @@ class Firmware(SmartModule):
"""Return update state."""
resp = await self.call("get_fw_download_state")
state = resp["get_fw_download_state"]
return DownloadState(**state)
return DownloadState.from_dict(state)
@allow_update_after
async def update(
@ -232,7 +234,7 @@ class Firmware(SmartModule):
else:
_LOGGER.warning("Unhandled state code: %s", state)
return state.dict()
return state.to_dict()
@property
def auto_update_enabled(self) -> bool:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import logging
from contextlib import nullcontext
from datetime import date
from typing import TypedDict
import pytest
@ -52,6 +53,20 @@ async def test_firmware_features(
assert isinstance(feat.value, type)
@firmware
async def test_firmware_update_info(dev: SmartDevice):
"""Test that the firmware UpdateInfo object deserializes correctly."""
fw = dev.modules.get(Module.Firmware)
assert fw
if not dev.is_cloud_connected:
pytest.skip("Device is not cloud connected, skipping test")
assert fw.firmware_update_info is None
await fw.check_latest_firmware()
assert fw.firmware_update_info is not None
assert isinstance(fw.firmware_update_info.release_date, date | None)
@firmware
async def test_update_available_without_cloud(dev: SmartDevice):
"""Test that update_available returns None when disconnected."""
@ -105,15 +120,15 @@ async def test_firmware_update(
}
update_states = [
# Unknown 1
DownloadState(status=1, download_progress=0, **extras),
DownloadState(status=1, progress=0, **extras),
# Downloading
DownloadState(status=2, download_progress=10, **extras),
DownloadState(status=2, download_progress=100, **extras),
DownloadState(status=2, progress=10, **extras),
DownloadState(status=2, progress=100, **extras),
# Flashing
DownloadState(status=3, download_progress=100, **extras),
DownloadState(status=3, download_progress=100, **extras),
DownloadState(status=3, progress=100, **extras),
DownloadState(status=3, progress=100, **extras),
# Done
DownloadState(status=0, download_progress=100, **extras),
DownloadState(status=0, progress=100, **extras),
]
asyncio_sleep = asyncio.sleep