from __future__ import annotations import copy from json import loads as json_loads from typing import Any from kasa import Credentials, DeviceConfig, SmartProtocol from kasa.protocols.smartcamprotocol import SmartCamProtocol from kasa.transports.basetransport import BaseTransport from .fakeprotocol_smart import FakeSmartTransport class FakeSmartCamProtocol(SmartCamProtocol): def __init__(self, info, fixture_name, *, is_child=False, verbatim=False): super().__init__( transport=FakeSmartCamTransport( info, fixture_name, is_child=is_child, verbatim=verbatim ), ) async def query(self, request, retry_count: int = 3): """Implement query here so can still patch SmartProtocol.query.""" resp_dict = await self._query(request, retry_count) return resp_dict class FakeSmartCamTransport(BaseTransport): def __init__( self, info, fixture_name, *, list_return_size=10, is_child=False, verbatim=False, ): super().__init__( config=DeviceConfig( "127.0.0.123", credentials=Credentials( username="dummy_user", password="dummy_password", # noqa: S106 ), ), ) self.fixture_name = fixture_name # When True verbatim will bypass any extra processing of missing # methods and is used to test the fixture creation itself. self.verbatim = verbatim if not is_child: self.info = copy.deepcopy(info) self.child_protocols = FakeSmartTransport._get_child_protocols( self.info, self.fixture_name, "getChildDeviceList" ) else: self.info = info # self.child_protocols = self._get_child_protocols() self.list_return_size = list_return_size @property def default_port(self): """Default port for the transport.""" return 443 @property def credentials_hash(self): """The hashed credentials used by the transport.""" return self._credentials.username + self._credentials.password + "camerahash" async def send(self, request: str): request_dict = json_loads(request) method = request_dict["method"] if method == "multipleRequest": params = request_dict["params"] responses = [] for request in params["requests"]: response = await self._send_request(request) # type: ignore[arg-type] response["method"] = request["method"] # type: ignore[index] responses.append(response) # Devices do not continue after error if response["error_code"] != 0: break return {"result": {"responses": responses}, "error_code": 0} else: return await self._send_request(request_dict) async def _handle_control_child(self, params: dict): """Handle control_child command.""" device_id = params.get("device_id") assert device_id in self.child_protocols, "Fixture does not have child info" child_protocol: SmartProtocol = self.child_protocols[device_id] request_data = params.get("request_data", {}) child_method = request_data.get("method") child_params = request_data.get("params") # noqa: F841 resp = await child_protocol.query({child_method: child_params}) resp["error_code"] = 0 for val in resp.values(): return { "result": {"response_data": {"result": val, "error_code": 0}}, "error_code": 0, } @staticmethod def _get_param_set_value(info: dict, set_keys: list[str], value): for key in set_keys[:-1]: info = info[key] info[set_keys[-1]] = value # Setters for when there's not a simple mapping of setters to getters SETTERS = { ("system", "sys", "dev_alias"): [ "getDeviceInfo", "device_info", "basic_info", "device_alias", ], # setTimezone maps to getClockStatus ("system", "clock_status", "seconds_from_1970"): [ "getClockStatus", "system", "clock_status", "seconds_from_1970", ], # setTimezone maps to getClockStatus ("system", "clock_status", "local_time"): [ "getClockStatus", "system", "clock_status", "local_time", ], } @staticmethod def _get_second_key(request_dict: dict[str, Any]) -> str: assert ( len(request_dict) == 2 ), f"Unexpected dict {request_dict}, should be length 2" it = iter(request_dict) next(it, None) return next(it) async def _send_request(self, request_dict: dict): method = request_dict["method"] info = self.info if method == "controlChild": return await self._handle_control_child( request_dict["params"]["childControl"] ) if method[:3] == "set": get_method = "g" + method[1:] for key, val in request_dict.items(): if key == "method": continue # key is params for multi request and the actual params # for single requests if key == "params": module = next(iter(val)) val = val[module] else: module = key section = next(iter(val)) skey_val = val[section] if not isinstance(skey_val, dict): # single level query section_key = section section_val = skey_val if (get_info := info.get(get_method)) and section_key in get_info: get_info[section_key] = section_val else: return {"error_code": -1} break for skey, sval in skey_val.items(): section_key = skey section_value = sval if setter_keys := self.SETTERS.get((module, section, section_key)): self._get_param_set_value(info, setter_keys, section_value) elif ( section := info.get(get_method, {}) .get(module, {}) .get(section, {}) ) and section_key in section: section[section_key] = section_value else: return {"error_code": -1} break return {"error_code": 0} elif method == "get": module = self._get_second_key(request_dict) get_method = f"get_{module}" if get_method in info: result = copy.deepcopy(info[get_method]["get"]) return {**result, "error_code": 0} else: return {"error_code": -1} elif method[:3] == "get": params = request_dict.get("params") if method in info: result = copy.deepcopy(info[method]) if "start_index" in result and "sum" in result: list_key = next( iter([key for key in result if isinstance(result[key], list)]) ) start_index = ( start_index if (params and (start_index := params.get("start_index"))) else 0 ) result[list_key] = result[list_key][ start_index : start_index + self.list_return_size ] return {"result": result, "error_code": 0} else: return {"error_code": -1} return {"error_code": -1} async def close(self) -> None: pass async def reset(self) -> None: pass