From 58afeb28a1e48436c0d8ed78f5efaba07284558d Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 23 Jul 2024 19:02:20 +0100 Subject: [PATCH] Update smart request parameter handling (#1061) Changes to the smart request handling: - Do not send params if null - Drop the requestId parameter - get_preset_rules doesn't send parameters for preset component version less than 3 - get_led_info no longer sends the wrong parameters - get_on_off_gradually_info no longer sends an empty {} parameter --- kasa/smart/modules/led.py | 2 +- kasa/smart/modules/lightpreset.py | 3 + kasa/smart/modules/lighttransition.py | 2 +- kasa/smartprotocol.py | 93 +++------------------------ kasa/tests/fakeprotocol_smart.py | 10 +-- 5 files changed, 19 insertions(+), 91 deletions(-) diff --git a/kasa/smart/modules/led.py b/kasa/smart/modules/led.py index bbfe3579..9c02be85 100644 --- a/kasa/smart/modules/led.py +++ b/kasa/smart/modules/led.py @@ -16,7 +16,7 @@ class Led(SmartModule, LedInterface): def query(self) -> dict: """Query to execute during the update cycle.""" - return {self.QUERY_GETTER_NAME: {"led_rule": None}} + return {self.QUERY_GETTER_NAME: None} @property def mode(self): diff --git a/kasa/smart/modules/lightpreset.py b/kasa/smart/modules/lightpreset.py index 6bb2fb3f..16cd15ae 100644 --- a/kasa/smart/modules/lightpreset.py +++ b/kasa/smart/modules/lightpreset.py @@ -153,6 +153,9 @@ class LightPreset(SmartModule, LightPresetInterface): """Query to execute during the update cycle.""" if self._state_in_sysinfo: # Child lights can have states in the child info return {} + if self.supported_version < 3: + return {self.QUERY_GETTER_NAME: None} + return {self.QUERY_GETTER_NAME: {"start_index": 0}} async def _check_supported(self): diff --git a/kasa/smart/modules/lighttransition.py b/kasa/smart/modules/lighttransition.py index 3a5897d1..e0aeb4d7 100644 --- a/kasa/smart/modules/lighttransition.py +++ b/kasa/smart/modules/lighttransition.py @@ -234,7 +234,7 @@ class LightTransition(SmartModule): if self._state_in_sysinfo: return {} else: - return {self.QUERY_GETTER_NAME: {}} + return {self.QUERY_GETTER_NAME: None} async def _check_supported(self): """Additional check to see if the module is supported by the device.""" diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 8b22f0cb..8f92b94e 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -66,7 +66,6 @@ class SmartProtocol(BaseProtocol): """Create a protocol object.""" super().__init__(transport=transport) self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() - self._request_id_generator = SnowflakeId(1, 1) self._query_lock = asyncio.Lock() self._multi_request_batch_size = ( self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE @@ -77,11 +76,11 @@ class SmartProtocol(BaseProtocol): """Get a request message as a string.""" request = { "method": method, - "params": params, - "requestID": self._request_id_generator.generate_id(), "request_time_milis": round(time.time() * 1000), "terminal_uuid": self._terminal_uuid, } + if params: + request["params"] = params return json_dumps(request) async def query(self, request: str | dict, retry_count: int = 3) -> dict: @@ -157,8 +156,10 @@ class SmartProtocol(BaseProtocol): debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) multi_result: dict[str, Any] = {} smart_method = "multipleRequest" + multi_requests = [ - {"method": method, "params": params} for method, params in requests.items() + {"method": method, "params": params} if params else {"method": method} + for method, params in requests.items() ] end = len(multi_requests) @@ -168,7 +169,7 @@ class SmartProtocol(BaseProtocol): # If step is 1 do not send request batches for request in multi_requests: method = request["method"] - req = self.get_smart_request(method, request["params"]) + req = self.get_smart_request(method, request.get("params")) resp = await self._transport.send(req) self._handle_response_error_code(resp, method, raise_on_error=False) multi_result[method] = resp["result"] @@ -347,86 +348,6 @@ class SmartProtocol(BaseProtocol): await self._transport.close() -class SnowflakeId: - """Class for generating snowflake ids.""" - - EPOCH = 1420041600000 # Custom epoch (in milliseconds) - WORKER_ID_BITS = 5 - DATA_CENTER_ID_BITS = 5 - SEQUENCE_BITS = 12 - - MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1 - MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1 - - SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1 - - def __init__(self, worker_id, data_center_id): - if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0: - raise ValueError( - "Worker ID can't be greater than " - + str(SnowflakeId.MAX_WORKER_ID) - + " or less than 0" - ) - if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0: - raise ValueError( - "Data center ID can't be greater than " - + str(SnowflakeId.MAX_DATA_CENTER_ID) - + " or less than 0" - ) - - self.worker_id = worker_id - self.data_center_id = data_center_id - self.sequence = 0 - self.last_timestamp = -1 - - def generate_id(self): - """Generate a snowflake id.""" - timestamp = self._current_millis() - - if timestamp < self.last_timestamp: - raise ValueError("Clock moved backwards. Refusing to generate ID.") - - if timestamp == self.last_timestamp: - # Within the same millisecond, increment the sequence number - self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK - if self.sequence == 0: - # Sequence exceeds its bit range, wait until the next millisecond - timestamp = self._wait_next_millis(self.last_timestamp) - else: - # New millisecond, reset the sequence number - self.sequence = 0 - - # Update the last timestamp - self.last_timestamp = timestamp - - # Generate and return the final ID - return ( - ( - (timestamp - SnowflakeId.EPOCH) - << ( - SnowflakeId.WORKER_ID_BITS - + SnowflakeId.SEQUENCE_BITS - + SnowflakeId.DATA_CENTER_ID_BITS - ) - ) - | ( - self.data_center_id - << (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS) - ) - | (self.worker_id << SnowflakeId.SEQUENCE_BITS) - | self.sequence - ) - - def _current_millis(self): - return round(time.monotonic() * 1000) - - def _wait_next_millis(self, last_timestamp): - timestamp = self._current_millis() - while timestamp <= last_timestamp: - timestamp = self._current_millis() - return timestamp - - class _ChildProtocolWrapper(SmartProtocol): """Protocol wrapper for controlling child devices. @@ -456,6 +377,8 @@ class _ChildProtocolWrapper(SmartProtocol): smart_method = "multipleRequest" requests = [ {"method": method, "params": params} + if params + else {"method": method} for method, params in request.items() ] smart_params = {"requests": requests} diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index 600cd75d..7a54be17 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -119,8 +119,9 @@ class FakeSmartTransport(BaseTransport): async def send(self, request: str): request_dict = json_loads(request) method = request_dict["method"] - params = request_dict["params"] + if method == "multipleRequest": + params = request_dict["params"] responses = [] for request in params["requests"]: response = self._send_request(request) # type: ignore[arg-type] @@ -308,12 +309,13 @@ class FakeSmartTransport(BaseTransport): def _send_request(self, request_dict: dict): method = request_dict["method"] - params = request_dict["params"] info = self.info if method == "control_child": - return self._handle_control_child(params) - elif method == "component_nego" or method[:4] == "get_": + return self._handle_control_child(request_dict["params"]) + + params = request_dict.get("params") + if method == "component_nego" or method[:4] == "get_": if method in info: result = copy.deepcopy(info[method]) if "start_index" in result and "sum" in result: