mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-08 05:47:06 +00:00
229 lines
8.0 KiB
Python
229 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
from json import loads as json_loads
|
|
from typing import Any
|
|
|
|
from kasa import Credentials, DeviceConfig, SmartProtocol
|
|
from kasa.protocols.smartcamprotocol import SmartCamProtocol
|
|
from kasa.transports.basetransport import BaseTransport
|
|
|
|
from .fakeprotocol_smart import FakeSmartTransport
|
|
|
|
|
|
class FakeSmartCamProtocol(SmartCamProtocol):
|
|
def __init__(self, info, fixture_name, *, is_child=False, verbatim=False):
|
|
super().__init__(
|
|
transport=FakeSmartCamTransport(
|
|
info, fixture_name, is_child=is_child, verbatim=verbatim
|
|
),
|
|
)
|
|
|
|
async def query(self, request, retry_count: int = 3):
|
|
"""Implement query here so can still patch SmartProtocol.query."""
|
|
resp_dict = await self._query(request, retry_count)
|
|
return resp_dict
|
|
|
|
|
|
class FakeSmartCamTransport(BaseTransport):
|
|
def __init__(
|
|
self,
|
|
info,
|
|
fixture_name,
|
|
*,
|
|
list_return_size=10,
|
|
is_child=False,
|
|
verbatim=False,
|
|
):
|
|
super().__init__(
|
|
config=DeviceConfig(
|
|
"127.0.0.123",
|
|
credentials=Credentials(
|
|
username="dummy_user",
|
|
password="dummy_password", # noqa: S106
|
|
),
|
|
),
|
|
)
|
|
self.fixture_name = fixture_name
|
|
# When True verbatim will bypass any extra processing of missing
|
|
# methods and is used to test the fixture creation itself.
|
|
self.verbatim = verbatim
|
|
if not is_child:
|
|
self.info = copy.deepcopy(info)
|
|
self.child_protocols = FakeSmartTransport._get_child_protocols(
|
|
self.info, self.fixture_name, "getChildDeviceList"
|
|
)
|
|
else:
|
|
self.info = info
|
|
# self.child_protocols = self._get_child_protocols()
|
|
self.list_return_size = list_return_size
|
|
|
|
@property
|
|
def default_port(self):
|
|
"""Default port for the transport."""
|
|
return 443
|
|
|
|
@property
|
|
def credentials_hash(self):
|
|
"""The hashed credentials used by the transport."""
|
|
return self._credentials.username + self._credentials.password + "camerahash"
|
|
|
|
async def send(self, request: str):
|
|
request_dict = json_loads(request)
|
|
method = request_dict["method"]
|
|
|
|
if method == "multipleRequest":
|
|
params = request_dict["params"]
|
|
responses = []
|
|
for request in params["requests"]:
|
|
response = await self._send_request(request) # type: ignore[arg-type]
|
|
response["method"] = request["method"] # type: ignore[index]
|
|
responses.append(response)
|
|
# Devices do not continue after error
|
|
if response["error_code"] != 0:
|
|
break
|
|
return {"result": {"responses": responses}, "error_code": 0}
|
|
else:
|
|
return await self._send_request(request_dict)
|
|
|
|
async def _handle_control_child(self, params: dict):
|
|
"""Handle control_child command."""
|
|
device_id = params.get("device_id")
|
|
assert device_id in self.child_protocols, "Fixture does not have child info"
|
|
|
|
child_protocol: SmartProtocol = self.child_protocols[device_id]
|
|
|
|
request_data = params.get("request_data", {})
|
|
|
|
child_method = request_data.get("method")
|
|
child_params = request_data.get("params") # noqa: F841
|
|
|
|
resp = await child_protocol.query({child_method: child_params})
|
|
resp["error_code"] = 0
|
|
for val in resp.values():
|
|
return {
|
|
"result": {"response_data": {"result": val, "error_code": 0}},
|
|
"error_code": 0,
|
|
}
|
|
|
|
@staticmethod
|
|
def _get_param_set_value(info: dict, set_keys: list[str], value):
|
|
for key in set_keys[:-1]:
|
|
info = info[key]
|
|
info[set_keys[-1]] = value
|
|
|
|
# Setters for when there's not a simple mapping of setters to getters
|
|
SETTERS = {
|
|
("system", "sys", "dev_alias"): [
|
|
"getDeviceInfo",
|
|
"device_info",
|
|
"basic_info",
|
|
"device_alias",
|
|
],
|
|
# setTimezone maps to getClockStatus
|
|
("system", "clock_status", "seconds_from_1970"): [
|
|
"getClockStatus",
|
|
"system",
|
|
"clock_status",
|
|
"seconds_from_1970",
|
|
],
|
|
# setTimezone maps to getClockStatus
|
|
("system", "clock_status", "local_time"): [
|
|
"getClockStatus",
|
|
"system",
|
|
"clock_status",
|
|
"local_time",
|
|
],
|
|
}
|
|
|
|
@staticmethod
|
|
def _get_second_key(request_dict: dict[str, Any]) -> str:
|
|
assert (
|
|
len(request_dict) == 2
|
|
), f"Unexpected dict {request_dict}, should be length 2"
|
|
it = iter(request_dict)
|
|
next(it, None)
|
|
return next(it)
|
|
|
|
async def _send_request(self, request_dict: dict):
|
|
method = request_dict["method"]
|
|
|
|
info = self.info
|
|
if method == "controlChild":
|
|
return await self._handle_control_child(
|
|
request_dict["params"]["childControl"]
|
|
)
|
|
|
|
if method[:3] == "set":
|
|
get_method = "g" + method[1:]
|
|
for key, val in request_dict.items():
|
|
if key == "method":
|
|
continue
|
|
# key is params for multi request and the actual params
|
|
# for single requests
|
|
if key == "params":
|
|
module = next(iter(val))
|
|
val = val[module]
|
|
else:
|
|
module = key
|
|
section = next(iter(val))
|
|
skey_val = val[section]
|
|
if not isinstance(skey_val, dict): # single level query
|
|
section_key = section
|
|
section_val = skey_val
|
|
if (get_info := info.get(get_method)) and section_key in get_info:
|
|
get_info[section_key] = section_val
|
|
else:
|
|
return {"error_code": -1}
|
|
break
|
|
for skey, sval in skey_val.items():
|
|
section_key = skey
|
|
section_value = sval
|
|
if setter_keys := self.SETTERS.get((module, section, section_key)):
|
|
self._get_param_set_value(info, setter_keys, section_value)
|
|
elif (
|
|
section := info.get(get_method, {})
|
|
.get(module, {})
|
|
.get(section, {})
|
|
) and section_key in section:
|
|
section[section_key] = section_value
|
|
else:
|
|
return {"error_code": -1}
|
|
break
|
|
return {"error_code": 0}
|
|
elif method == "get":
|
|
module = self._get_second_key(request_dict)
|
|
get_method = f"get_{module}"
|
|
if get_method in info:
|
|
result = copy.deepcopy(info[get_method]["get"])
|
|
return {**result, "error_code": 0}
|
|
else:
|
|
return {"error_code": -1}
|
|
elif method[:3] == "get":
|
|
params = request_dict.get("params")
|
|
if method in info:
|
|
result = copy.deepcopy(info[method])
|
|
if "start_index" in result and "sum" in result:
|
|
list_key = next(
|
|
iter([key for key in result if isinstance(result[key], list)])
|
|
)
|
|
start_index = (
|
|
start_index
|
|
if (params and (start_index := params.get("start_index")))
|
|
else 0
|
|
)
|
|
|
|
result[list_key] = result[list_key][
|
|
start_index : start_index + self.list_return_size
|
|
]
|
|
return {"result": result, "error_code": 0}
|
|
else:
|
|
return {"error_code": -1}
|
|
return {"error_code": -1}
|
|
|
|
async def close(self) -> None:
|
|
pass
|
|
|
|
async def reset(self) -> None:
|
|
pass
|