mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 11:13:34 +00:00
Update smart request parameter handling (#1061)
Changes to the smart request handling: - Do not send params if null - Drop the requestId parameter - get_preset_rules doesn't send parameters for preset component version less than 3 - get_led_info no longer sends the wrong parameters - get_on_off_gradually_info no longer sends an empty {} parameter
This commit is contained in:
parent
06ff598d9c
commit
58afeb28a1
@ -16,7 +16,7 @@ class Led(SmartModule, LedInterface):
|
||||
|
||||
def query(self) -> dict:
|
||||
"""Query to execute during the update cycle."""
|
||||
return {self.QUERY_GETTER_NAME: {"led_rule": None}}
|
||||
return {self.QUERY_GETTER_NAME: None}
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
|
@ -153,6 +153,9 @@ class LightPreset(SmartModule, LightPresetInterface):
|
||||
"""Query to execute during the update cycle."""
|
||||
if self._state_in_sysinfo: # Child lights can have states in the child info
|
||||
return {}
|
||||
if self.supported_version < 3:
|
||||
return {self.QUERY_GETTER_NAME: None}
|
||||
|
||||
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
|
||||
|
||||
async def _check_supported(self):
|
||||
|
@ -234,7 +234,7 @@ class LightTransition(SmartModule):
|
||||
if self._state_in_sysinfo:
|
||||
return {}
|
||||
else:
|
||||
return {self.QUERY_GETTER_NAME: {}}
|
||||
return {self.QUERY_GETTER_NAME: None}
|
||||
|
||||
async def _check_supported(self):
|
||||
"""Additional check to see if the module is supported by the device."""
|
||||
|
@ -66,7 +66,6 @@ class SmartProtocol(BaseProtocol):
|
||||
"""Create a protocol object."""
|
||||
super().__init__(transport=transport)
|
||||
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
|
||||
self._request_id_generator = SnowflakeId(1, 1)
|
||||
self._query_lock = asyncio.Lock()
|
||||
self._multi_request_batch_size = (
|
||||
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
|
||||
@ -77,11 +76,11 @@ class SmartProtocol(BaseProtocol):
|
||||
"""Get a request message as a string."""
|
||||
request = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
"requestID": self._request_id_generator.generate_id(),
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"terminal_uuid": self._terminal_uuid,
|
||||
}
|
||||
if params:
|
||||
request["params"] = params
|
||||
return json_dumps(request)
|
||||
|
||||
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
|
||||
@ -157,8 +156,10 @@ class SmartProtocol(BaseProtocol):
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
multi_result: dict[str, Any] = {}
|
||||
smart_method = "multipleRequest"
|
||||
|
||||
multi_requests = [
|
||||
{"method": method, "params": params} for method, params in requests.items()
|
||||
{"method": method, "params": params} if params else {"method": method}
|
||||
for method, params in requests.items()
|
||||
]
|
||||
|
||||
end = len(multi_requests)
|
||||
@ -168,7 +169,7 @@ class SmartProtocol(BaseProtocol):
|
||||
# If step is 1 do not send request batches
|
||||
for request in multi_requests:
|
||||
method = request["method"]
|
||||
req = self.get_smart_request(method, request["params"])
|
||||
req = self.get_smart_request(method, request.get("params"))
|
||||
resp = await self._transport.send(req)
|
||||
self._handle_response_error_code(resp, method, raise_on_error=False)
|
||||
multi_result[method] = resp["result"]
|
||||
@ -347,86 +348,6 @@ class SmartProtocol(BaseProtocol):
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
class SnowflakeId:
|
||||
"""Class for generating snowflake ids."""
|
||||
|
||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||
WORKER_ID_BITS = 5
|
||||
DATA_CENTER_ID_BITS = 5
|
||||
SEQUENCE_BITS = 12
|
||||
|
||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||
|
||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||
|
||||
def __init__(self, worker_id, data_center_id):
|
||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||
raise ValueError(
|
||||
"Worker ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||
raise ValueError(
|
||||
"Data center ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.data_center_id = data_center_id
|
||||
self.sequence = 0
|
||||
self.last_timestamp = -1
|
||||
|
||||
def generate_id(self):
|
||||
"""Generate a snowflake id."""
|
||||
timestamp = self._current_millis()
|
||||
|
||||
if timestamp < self.last_timestamp:
|
||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||
|
||||
if timestamp == self.last_timestamp:
|
||||
# Within the same millisecond, increment the sequence number
|
||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||
if self.sequence == 0:
|
||||
# Sequence exceeds its bit range, wait until the next millisecond
|
||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||
else:
|
||||
# New millisecond, reset the sequence number
|
||||
self.sequence = 0
|
||||
|
||||
# Update the last timestamp
|
||||
self.last_timestamp = timestamp
|
||||
|
||||
# Generate and return the final ID
|
||||
return (
|
||||
(
|
||||
(timestamp - SnowflakeId.EPOCH)
|
||||
<< (
|
||||
SnowflakeId.WORKER_ID_BITS
|
||||
+ SnowflakeId.SEQUENCE_BITS
|
||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||
)
|
||||
)
|
||||
| (
|
||||
self.data_center_id
|
||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||
)
|
||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||
| self.sequence
|
||||
)
|
||||
|
||||
def _current_millis(self):
|
||||
return round(time.monotonic() * 1000)
|
||||
|
||||
def _wait_next_millis(self, last_timestamp):
|
||||
timestamp = self._current_millis()
|
||||
while timestamp <= last_timestamp:
|
||||
timestamp = self._current_millis()
|
||||
return timestamp
|
||||
|
||||
|
||||
class _ChildProtocolWrapper(SmartProtocol):
|
||||
"""Protocol wrapper for controlling child devices.
|
||||
|
||||
@ -456,6 +377,8 @@ class _ChildProtocolWrapper(SmartProtocol):
|
||||
smart_method = "multipleRequest"
|
||||
requests = [
|
||||
{"method": method, "params": params}
|
||||
if params
|
||||
else {"method": method}
|
||||
for method, params in request.items()
|
||||
]
|
||||
smart_params = {"requests": requests}
|
||||
|
@ -119,8 +119,9 @@ class FakeSmartTransport(BaseTransport):
|
||||
async def send(self, request: str):
|
||||
request_dict = json_loads(request)
|
||||
method = request_dict["method"]
|
||||
params = request_dict["params"]
|
||||
|
||||
if method == "multipleRequest":
|
||||
params = request_dict["params"]
|
||||
responses = []
|
||||
for request in params["requests"]:
|
||||
response = self._send_request(request) # type: ignore[arg-type]
|
||||
@ -308,12 +309,13 @@ class FakeSmartTransport(BaseTransport):
|
||||
|
||||
def _send_request(self, request_dict: dict):
|
||||
method = request_dict["method"]
|
||||
params = request_dict["params"]
|
||||
|
||||
info = self.info
|
||||
if method == "control_child":
|
||||
return self._handle_control_child(params)
|
||||
elif method == "component_nego" or method[:4] == "get_":
|
||||
return self._handle_control_child(request_dict["params"])
|
||||
|
||||
params = request_dict.get("params")
|
||||
if method == "component_nego" or method[:4] == "get_":
|
||||
if method in info:
|
||||
result = copy.deepcopy(info[method])
|
||||
if "start_index" in result and "sum" in result:
|
||||
|
Loading…
Reference in New Issue
Block a user