Initial map module implementation

This commit is contained in:
Teemu Rytilahti 2025-01-12 18:30:21 +01:00
parent 2542516009
commit 83d169d35c
3 changed files with 265 additions and 0 deletions

View File

@ -164,6 +164,7 @@ class Module(ABC):
# Vacuum modules
Clean: Final[ModuleName[smart.Clean]] = ModuleName("Clean")
Dustbin: Final[ModuleName[smart.Dustbin]] = ModuleName("Dustbin")
Map: Final[ModuleName[smart.Map]] = ModuleName("Map")
Speaker: Final[ModuleName[smart.Speaker]] = ModuleName("Speaker")
def __init__(self, device: Device, module: str) -> None:

View File

@ -26,6 +26,7 @@ from .lighteffect import LightEffect
from .lightpreset import LightPreset
from .lightstripeffect import LightStripEffect
from .lighttransition import LightTransition
from .map import Map
from .matter import Matter
from .motionsensor import MotionSensor
from .overheatprotection import OverheatProtection
@ -76,4 +77,5 @@ __all__ = [
"HomeKit",
"Matter",
"Dustbin",
"Map",
]

262
kasa/smart/modules/map.py Normal file
View File

@ -0,0 +1,262 @@
"""Implementation for map module."""
import hashlib
import logging
from base64 import b64decode
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from mashumaro import DataClassDictMixin, field_options
from mashumaro.config import BaseConfig
from mashumaro.types import Discriminator, SerializationStrategy
from ...exceptions import KasaException
from ...feature import Feature
from ..smartmodule import SmartModule
from .clean import FanSpeed
if TYPE_CHECKING:
import PIL
_LOGGER = logging.getLogger(__name__)
@dataclass
class MapSummary(DataClassDictMixin):
"""Class representing map summary from mapinfo response."""
map_id: int
rotate_angle: int
is_saved: bool
update_time: int
global_cleaned: int
auto_area_flag: bool
map_locked: int
class Maps(SerializationStrategy):
"""Strategy to deserialize list of maps into a dict."""
def deserialize(self, value: str) -> dict:
"""Deserialize list of maps into a dict."""
maps = {x["map_id"]: MapSummary.from_dict(x) for x in value}
return maps
@dataclass
class Area(DataClassDictMixin):
"""Reprsentation of an area.
This can be a room, a carpet, etc.
"""
class Config(BaseConfig):
"""Configuration."""
discriminator = Discriminator(
field="type",
include_subtypes=True,
)
@dataclass
class Room(Area):
"""Room area."""
fanspeed: FanSpeed = field(metadata=field_options(alias="suction"))
cistern: int = field() # TODO: enumize
clean_count: int = field(metadata=field_options(alias="clean_number"))
id: int
type: str = "room"
@dataclass
class Carpet(Area):
"""Carpet area."""
vertexs: list
carpet_strategy: int
id: int
type: str = "carpet_rectangle"
class Areas(SerializationStrategy):
"""Strategy to deserialize list of areas into a dict."""
def deserialize(self, value: list) -> dict:
"""Deserialize list of areas into a dict."""
areas = {x["id"]: Area.from_dict(x) for x in value}
return areas
@dataclass
class MapInfo(DataClassDictMixin):
"""Class representing getMapInfo response."""
map_num: int
version: str
current_map_id: int
auto_change_map: bool
maps: dict[int, MapSummary] = field(
metadata=field_options(serialization_strategy=Maps(), alias="map_list")
)
@dataclass
class MapData(DataClassDictMixin):
"""Class representing getMapData response."""
auto_area: bool = field(metadata=field_options(alias="auto_area_flag"))
path_id: int
version: str
map_id: int
resolution: int
resolution_unit: str
width: int
height: int
origin_coor: tuple[int, int, int]
real_origin_coor: tuple[int, int, int]
bitnum: int
palette: dict = field(metadata=field_options(alias="bit_list"))
pix_len: int
map_hash: str
pix_lz4len: int
map_data: bytes = field(metadata={"deserialize": b64decode}, repr=False)
areas: dict = field(
metadata=field_options(serialization_strategy=Areas(), alias="area_list")
)
def get_image(self, mapinfo: MapInfo | None = None) -> "PIL.Image":
"""Return image object map getMapData response."""
# TODO: move assert checks to mashumaro checks
try:
import lz4.block
from PIL import Image
except ImportError as ex:
raise KasaException(
"You need to have lz4 and pillow installed to use this function."
) from ex
if len(self.map_data) != self.pix_lz4len:
raise KasaException("Invalid map data length")
if self.width * self.height != self.pix_len:
raise KasaException("Invalid payload")
_LOGGER.debug("resolution: %s %s", self.resolution, self.resolution_unit)
_LOGGER.debug("Size: %s x %s", self.width, self.height)
_LOGGER.debug("Bits per pixel: %s", self.bitnum)
_LOGGER.debug("origin: %s", self.origin_coor)
_LOGGER.debug("real origin: %s", self.real_origin_coor)
for area_id, area in self.areas.items():
_LOGGER.debug("Area %s: %s", area_id, area)
_LOGGER.debug("Palette: %s", self.palette)
# TODO: use nicer palette
img_data = lz4.block.decompress(self.map_data, uncompressed_size=self.pix_len)
img_data_hash = hashlib.md5(img_data).hexdigest().upper() # noqa: S324
if self.map_hash != img_data_hash:
raise KasaException("Invalid map hash")
match self.bitnum:
case 8:
mode = "L" # 8bit gray
case _:
raise KasaException(f"Unknown bitnum {self.bitnum}")
img = Image.frombytes(mode, (self.width, self.height), data=img_data)
# rotate
img = img.rotate(mapinfo.maps[self.map_id].rotate_angle)
return img
@dataclass
class PathData(DataClassDictMixin):
"""Path data container."""
path_id: int
points: int = field(metadata=field_options(alias="point_counts"))
total_points: int
data: bytes = field(
metadata=field_options(deserialize=b64decode, alias="pos_array"), repr=False
)
data_len: int = field(metadata=field_options(alias="pos_len"))
data_lz4len: int = field(metadata=field_options(alias="pos_lz4len"))
def get_decompressed_data(self) -> bytes:
"""Return decompressed path data."""
try:
import lz4.block
except ImportError as ex:
raise KasaException(
"You need to have lz4 and pillow installed to use this function."
) from ex
decompressed = lz4.block.decompress(self.data, uncompressed_size=self.data_len)
if len(decompressed) != self.data_len:
raise KasaException("Invalid data length")
return decompressed
class Map(SmartModule):
"""Implementation of vacuum map module."""
REQUIRED_COMPONENT = "map"
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {
"getMapInfo": {},
"getMapData": {"map_id": -1},
"getPathData": {"start_pos": 0},
}
def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
self._device,
id="map_count",
name="Map count",
container=self,
attribute_getter="map_count",
category=Feature.Category.Debug,
type=Feature.Sensor,
)
)
@property
def map_count(self) -> int:
"""Return number of maps."""
return self.map_info.map_num
@property
def map_info(self) -> MapInfo:
"""Return map information."""
return MapInfo.from_dict(self.data["getMapInfo"])
@property
def map_data(self) -> MapData:
"""Return map data."""
return MapData.from_dict(self.data["getMapData"])
@property
def path_data(self) -> PathData:
"""Return path data."""
return PathData.from_dict(self.data["getPathData"])
def get_path(self) -> bytes:
"""Return path as an image."""
return self.path_data.get_decompressed_data()
def get_map_image(self) -> "PIL.Image":
"""Return map as an image."""
return self.map_data.get_image(self.map_info)