Update smart request parameter handling (#1061)

Changes to the smart request handling:
- Do not send params if null
- Drop the requestId parameter
- get_preset_rules doesn't send parameters for preset component version less than 3
- get_led_info no longer sends the wrong parameters
- get_on_off_gradually_info no longer sends an empty {} parameter
This commit is contained in:
Steven B. 2024-07-23 19:02:20 +01:00 committed by GitHub
parent 06ff598d9c
commit 58afeb28a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 19 additions and 91 deletions

View File

@ -16,7 +16,7 @@ class Led(SmartModule, LedInterface):
def query(self) -> dict: def query(self) -> dict:
"""Query to execute during the update cycle.""" """Query to execute during the update cycle."""
return {self.QUERY_GETTER_NAME: {"led_rule": None}} return {self.QUERY_GETTER_NAME: None}
@property @property
def mode(self): def mode(self):

View File

@ -153,6 +153,9 @@ class LightPreset(SmartModule, LightPresetInterface):
"""Query to execute during the update cycle.""" """Query to execute during the update cycle."""
if self._state_in_sysinfo: # Child lights can have states in the child info if self._state_in_sysinfo: # Child lights can have states in the child info
return {} return {}
if self.supported_version < 3:
return {self.QUERY_GETTER_NAME: None}
return {self.QUERY_GETTER_NAME: {"start_index": 0}} return {self.QUERY_GETTER_NAME: {"start_index": 0}}
async def _check_supported(self): async def _check_supported(self):

View File

@ -234,7 +234,7 @@ class LightTransition(SmartModule):
if self._state_in_sysinfo: if self._state_in_sysinfo:
return {} return {}
else: else:
return {self.QUERY_GETTER_NAME: {}} return {self.QUERY_GETTER_NAME: None}
async def _check_supported(self): async def _check_supported(self):
"""Additional check to see if the module is supported by the device.""" """Additional check to see if the module is supported by the device."""

View File

@ -66,7 +66,6 @@ class SmartProtocol(BaseProtocol):
"""Create a protocol object.""" """Create a protocol object."""
super().__init__(transport=transport) super().__init__(transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock() self._query_lock = asyncio.Lock()
self._multi_request_batch_size = ( self._multi_request_batch_size = (
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
@ -77,11 +76,11 @@ class SmartProtocol(BaseProtocol):
"""Get a request message as a string.""" """Get a request message as a string."""
request = { request = {
"method": method, "method": method,
"params": params,
"requestID": self._request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000), "request_time_milis": round(time.time() * 1000),
"terminal_uuid": self._terminal_uuid, "terminal_uuid": self._terminal_uuid,
} }
if params:
request["params"] = params
return json_dumps(request) return json_dumps(request)
async def query(self, request: str | dict, retry_count: int = 3) -> dict: async def query(self, request: str | dict, retry_count: int = 3) -> dict:
@ -157,8 +156,10 @@ class SmartProtocol(BaseProtocol):
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: dict[str, Any] = {} multi_result: dict[str, Any] = {}
smart_method = "multipleRequest" smart_method = "multipleRequest"
multi_requests = [ multi_requests = [
{"method": method, "params": params} for method, params in requests.items() {"method": method, "params": params} if params else {"method": method}
for method, params in requests.items()
] ]
end = len(multi_requests) end = len(multi_requests)
@ -168,7 +169,7 @@ class SmartProtocol(BaseProtocol):
# If step is 1 do not send request batches # If step is 1 do not send request batches
for request in multi_requests: for request in multi_requests:
method = request["method"] method = request["method"]
req = self.get_smart_request(method, request["params"]) req = self.get_smart_request(method, request.get("params"))
resp = await self._transport.send(req) resp = await self._transport.send(req)
self._handle_response_error_code(resp, method, raise_on_error=False) self._handle_response_error_code(resp, method, raise_on_error=False)
multi_result[method] = resp["result"] multi_result[method] = resp["result"]
@ -347,86 +348,6 @@ class SmartProtocol(BaseProtocol):
await self._transport.close() await self._transport.close()
class SnowflakeId:
"""Class for generating snowflake ids."""
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)
self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1
def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()
if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")
if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0
# Update the last timestamp
self.last_timestamp = timestamp
# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)
def _current_millis(self):
return round(time.monotonic() * 1000)
def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp
class _ChildProtocolWrapper(SmartProtocol): class _ChildProtocolWrapper(SmartProtocol):
"""Protocol wrapper for controlling child devices. """Protocol wrapper for controlling child devices.
@ -456,6 +377,8 @@ class _ChildProtocolWrapper(SmartProtocol):
smart_method = "multipleRequest" smart_method = "multipleRequest"
requests = [ requests = [
{"method": method, "params": params} {"method": method, "params": params}
if params
else {"method": method}
for method, params in request.items() for method, params in request.items()
] ]
smart_params = {"requests": requests} smart_params = {"requests": requests}

View File

@ -119,8 +119,9 @@ class FakeSmartTransport(BaseTransport):
async def send(self, request: str): async def send(self, request: str):
request_dict = json_loads(request) request_dict = json_loads(request)
method = request_dict["method"] method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest": if method == "multipleRequest":
params = request_dict["params"]
responses = [] responses = []
for request in params["requests"]: for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type] response = self._send_request(request) # type: ignore[arg-type]
@ -308,12 +309,13 @@ class FakeSmartTransport(BaseTransport):
def _send_request(self, request_dict: dict): def _send_request(self, request_dict: dict):
method = request_dict["method"] method = request_dict["method"]
params = request_dict["params"]
info = self.info info = self.info
if method == "control_child": if method == "control_child":
return self._handle_control_child(params) return self._handle_control_child(request_dict["params"])
elif method == "component_nego" or method[:4] == "get_":
params = request_dict.get("params")
if method == "component_nego" or method[:4] == "get_":
if method in info: if method in info:
result = copy.deepcopy(info[method]) result = copy.deepcopy(info[method])
if "start_index" in result and "sum" in result: if "start_index" in result and "sum" in result: