mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Update smart request parameter handling (#1061)
Changes to the smart request handling: - Do not send params if null - Drop the requestId parameter - get_preset_rules doesn't send parameters for preset component version less than 3 - get_led_info no longer sends the wrong parameters - get_on_off_gradually_info no longer sends an empty {} parameter
This commit is contained in:
parent
06ff598d9c
commit
58afeb28a1
@ -16,7 +16,7 @@ class Led(SmartModule, LedInterface):
|
|||||||
|
|
||||||
def query(self) -> dict:
|
def query(self) -> dict:
|
||||||
"""Query to execute during the update cycle."""
|
"""Query to execute during the update cycle."""
|
||||||
return {self.QUERY_GETTER_NAME: {"led_rule": None}}
|
return {self.QUERY_GETTER_NAME: None}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mode(self):
|
def mode(self):
|
||||||
|
@ -153,6 +153,9 @@ class LightPreset(SmartModule, LightPresetInterface):
|
|||||||
"""Query to execute during the update cycle."""
|
"""Query to execute during the update cycle."""
|
||||||
if self._state_in_sysinfo: # Child lights can have states in the child info
|
if self._state_in_sysinfo: # Child lights can have states in the child info
|
||||||
return {}
|
return {}
|
||||||
|
if self.supported_version < 3:
|
||||||
|
return {self.QUERY_GETTER_NAME: None}
|
||||||
|
|
||||||
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
|
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
|
||||||
|
|
||||||
async def _check_supported(self):
|
async def _check_supported(self):
|
||||||
|
@ -234,7 +234,7 @@ class LightTransition(SmartModule):
|
|||||||
if self._state_in_sysinfo:
|
if self._state_in_sysinfo:
|
||||||
return {}
|
return {}
|
||||||
else:
|
else:
|
||||||
return {self.QUERY_GETTER_NAME: {}}
|
return {self.QUERY_GETTER_NAME: None}
|
||||||
|
|
||||||
async def _check_supported(self):
|
async def _check_supported(self):
|
||||||
"""Additional check to see if the module is supported by the device."""
|
"""Additional check to see if the module is supported by the device."""
|
||||||
|
@ -66,7 +66,6 @@ class SmartProtocol(BaseProtocol):
|
|||||||
"""Create a protocol object."""
|
"""Create a protocol object."""
|
||||||
super().__init__(transport=transport)
|
super().__init__(transport=transport)
|
||||||
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
||||||
self._request_id_generator = SnowflakeId(1, 1)
|
|
||||||
self._query_lock = asyncio.Lock()
|
self._query_lock = asyncio.Lock()
|
||||||
self._multi_request_batch_size = (
|
self._multi_request_batch_size = (
|
||||||
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
||||||
@ -77,11 +76,11 @@ class SmartProtocol(BaseProtocol):
|
|||||||
"""Get a request message as a string."""
|
"""Get a request message as a string."""
|
||||||
request = {
|
request = {
|
||||||
"method": method,
|
"method": method,
|
||||||
"params": params,
|
|
||||||
"requestID": self._request_id_generator.generate_id(),
|
|
||||||
"request_time_milis": round(time.time() * 1000),
|
"request_time_milis": round(time.time() * 1000),
|
||||||
"terminal_uuid": self._terminal_uuid,
|
"terminal_uuid": self._terminal_uuid,
|
||||||
}
|
}
|
||||||
|
if params:
|
||||||
|
request["params"] = params
|
||||||
return json_dumps(request)
|
return json_dumps(request)
|
||||||
|
|
||||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||||
@ -157,8 +156,10 @@ class SmartProtocol(BaseProtocol):
|
|||||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
multi_result: dict[str, Any] = {}
|
multi_result: dict[str, Any] = {}
|
||||||
smart_method = "multipleRequest"
|
smart_method = "multipleRequest"
|
||||||
|
|
||||||
multi_requests = [
|
multi_requests = [
|
||||||
{"method": method, "params": params} for method, params in requests.items()
|
{"method": method, "params": params} if params else {"method": method}
|
||||||
|
for method, params in requests.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
end = len(multi_requests)
|
end = len(multi_requests)
|
||||||
@ -168,7 +169,7 @@ class SmartProtocol(BaseProtocol):
|
|||||||
# If step is 1 do not send request batches
|
# If step is 1 do not send request batches
|
||||||
for request in multi_requests:
|
for request in multi_requests:
|
||||||
method = request["method"]
|
method = request["method"]
|
||||||
req = self.get_smart_request(method, request["params"])
|
req = self.get_smart_request(method, request.get("params"))
|
||||||
resp = await self._transport.send(req)
|
resp = await self._transport.send(req)
|
||||||
self._handle_response_error_code(resp, method, raise_on_error=False)
|
self._handle_response_error_code(resp, method, raise_on_error=False)
|
||||||
multi_result[method] = resp["result"]
|
multi_result[method] = resp["result"]
|
||||||
@ -347,86 +348,6 @@ class SmartProtocol(BaseProtocol):
|
|||||||
await self._transport.close()
|
await self._transport.close()
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeId:
|
|
||||||
"""Class for generating snowflake ids."""
|
|
||||||
|
|
||||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
|
||||||
WORKER_ID_BITS = 5
|
|
||||||
DATA_CENTER_ID_BITS = 5
|
|
||||||
SEQUENCE_BITS = 12
|
|
||||||
|
|
||||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
|
||||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
|
||||||
|
|
||||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
|
||||||
|
|
||||||
def __init__(self, worker_id, data_center_id):
|
|
||||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Worker ID can't be greater than "
|
|
||||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
|
||||||
+ " or less than 0"
|
|
||||||
)
|
|
||||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Data center ID can't be greater than "
|
|
||||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
|
||||||
+ " or less than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.worker_id = worker_id
|
|
||||||
self.data_center_id = data_center_id
|
|
||||||
self.sequence = 0
|
|
||||||
self.last_timestamp = -1
|
|
||||||
|
|
||||||
def generate_id(self):
|
|
||||||
"""Generate a snowflake id."""
|
|
||||||
timestamp = self._current_millis()
|
|
||||||
|
|
||||||
if timestamp < self.last_timestamp:
|
|
||||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
|
||||||
|
|
||||||
if timestamp == self.last_timestamp:
|
|
||||||
# Within the same millisecond, increment the sequence number
|
|
||||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
|
||||||
if self.sequence == 0:
|
|
||||||
# Sequence exceeds its bit range, wait until the next millisecond
|
|
||||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
|
||||||
else:
|
|
||||||
# New millisecond, reset the sequence number
|
|
||||||
self.sequence = 0
|
|
||||||
|
|
||||||
# Update the last timestamp
|
|
||||||
self.last_timestamp = timestamp
|
|
||||||
|
|
||||||
# Generate and return the final ID
|
|
||||||
return (
|
|
||||||
(
|
|
||||||
(timestamp - SnowflakeId.EPOCH)
|
|
||||||
<< (
|
|
||||||
SnowflakeId.WORKER_ID_BITS
|
|
||||||
+ SnowflakeId.SEQUENCE_BITS
|
|
||||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
| (
|
|
||||||
self.data_center_id
|
|
||||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
|
||||||
)
|
|
||||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
|
||||||
| self.sequence
|
|
||||||
)
|
|
||||||
|
|
||||||
def _current_millis(self):
|
|
||||||
return round(time.monotonic() * 1000)
|
|
||||||
|
|
||||||
def _wait_next_millis(self, last_timestamp):
|
|
||||||
timestamp = self._current_millis()
|
|
||||||
while timestamp <= last_timestamp:
|
|
||||||
timestamp = self._current_millis()
|
|
||||||
return timestamp
|
|
||||||
|
|
||||||
|
|
||||||
class _ChildProtocolWrapper(SmartProtocol):
|
class _ChildProtocolWrapper(SmartProtocol):
|
||||||
"""Protocol wrapper for controlling child devices.
|
"""Protocol wrapper for controlling child devices.
|
||||||
|
|
||||||
@ -456,6 +377,8 @@ class _ChildProtocolWrapper(SmartProtocol):
|
|||||||
smart_method = "multipleRequest"
|
smart_method = "multipleRequest"
|
||||||
requests = [
|
requests = [
|
||||||
{"method": method, "params": params}
|
{"method": method, "params": params}
|
||||||
|
if params
|
||||||
|
else {"method": method}
|
||||||
for method, params in request.items()
|
for method, params in request.items()
|
||||||
]
|
]
|
||||||
smart_params = {"requests": requests}
|
smart_params = {"requests": requests}
|
||||||
|
@ -119,8 +119,9 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
async def send(self, request: str):
|
async def send(self, request: str):
|
||||||
request_dict = json_loads(request)
|
request_dict = json_loads(request)
|
||||||
method = request_dict["method"]
|
method = request_dict["method"]
|
||||||
params = request_dict["params"]
|
|
||||||
if method == "multipleRequest":
|
if method == "multipleRequest":
|
||||||
|
params = request_dict["params"]
|
||||||
responses = []
|
responses = []
|
||||||
for request in params["requests"]:
|
for request in params["requests"]:
|
||||||
response = self._send_request(request) # type: ignore[arg-type]
|
response = self._send_request(request) # type: ignore[arg-type]
|
||||||
@ -308,12 +309,13 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
|
|
||||||
def _send_request(self, request_dict: dict):
|
def _send_request(self, request_dict: dict):
|
||||||
method = request_dict["method"]
|
method = request_dict["method"]
|
||||||
params = request_dict["params"]
|
|
||||||
|
|
||||||
info = self.info
|
info = self.info
|
||||||
if method == "control_child":
|
if method == "control_child":
|
||||||
return self._handle_control_child(params)
|
return self._handle_control_child(request_dict["params"])
|
||||||
elif method == "component_nego" or method[:4] == "get_":
|
|
||||||
|
params = request_dict.get("params")
|
||||||
|
if method == "component_nego" or method[:4] == "get_":
|
||||||
if method in info:
|
if method in info:
|
||||||
result = copy.deepcopy(info[method])
|
result = copy.deepcopy(info[method])
|
||||||
if "start_index" in result and "sum" in result:
|
if "start_index" in result and "sum" in result:
|
||||||
|
Loading…
Reference in New Issue
Block a user