diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 626add0f..430515e4 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -2,9 +2,14 @@ from __future__ import annotations +import asyncio +import logging from datetime import date -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional +# When support for cpython older than 3.11 is dropped +# async_timeout can be replaced with asyncio.timeout +from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, Field, validator from ...exceptions import SmartErrorCode @@ -15,11 +20,27 @@ if TYPE_CHECKING: from ..smartdevice import SmartDevice +_LOGGER = logging.getLogger(__name__) + + +class DownloadState(BaseModel): + """Download state.""" + + # Example: + # {'status': 0, 'download_progress': 0, 'reboot_time': 5, + # 'upgrade_time': 5, 'auto_upgrade': False} + status: int + progress: int = Field(alias="download_progress") + reboot_time: int + upgrade_time: int + auto_upgrade: bool + + class UpdateInfo(BaseModel): """Update info status object.""" status: int = Field(alias="type") - fw_ver: Optional[str] = None # noqa: UP007 + version: Optional[str] = Field(alias="fw_ver", default=None) # noqa: UP007 release_date: Optional[date] = None # noqa: UP007 release_notes: Optional[str] = Field(alias="release_note", default=None) # noqa: UP007 fw_size: Optional[int] = None # noqa: UP007 @@ -71,6 +92,26 @@ class Firmware(SmartModule): category=Feature.Category.Info, ) ) + self._add_feature( + Feature( + device, + id="current_firmware_version", + name="Current firmware version", + container=self, + attribute_getter="current_firmware", + category=Feature.Category.Debug, + ) + ) + self._add_feature( + Feature( + device, + id="available_firmware_version", + name="Available firmware version", + container=self, + attribute_getter="latest_firmware", + category=Feature.Category.Debug, + ) + ) def query(self) -> dict: """Query to execute during the update cycle.""" @@ -80,7 +121,17 @@ class Firmware(SmartModule): return req @property - def latest_firmware(self): + def current_firmware(self) -> str: + """Return the current firmware version.""" + return self._device.hw_info["sw_ver"] + + @property + def latest_firmware(self) -> str: + """Return the latest firmware version.""" + return self.firmware_update_info.version + + @property + def firmware_update_info(self): """Return latest firmware information.""" fw = self.data.get("get_latest_fw") or self.data if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode): @@ -94,15 +145,62 @@ class Firmware(SmartModule): """Return True if update is available.""" if not self._device.is_cloud_connected: return None - return self.latest_firmware.update_available + return self.firmware_update_info.update_available - async def get_update_state(self): + async def get_update_state(self) -> DownloadState: """Return update state.""" - return await self.call("get_fw_download_state") + resp = await self.call("get_fw_download_state") + state = resp["get_fw_download_state"] + return DownloadState(**state) - async def update(self): + async def update( + self, progress_cb: Callable[[DownloadState], Coroutine] | None = None + ): """Update the device firmware.""" - return await self.call("fw_download") + current_fw = self.current_firmware + _LOGGER.info( + "Going to upgrade from %s to %s", + current_fw, + self.firmware_update_info.version, + ) + await self.call("fw_download") + + # TODO: read timeout from get_auto_update_info or from get_fw_download_state? + async with asyncio_timeout(60 * 5): + while True: + await asyncio.sleep(0.5) + try: + state = await self.get_update_state() + except Exception as ex: + _LOGGER.warning( + "Got exception, maybe the device is rebooting? %s", ex + ) + continue + + _LOGGER.debug("Update state: %s" % state) + if progress_cb is not None: + asyncio.create_task(progress_cb(state)) + + if state.status == 0: + _LOGGER.info( + "Update idle, hopefully updated to %s", + self.firmware_update_info.version, + ) + break + elif state.status == 2: + _LOGGER.info("Downloading firmware, progress: %s", state.progress) + elif state.status == 3: + upgrade_sleep = state.upgrade_time + _LOGGER.info( + "Flashing firmware, sleeping for %s before checking status", + upgrade_sleep, + ) + await asyncio.sleep(upgrade_sleep) + elif state.status < 0: + _LOGGER.error("Got error: %s", state.status) + break + else: + _LOGGER.warning("Unhandled state code: %s", state) @property def auto_update_enabled(self): @@ -115,4 +213,4 @@ class Firmware(SmartModule): async def set_auto_update_enabled(self, enabled: bool): """Change autoupdate setting.""" data = {**self.data["get_auto_update_info"], "enable": enabled} - await self.call("set_auto_update_info", data) # {"enable": enabled}) + await self.call("set_auto_update_info", data) diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index ae1a7ad6..5ca4a8ae 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -234,7 +234,7 @@ class FakeSmartTransport(BaseTransport): pytest.fixtures_missing_methods[self.fixture_name] = set() pytest.fixtures_missing_methods[self.fixture_name].add(method) return retval - elif method == "set_qs_info": + elif method in ["set_qs_info", "fw_download"]: return {"error_code": 0} elif method == "set_dynamic_light_effect_rule_enable": self._set_light_effect(info, params) diff --git a/kasa/tests/smart/modules/test_firmware.py b/kasa/tests/smart/modules/test_firmware.py new file mode 100644 index 00000000..d0df87ca --- /dev/null +++ b/kasa/tests/smart/modules/test_firmware.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +import logging + +import pytest +from pytest_mock import MockerFixture + +from kasa.smart import SmartDevice +from kasa.smart.modules import Firmware +from kasa.smart.modules.firmware import DownloadState +from kasa.tests.device_fixtures import parametrize + +firmware = parametrize( + "has firmware", component_filter="firmware", protocol_filter={"SMART"} +) + + +@firmware +@pytest.mark.parametrize( + "feature, prop_name, type, required_version", + [ + ("auto_update_enabled", "auto_update_enabled", bool, 2), + ("update_available", "update_available", bool, 1), + ("update_available", "update_available", bool, 1), + ("current_firmware_version", "current_firmware", str, 1), + ("available_firmware_version", "latest_firmware", str, 1), + ], +) +async def test_firmware_features( + dev: SmartDevice, feature, prop_name, type, required_version, mocker: MockerFixture +): + """Test light effect.""" + fw = dev.get_module(Firmware) + assert fw + + if not dev.is_cloud_connected: + pytest.skip("Device is not cloud connected, skipping test") + + if fw.supported_version < required_version: + pytest.skip("Feature %s requires newer version" % feature) + + prop = getattr(fw, prop_name) + assert isinstance(prop, type) + + feat = fw._module_features[feature] + assert feat.value == prop + assert isinstance(feat.value, type) + + +@firmware +async def test_update_available_without_cloud(dev: SmartDevice): + """Test that update_available returns None when disconnected.""" + fw = dev.get_module(Firmware) + assert fw + + if dev.is_cloud_connected: + assert isinstance(fw.update_available, bool) + else: + assert fw.update_available is None + + +@firmware +async def test_firmware_update( + dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture +): + """Test updating firmware.""" + caplog.set_level(logging.INFO) + + fw = dev.get_module(Firmware) + assert fw + + upgrade_time = 5 + extras = {"reboot_time": 5, "upgrade_time": upgrade_time, "auto_upgrade": False} + update_states = [ + # Unknown 1 + DownloadState(status=1, download_progress=0, **extras), + # Downloading + DownloadState(status=2, download_progress=10, **extras), + DownloadState(status=2, download_progress=100, **extras), + # Flashing + DownloadState(status=3, download_progress=100, **extras), + DownloadState(status=3, download_progress=100, **extras), + # Done + DownloadState(status=0, download_progress=100, **extras), + ] + + asyncio_sleep = asyncio.sleep + sleep = mocker.patch("asyncio.sleep") + mocker.patch.object(fw, "get_update_state", side_effect=update_states) + + cb_mock = mocker.AsyncMock() + + await fw.update(progress_cb=cb_mock) + + # This is necessary to allow the eventloop to process the created tasks + await asyncio_sleep(0) + + assert "Unhandled state code" in caplog.text + assert "Downloading firmware, progress: 10" in caplog.text + assert "Flashing firmware, sleeping" in caplog.text + assert "Update idle" in caplog.text + + for state in update_states: + cb_mock.assert_any_await(state) + + # sleep based on the upgrade_time + sleep.assert_any_call(upgrade_time)