diff --git a/kasa/module.py b/kasa/module.py index 0c5a0489..a5fe4a42 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -164,6 +164,7 @@ class Module(ABC): # Vacuum modules Clean: Final[ModuleName[smart.Clean]] = ModuleName("Clean") Dustbin: Final[ModuleName[smart.Dustbin]] = ModuleName("Dustbin") + Map: Final[ModuleName[smart.Map]] = ModuleName("Map") Speaker: Final[ModuleName[smart.Speaker]] = ModuleName("Speaker") def __init__(self, device: Device, module: str) -> None: diff --git a/kasa/smart/modules/__init__.py b/kasa/smart/modules/__init__.py index deb09f4f..95cc94f3 100644 --- a/kasa/smart/modules/__init__.py +++ b/kasa/smart/modules/__init__.py @@ -26,6 +26,7 @@ from .lighteffect import LightEffect from .lightpreset import LightPreset from .lightstripeffect import LightStripEffect from .lighttransition import LightTransition +from .map import Map from .matter import Matter from .motionsensor import MotionSensor from .overheatprotection import OverheatProtection @@ -76,4 +77,5 @@ __all__ = [ "HomeKit", "Matter", "Dustbin", + "Map", ] diff --git a/kasa/smart/modules/map.py b/kasa/smart/modules/map.py new file mode 100644 index 00000000..6a7c9041 --- /dev/null +++ b/kasa/smart/modules/map.py @@ -0,0 +1,262 @@ +"""Implementation for map module.""" + +import hashlib +import logging +from base64 import b64decode +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from mashumaro import DataClassDictMixin, field_options +from mashumaro.config import BaseConfig +from mashumaro.types import Discriminator, SerializationStrategy + +from ...exceptions import KasaException +from ...feature import Feature +from ..smartmodule import SmartModule +from .clean import FanSpeed + +if TYPE_CHECKING: + import PIL + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class MapSummary(DataClassDictMixin): + """Class representing map summary from mapinfo response.""" + + map_id: int + rotate_angle: int + is_saved: bool + update_time: int + global_cleaned: int + auto_area_flag: bool + map_locked: int + + +class Maps(SerializationStrategy): + """Strategy to deserialize list of maps into a dict.""" + + def deserialize(self, value: str) -> dict: + """Deserialize list of maps into a dict.""" + maps = {x["map_id"]: MapSummary.from_dict(x) for x in value} + return maps + + +@dataclass +class Area(DataClassDictMixin): + """Reprsentation of an area. + + This can be a room, a carpet, etc. + """ + + class Config(BaseConfig): + """Configuration.""" + + discriminator = Discriminator( + field="type", + include_subtypes=True, + ) + + +@dataclass +class Room(Area): + """Room area.""" + + fanspeed: FanSpeed = field(metadata=field_options(alias="suction")) + cistern: int = field() # TODO: enumize + + clean_count: int = field(metadata=field_options(alias="clean_number")) + id: int + type: str = "room" + + +@dataclass +class Carpet(Area): + """Carpet area.""" + + vertexs: list + carpet_strategy: int + id: int + type: str = "carpet_rectangle" + + +class Areas(SerializationStrategy): + """Strategy to deserialize list of areas into a dict.""" + + def deserialize(self, value: list) -> dict: + """Deserialize list of areas into a dict.""" + areas = {x["id"]: Area.from_dict(x) for x in value} + return areas + + +@dataclass +class MapInfo(DataClassDictMixin): + """Class representing getMapInfo response.""" + + map_num: int + version: str + current_map_id: int + auto_change_map: bool + maps: dict[int, MapSummary] = field( + metadata=field_options(serialization_strategy=Maps(), alias="map_list") + ) + + +@dataclass +class MapData(DataClassDictMixin): + """Class representing getMapData response.""" + + auto_area: bool = field(metadata=field_options(alias="auto_area_flag")) + path_id: int + version: str + map_id: int + resolution: int + resolution_unit: str + width: int + height: int + origin_coor: tuple[int, int, int] + real_origin_coor: tuple[int, int, int] + bitnum: int + palette: dict = field(metadata=field_options(alias="bit_list")) + pix_len: int + map_hash: str + pix_lz4len: int + map_data: bytes = field(metadata={"deserialize": b64decode}, repr=False) + areas: dict = field( + metadata=field_options(serialization_strategy=Areas(), alias="area_list") + ) + + def get_image(self, mapinfo: MapInfo | None = None) -> "PIL.Image": + """Return image object map getMapData response.""" + # TODO: move assert checks to mashumaro checks + + try: + import lz4.block + from PIL import Image + except ImportError as ex: + raise KasaException( + "You need to have lz4 and pillow installed to use this function." + ) from ex + + if len(self.map_data) != self.pix_lz4len: + raise KasaException("Invalid map data length") + + if self.width * self.height != self.pix_len: + raise KasaException("Invalid payload") + + _LOGGER.debug("resolution: %s %s", self.resolution, self.resolution_unit) + _LOGGER.debug("Size: %s x %s", self.width, self.height) + _LOGGER.debug("Bits per pixel: %s", self.bitnum) + _LOGGER.debug("origin: %s", self.origin_coor) + _LOGGER.debug("real origin: %s", self.real_origin_coor) + for area_id, area in self.areas.items(): + _LOGGER.debug("Area %s: %s", area_id, area) + + _LOGGER.debug("Palette: %s", self.palette) + + # TODO: use nicer palette + + img_data = lz4.block.decompress(self.map_data, uncompressed_size=self.pix_len) + + img_data_hash = hashlib.md5(img_data).hexdigest().upper() # noqa: S324 + if self.map_hash != img_data_hash: + raise KasaException("Invalid map hash") + + match self.bitnum: + case 8: + mode = "L" # 8bit gray + case _: + raise KasaException(f"Unknown bitnum {self.bitnum}") + + img = Image.frombytes(mode, (self.width, self.height), data=img_data) + + # rotate + img = img.rotate(mapinfo.maps[self.map_id].rotate_angle) + + return img + + +@dataclass +class PathData(DataClassDictMixin): + """Path data container.""" + + path_id: int + points: int = field(metadata=field_options(alias="point_counts")) + total_points: int + data: bytes = field( + metadata=field_options(deserialize=b64decode, alias="pos_array"), repr=False + ) + data_len: int = field(metadata=field_options(alias="pos_len")) + data_lz4len: int = field(metadata=field_options(alias="pos_lz4len")) + + def get_decompressed_data(self) -> bytes: + """Return decompressed path data.""" + try: + import lz4.block + except ImportError as ex: + raise KasaException( + "You need to have lz4 and pillow installed to use this function." + ) from ex + + decompressed = lz4.block.decompress(self.data, uncompressed_size=self.data_len) + if len(decompressed) != self.data_len: + raise KasaException("Invalid data length") + + return decompressed + + +class Map(SmartModule): + """Implementation of vacuum map module.""" + + REQUIRED_COMPONENT = "map" + + def query(self) -> dict: + """Query to execute during the update cycle.""" + return { + "getMapInfo": {}, + "getMapData": {"map_id": -1}, + "getPathData": {"start_pos": 0}, + } + + def _initialize_features(self) -> None: + """Initialize features.""" + self._add_feature( + Feature( + self._device, + id="map_count", + name="Map count", + container=self, + attribute_getter="map_count", + category=Feature.Category.Debug, + type=Feature.Sensor, + ) + ) + + @property + def map_count(self) -> int: + """Return number of maps.""" + return self.map_info.map_num + + @property + def map_info(self) -> MapInfo: + """Return map information.""" + return MapInfo.from_dict(self.data["getMapInfo"]) + + @property + def map_data(self) -> MapData: + """Return map data.""" + return MapData.from_dict(self.data["getMapData"]) + + @property + def path_data(self) -> PathData: + """Return path data.""" + return PathData.from_dict(self.data["getPathData"]) + + def get_path(self) -> bytes: + """Return path as an image.""" + return self.path_data.get_decompressed_data() + + def get_map_image(self) -> "PIL.Image": + """Return map as an image.""" + return self.map_data.get_image(self.map_info)