mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-05-30 21:51:24 +00:00
263 lines
7.3 KiB
Python
263 lines
7.3 KiB
Python
"""Implementation for map module."""
|
|
|
|
import hashlib
|
|
import logging
|
|
from base64 import b64decode
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING
|
|
|
|
from mashumaro import DataClassDictMixin, field_options
|
|
from mashumaro.config import BaseConfig
|
|
from mashumaro.types import Discriminator, SerializationStrategy
|
|
|
|
from ...exceptions import KasaException
|
|
from ...feature import Feature
|
|
from ..smartmodule import SmartModule
|
|
from .clean import FanSpeed
|
|
|
|
if TYPE_CHECKING:
|
|
import PIL
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class MapSummary(DataClassDictMixin):
|
|
"""Class representing map summary from mapinfo response."""
|
|
|
|
map_id: int
|
|
rotate_angle: int
|
|
is_saved: bool
|
|
update_time: int
|
|
auto_area_flag: bool
|
|
map_locked: int | None = field(default=None)
|
|
global_cleaned: int | None = field(default=None)
|
|
|
|
|
|
class Maps(SerializationStrategy):
|
|
"""Strategy to deserialize list of maps into a dict."""
|
|
|
|
def deserialize(self, value: str) -> dict:
|
|
"""Deserialize list of maps into a dict."""
|
|
maps = {x["map_id"]: MapSummary.from_dict(x) for x in value}
|
|
return maps
|
|
|
|
|
|
@dataclass
|
|
class Area(DataClassDictMixin):
|
|
"""Reprsentation of an area.
|
|
|
|
This can be a room, a carpet, etc.
|
|
"""
|
|
|
|
class Config(BaseConfig):
|
|
"""Configuration."""
|
|
|
|
discriminator = Discriminator(
|
|
field="type",
|
|
include_subtypes=True,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Room(Area):
|
|
"""Room area."""
|
|
|
|
fanspeed: FanSpeed = field(metadata=field_options(alias="suction"))
|
|
cistern: int = field() # TODO: enumize
|
|
|
|
clean_count: int = field(metadata=field_options(alias="clean_number"))
|
|
id: int
|
|
type: str = "room"
|
|
|
|
|
|
@dataclass
|
|
class Carpet(Area):
|
|
"""Carpet area."""
|
|
|
|
vertexs: list
|
|
carpet_strategy: int
|
|
id: int
|
|
type: str = "carpet_rectangle"
|
|
|
|
|
|
class Areas(SerializationStrategy):
|
|
"""Strategy to deserialize list of areas into a dict."""
|
|
|
|
def deserialize(self, value: list) -> dict:
|
|
"""Deserialize list of areas into a dict."""
|
|
areas = {x["id"]: Area.from_dict(x) for x in value}
|
|
return areas
|
|
|
|
|
|
@dataclass
|
|
class MapInfo(DataClassDictMixin):
|
|
"""Class representing getMapInfo response."""
|
|
|
|
map_num: int
|
|
version: str
|
|
current_map_id: int
|
|
auto_change_map: bool
|
|
maps: dict[int, MapSummary] = field(
|
|
metadata=field_options(serialization_strategy=Maps(), alias="map_list")
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MapData(DataClassDictMixin):
|
|
"""Class representing getMapData response."""
|
|
|
|
auto_area: bool = field(metadata=field_options(alias="auto_area_flag"))
|
|
path_id: int
|
|
version: str
|
|
map_id: int
|
|
resolution: int
|
|
resolution_unit: str
|
|
width: int
|
|
height: int
|
|
origin_coor: tuple[int, int, int]
|
|
real_origin_coor: tuple[int, int, int]
|
|
bitnum: int
|
|
palette: dict = field(metadata=field_options(alias="bit_list"))
|
|
pix_len: int
|
|
map_hash: str
|
|
pix_lz4len: int
|
|
map_data: bytes = field(metadata={"deserialize": b64decode}, repr=False)
|
|
areas: dict = field(
|
|
metadata=field_options(serialization_strategy=Areas(), alias="area_list")
|
|
)
|
|
|
|
def get_image(self, mapinfo: MapInfo | None = None) -> "PIL.Image":
|
|
"""Return image object map getMapData response."""
|
|
# TODO: move assert checks to mashumaro checks
|
|
|
|
try:
|
|
import lz4.block
|
|
from PIL import Image
|
|
except ImportError as ex:
|
|
raise KasaException(
|
|
"You need to have lz4 and pillow installed to use this function."
|
|
) from ex
|
|
|
|
if len(self.map_data) != self.pix_lz4len:
|
|
raise KasaException("Invalid map data length")
|
|
|
|
if self.width * self.height != self.pix_len:
|
|
raise KasaException("Invalid payload")
|
|
|
|
_LOGGER.debug("resolution: %s %s", self.resolution, self.resolution_unit)
|
|
_LOGGER.debug("Size: %s x %s", self.width, self.height)
|
|
_LOGGER.debug("Bits per pixel: %s", self.bitnum)
|
|
_LOGGER.debug("origin: %s", self.origin_coor)
|
|
_LOGGER.debug("real origin: %s", self.real_origin_coor)
|
|
for area_id, area in self.areas.items():
|
|
_LOGGER.debug("Area %s: %s", area_id, area)
|
|
|
|
_LOGGER.debug("Palette: %s", self.palette)
|
|
|
|
# TODO: use nicer palette
|
|
|
|
img_data = lz4.block.decompress(self.map_data, uncompressed_size=self.pix_len)
|
|
|
|
img_data_hash = hashlib.md5(img_data).hexdigest().upper() # noqa: S324
|
|
if self.map_hash != img_data_hash:
|
|
raise KasaException("Invalid map hash")
|
|
|
|
match self.bitnum:
|
|
case 8:
|
|
mode = "L" # 8bit gray
|
|
case _:
|
|
raise KasaException(f"Unknown bitnum {self.bitnum}")
|
|
|
|
img = Image.frombytes(mode, (self.width, self.height), data=img_data)
|
|
|
|
# rotate
|
|
img = img.rotate(mapinfo.maps[self.map_id].rotate_angle)
|
|
|
|
return img
|
|
|
|
|
|
@dataclass
|
|
class PathData(DataClassDictMixin):
|
|
"""Path data container."""
|
|
|
|
path_id: int
|
|
points: int = field(metadata=field_options(alias="point_counts"))
|
|
total_points: int
|
|
data: bytes = field(
|
|
metadata=field_options(deserialize=b64decode, alias="pos_array"), repr=False
|
|
)
|
|
data_len: int = field(metadata=field_options(alias="pos_len"))
|
|
data_lz4len: int = field(metadata=field_options(alias="pos_lz4len"))
|
|
|
|
def get_decompressed_data(self) -> bytes:
|
|
"""Return decompressed path data."""
|
|
try:
|
|
import lz4.block
|
|
except ImportError as ex:
|
|
raise KasaException(
|
|
"You need to have lz4 and pillow installed to use this function."
|
|
) from ex
|
|
|
|
decompressed = lz4.block.decompress(self.data, uncompressed_size=self.data_len)
|
|
if len(decompressed) != self.data_len:
|
|
raise KasaException("Invalid data length")
|
|
|
|
return decompressed
|
|
|
|
|
|
class Map(SmartModule):
|
|
"""Implementation of vacuum map module."""
|
|
|
|
REQUIRED_COMPONENT = "map"
|
|
|
|
def query(self) -> dict:
|
|
"""Query to execute during the update cycle."""
|
|
return {
|
|
"getMapInfo": {},
|
|
"getMapData": {"map_id": -1},
|
|
"getPathData": {"start_pos": 0},
|
|
}
|
|
|
|
def _initialize_features(self) -> None:
|
|
"""Initialize features."""
|
|
self._add_feature(
|
|
Feature(
|
|
self._device,
|
|
id="map_count",
|
|
name="Map count",
|
|
container=self,
|
|
attribute_getter="map_count",
|
|
category=Feature.Category.Debug,
|
|
type=Feature.Sensor,
|
|
)
|
|
)
|
|
|
|
@property
|
|
def map_count(self) -> int:
|
|
"""Return number of maps."""
|
|
return self.map_info.map_num
|
|
|
|
@property
|
|
def map_info(self) -> MapInfo:
|
|
"""Return map information."""
|
|
return MapInfo.from_dict(self.data["getMapInfo"])
|
|
|
|
@property
|
|
def map_data(self) -> MapData:
|
|
"""Return map data."""
|
|
return MapData.from_dict(self.data["getMapData"])
|
|
|
|
@property
|
|
def path_data(self) -> PathData:
|
|
"""Return path data."""
|
|
return PathData.from_dict(self.data["getPathData"])
|
|
|
|
def get_path(self) -> bytes:
|
|
"""Return path as an image."""
|
|
return self.path_data.get_decompressed_data()
|
|
|
|
def get_map_image(self) -> "PIL.Image":
|
|
"""Return map as an image."""
|
|
return self.map_data.get_image(self.map_info)
|