263 lines
7.3 KiB
Python

"""Implementation for map module."""
import hashlib
import logging
from base64 import b64decode
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from mashumaro import DataClassDictMixin, field_options
from mashumaro.config import BaseConfig
from mashumaro.types import Discriminator, SerializationStrategy
from ...exceptions import KasaException
from ...feature import Feature
from ..smartmodule import SmartModule
from .clean import FanSpeed
if TYPE_CHECKING:
import PIL
_LOGGER = logging.getLogger(__name__)
@dataclass
class MapSummary(DataClassDictMixin):
"""Class representing map summary from mapinfo response."""
map_id: int
rotate_angle: int
is_saved: bool
update_time: int
auto_area_flag: bool
map_locked: int | None = field(default=None)
global_cleaned: int | None = field(default=None)
class Maps(SerializationStrategy):
"""Strategy to deserialize list of maps into a dict."""
def deserialize(self, value: str) -> dict:
"""Deserialize list of maps into a dict."""
maps = {x["map_id"]: MapSummary.from_dict(x) for x in value}
return maps
@dataclass
class Area(DataClassDictMixin):
"""Reprsentation of an area.
This can be a room, a carpet, etc.
"""
class Config(BaseConfig):
"""Configuration."""
discriminator = Discriminator(
field="type",
include_subtypes=True,
)
@dataclass
class Room(Area):
"""Room area."""
fanspeed: FanSpeed = field(metadata=field_options(alias="suction"))
cistern: int = field() # TODO: enumize
clean_count: int = field(metadata=field_options(alias="clean_number"))
id: int
type: str = "room"
@dataclass
class Carpet(Area):
"""Carpet area."""
vertexs: list
carpet_strategy: int
id: int
type: str = "carpet_rectangle"
class Areas(SerializationStrategy):
"""Strategy to deserialize list of areas into a dict."""
def deserialize(self, value: list) -> dict:
"""Deserialize list of areas into a dict."""
areas = {x["id"]: Area.from_dict(x) for x in value}
return areas
@dataclass
class MapInfo(DataClassDictMixin):
"""Class representing getMapInfo response."""
map_num: int
version: str
current_map_id: int
auto_change_map: bool
maps: dict[int, MapSummary] = field(
metadata=field_options(serialization_strategy=Maps(), alias="map_list")
)
@dataclass
class MapData(DataClassDictMixin):
"""Class representing getMapData response."""
auto_area: bool = field(metadata=field_options(alias="auto_area_flag"))
path_id: int
version: str
map_id: int
resolution: int
resolution_unit: str
width: int
height: int
origin_coor: tuple[int, int, int]
real_origin_coor: tuple[int, int, int]
bitnum: int
palette: dict = field(metadata=field_options(alias="bit_list"))
pix_len: int
map_hash: str
pix_lz4len: int
map_data: bytes = field(metadata={"deserialize": b64decode}, repr=False)
areas: dict = field(
metadata=field_options(serialization_strategy=Areas(), alias="area_list")
)
def get_image(self, mapinfo: MapInfo | None = None) -> "PIL.Image":
"""Return image object map getMapData response."""
# TODO: move assert checks to mashumaro checks
try:
import lz4.block
from PIL import Image
except ImportError as ex:
raise KasaException(
"You need to have lz4 and pillow installed to use this function."
) from ex
if len(self.map_data) != self.pix_lz4len:
raise KasaException("Invalid map data length")
if self.width * self.height != self.pix_len:
raise KasaException("Invalid payload")
_LOGGER.debug("resolution: %s %s", self.resolution, self.resolution_unit)
_LOGGER.debug("Size: %s x %s", self.width, self.height)
_LOGGER.debug("Bits per pixel: %s", self.bitnum)
_LOGGER.debug("origin: %s", self.origin_coor)
_LOGGER.debug("real origin: %s", self.real_origin_coor)
for area_id, area in self.areas.items():
_LOGGER.debug("Area %s: %s", area_id, area)
_LOGGER.debug("Palette: %s", self.palette)
# TODO: use nicer palette
img_data = lz4.block.decompress(self.map_data, uncompressed_size=self.pix_len)
img_data_hash = hashlib.md5(img_data).hexdigest().upper() # noqa: S324
if self.map_hash != img_data_hash:
raise KasaException("Invalid map hash")
match self.bitnum:
case 8:
mode = "L" # 8bit gray
case _:
raise KasaException(f"Unknown bitnum {self.bitnum}")
img = Image.frombytes(mode, (self.width, self.height), data=img_data)
# rotate
img = img.rotate(mapinfo.maps[self.map_id].rotate_angle)
return img
@dataclass
class PathData(DataClassDictMixin):
"""Path data container."""
path_id: int
points: int = field(metadata=field_options(alias="point_counts"))
total_points: int
data: bytes = field(
metadata=field_options(deserialize=b64decode, alias="pos_array"), repr=False
)
data_len: int = field(metadata=field_options(alias="pos_len"))
data_lz4len: int = field(metadata=field_options(alias="pos_lz4len"))
def get_decompressed_data(self) -> bytes:
"""Return decompressed path data."""
try:
import lz4.block
except ImportError as ex:
raise KasaException(
"You need to have lz4 and pillow installed to use this function."
) from ex
decompressed = lz4.block.decompress(self.data, uncompressed_size=self.data_len)
if len(decompressed) != self.data_len:
raise KasaException("Invalid data length")
return decompressed
class Map(SmartModule):
"""Implementation of vacuum map module."""
REQUIRED_COMPONENT = "map"
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {
"getMapInfo": {},
"getMapData": {"map_id": -1},
"getPathData": {"start_pos": 0},
}
def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
self._device,
id="map_count",
name="Map count",
container=self,
attribute_getter="map_count",
category=Feature.Category.Debug,
type=Feature.Sensor,
)
)
@property
def map_count(self) -> int:
"""Return number of maps."""
return self.map_info.map_num
@property
def map_info(self) -> MapInfo:
"""Return map information."""
return MapInfo.from_dict(self.data["getMapInfo"])
@property
def map_data(self) -> MapData:
"""Return map data."""
return MapData.from_dict(self.data["getMapData"])
@property
def path_data(self) -> PathData:
"""Return path data."""
return PathData.from_dict(self.data["getPathData"])
def get_path(self) -> bytes:
"""Return path as an image."""
return self.path_data.get_decompressed_data()
def get_map_image(self) -> "PIL.Image":
"""Return map as an image."""
return self.map_data.get_image(self.map_info)