diff --git a/kasa/__init__.py b/kasa/__init__.py index 62d54502..394fa72e 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -36,6 +36,7 @@ from kasa.exceptions import ( UnsupportedDeviceError, ) from kasa.feature import Feature +from kasa.firmware import Firmware, FirmwareUpdate from kasa.iot.iotbulb import BulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.iotprotocol import ( IotProtocol, @@ -72,6 +73,8 @@ __all__ = [ "ConnectionType", "EncryptType", "DeviceFamilyType", + "Firmware", + "FirmwareUpdate", ] from . import iot diff --git a/kasa/cli.py b/kasa/cli.py index 696dee27..386d2e1c 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -1252,5 +1252,47 @@ async def feature(dev: Device, child: str, name: str, value): return response +@cli.group(invoke_without_command=True) +@pass_dev +@click.pass_context +async def firmware(ctx: click.Context, dev: Device): + """Firmware update.""" + if ctx.invoked_subcommand is None: + return await ctx.invoke(firmware_info) + + +@firmware.command(name="info") +@pass_dev +@click.pass_context +async def firmware_info(ctx: click.Context, dev: Device): + """Return firmware information.""" + res = await dev.firmware.check_for_updates() + if res.update_available: + echo("[green bold]Update available![/green bold]") + echo(f"Current firmware: {res.current_version}") + echo(f"Version {res.available_version} released at {res.release_date}") + echo("Release notes") + echo("=============") + echo(res.release_notes) + echo("=============") + else: + echo("[red bold]No updates available.[/red bold]") + + +@firmware.command(name="update") +@pass_dev +@click.pass_context +async def firmware_update(ctx: click.Context, dev: Device): + """Perform firmware update.""" + await ctx.invoke(firmware_info) + click.confirm("Are you sure you want to upgrade the firmware?", abort=True) + + async def progress(x): + echo(f"Progress: {x}") + + echo("Going to update %s", dev) + await dev.firmware.update_firmware(progress_cb=progress) # type: ignore + + if __name__ == "__main__": cli() diff --git a/kasa/device.py b/kasa/device.py index ea358a8d..fc400a04 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -14,6 +14,7 @@ from .deviceconfig import DeviceConfig from .emeterstatus import EmeterStatus from .exceptions import KasaException from .feature import Feature +from .firmware import Firmware from .iotprotocol import IotProtocol from .module import Module, ModuleT from .protocol import BaseProtocol @@ -288,6 +289,11 @@ class Device(ABC): ) return self.children[index] + @property + @abstractmethod + def firmware(self) -> Firmware: + """Return firmware.""" + @property @abstractmethod def time(self) -> datetime: diff --git a/kasa/firmware.py b/kasa/firmware.py new file mode 100644 index 00000000..71592c64 --- /dev/null +++ b/kasa/firmware.py @@ -0,0 +1,41 @@ +"""Interface for firmware updates.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import date +from typing import Any, Awaitable, Callable + +UpdateResult = bool + + +@dataclass +class FirmwareUpdate: + """Update info status object.""" + + update_available: bool | None = None + current_version: str | None = None + available_version: str | None = None + release_date: date | None = None + release_notes: str | None = None + + +class Firmware(ABC): + """Interface to access firmware information and perform updates.""" + + @abstractmethod + async def update_firmware( + self, *, progress_cb: Callable[[Any, Any], Awaitable] + ) -> UpdateResult: + """Perform firmware update. + + This "blocks" until the update process has finished. + You can set *progress_cb* to get progress updates. + """ + raise NotImplementedError + + @abstractmethod + async def check_for_updates(self) -> FirmwareUpdate: + """Return information about available updates.""" + raise NotImplementedError diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 29ba3155..81735ec2 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -715,3 +715,9 @@ class IotDevice(Device): This should only be used for debugging purposes. """ return self._last_update or self._discovery_info + + @property + @requires_update + def firmware(self) -> Cloud: + """Returns object implementing the firmware handling.""" + return self.modules["cloud"] diff --git a/kasa/iot/modules/cloud.py b/kasa/iot/modules/cloud.py index 5022a68e..4bcfee5c 100644 --- a/kasa/iot/modules/cloud.py +++ b/kasa/iot/modules/cloud.py @@ -1,9 +1,25 @@ """Cloud module implementation.""" +from __future__ import annotations + +import logging + from pydantic.v1 import BaseModel +from datetime import date +from typing import Optional + from ...feature import Feature -from ..iotmodule import IotModule +from ...firmware import ( + Firmware, + UpdateResult, +) +from ...firmware import ( + FirmwareUpdate as FirmwareUpdateInterface, +) +from ..iotmodule import IotModule, merge + +_LOGGER = logging.getLogger(__name__) class CloudInfo(BaseModel): @@ -21,7 +37,31 @@ class CloudInfo(BaseModel): username: str -class Cloud(IotModule): +class FirmwareUpdate(BaseModel): + """Update info status object.""" + + status: int = Field(alias="fwType") + version: Optional[str] = Field(alias="fwVer", default=None) # noqa: UP007 + release_date: Optional[date] = Field(alias="fwReleaseDate", default=None) # noqa: UP007 + release_notes: Optional[str] = Field(alias="fwReleaseLog", default=None) # noqa: UP007 + url: Optional[str] = Field(alias="fwUrl", default=None) # noqa: UP007 + + @validator("release_date", pre=True) + def _release_date_optional(cls, v): + if not v: + return None + + return v + + @property + def update_available(self): + """Return True if update available.""" + if self.status != 0: + return True + return False + + +class Cloud(IotModule, Firmware): """Module implementing support for cloud services.""" def __init__(self, device, module): @@ -46,27 +86,73 @@ class Cloud(IotModule): def query(self): """Request cloud connectivity info.""" - return self.query_for_command("get_info") + req = self.query_for_command("get_info") + + # TODO: this is problematic, as it will fail the whole query on some + # devices if they are not connected to the internet + if self._module in self._device._last_update and self.is_connected: + req = merge(req, self.get_available_firmwares()) + + return req @property def info(self) -> CloudInfo: """Return information about the cloud connectivity.""" return CloudInfo.parse_obj(self.data["get_info"]) - def get_available_firmwares(self): + async def get_available_firmwares(self): """Return list of available firmwares.""" - return self.query_for_command("get_intl_fw_list") + return await self.call("get_intl_fw_list") - def set_server(self, url: str): + async def get_firmware_update(self) -> FirmwareUpdate: + """Return firmware update information.""" + try: + available_fws = (await self.get_available_firmwares()).get("fw_list", []) + if not available_fws: + return FirmwareUpdate(fwType=0) + if len(available_fws) > 1: + _LOGGER.warning( + "Got more than one update, using the first one: %s", available_fws + ) + return FirmwareUpdate.parse_obj(next(iter(available_fws))) + except Exception as ex: + _LOGGER.warning("Unable to check for firmware update: %s", ex) + return FirmwareUpdate(fwType=0) + + async def set_server(self, url: str): """Set the update server URL.""" - return self.query_for_command("set_server_url", {"server": url}) + return await self.call("set_server_url", {"server": url}) - def connect(self, username: str, password: str): + async def connect(self, username: str, password: str): """Login to the cloud using given information.""" - return self.query_for_command( - "bind", {"username": username, "password": password} - ) + return await self.call("bind", {"username": username, "password": password}) - def disconnect(self): + async def disconnect(self): """Disconnect from the cloud.""" - return self.query_for_command("unbind") + return await self.call("unbind") + + async def update_firmware(self, *, progress_cb=None) -> UpdateResult: + """Perform firmware update.""" + raise NotImplementedError + i = 0 + import asyncio + + while i < 100: + await asyncio.sleep(1) + if progress_cb is not None: + await progress_cb(i) + i += 10 + + return UpdateResult("") + + async def check_for_updates(self) -> FirmwareUpdateInterface: + """Return firmware update information.""" + fw = await self.get_firmware_update() + + return FirmwareUpdateInterface( + update_available=fw.update_available, + current_version=self._device.hw_info.get("sw_ver"), + available_version=fw.version, + release_date=fw.release_date, + release_notes=fw.release_notes, + ) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 14a23aaa..07616cf4 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -11,13 +11,15 @@ from typing import TYPE_CHECKING, Any, Optional # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, Field, validator - # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout from ...exceptions import SmartErrorCode from ...feature import Feature, FeatureType +from ...firmware import Firmware as FirmwareInterface +from ...firmware import FirmwareUpdate as FirmwareUpdateInterface +from ...firmware import UpdateResult from ..smartmodule import SmartModule if TYPE_CHECKING: @@ -27,7 +29,7 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) -class UpdateInfo(BaseModel): +class FirmwareUpdate(BaseModel): """Update info status object.""" status: int = Field(alias="type") @@ -53,7 +55,7 @@ class UpdateInfo(BaseModel): return False -class Firmware(SmartModule): +class Firmware(SmartModule, FirmwareInterface): """Implementation of firmware module.""" REQUIRED_COMPONENT = "firmware" @@ -143,9 +145,9 @@ class Firmware(SmartModule): fw = self.data.get("get_latest_fw") or self.data if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode): # Error in response, probably disconnected from the cloud. - return UpdateInfo(type=0, need_to_upgrade=False) + return FirmwareUpdate(type=0, need_to_upgrade=False) - return UpdateInfo.parse_obj(fw) + return FirmwareUpdate.parse_obj(fw) @property def update_available(self) -> bool | None: @@ -192,3 +194,20 @@ class Firmware(SmartModule): """Change autoupdate setting.""" data = {**self.data["get_auto_update_info"], "enable": enabled} await self.call("set_auto_update_info", data) + + async def update_firmware(self, *, progress_cb) -> UpdateResult: + """Update the firmware.""" + # TODO: implement, this is part of the common firmware API + raise NotImplementedError + + async def check_for_updates(self) -> FirmwareUpdateInterface: + """Return firmware update information.""" + # TODO: naming of the common firmware API methods + info = self.firmware_update_info + return FirmwareUpdateInterface( + current_version=self.current_firmware, + update_available=info.update_available, + available_version=info.version, + release_date=info.release_date, + release_notes=info.release_notes, + ) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 89813387..1e53bc0d 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -625,6 +625,13 @@ class SmartDevice(Bulb, Fan, Device): return self._device_type + @property + def firmware(self) -> FirmwareInterface: + """Return firmware module.""" + # TODO: open question: does it make sense to expose common modules? + fw = cast(FirmwareInterface, self.modules["Firmware"]) + return fw + @staticmethod def _get_device_type_from_components( components: list[str], device_type: str