mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-11-24 18:28:21 +00:00
Add klap support for TAPO protocol by splitting out Transports and Protocols (#557)
* Add support for TAPO/SMART KLAP and seperate transports from protocols * Add tests and some review changes * Update following review * Updates following review
This commit is contained in:
219
kasa/smartprotocol.py
Normal file
219
kasa/smartprotocol.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Implementation of the TP-Link AES Protocol.
|
||||
|
||||
Based on the work of https://github.com/petretiandrea/plugp100
|
||||
under compatible GNU GPL3 license.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from pprint import pformat as pf
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from .aestransport import AesTransport
|
||||
from .credentials import Credentials
|
||||
from .exceptions import AuthenticationException, SmartDeviceException
|
||||
from .json import dumps as json_dumps
|
||||
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
logging.getLogger("httpx").propagate = False
|
||||
|
||||
|
||||
class SmartProtocol(TPLinkProtocol):
|
||||
"""Class for the new TPLink SMART protocol."""
|
||||
|
||||
DEFAULT_PORT = 80
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||||
|
||||
self._credentials: Credentials = credentials or Credentials(
|
||||
username="", password=""
|
||||
)
|
||||
self._transport: BaseTransport = transport or AesTransport(
|
||||
host, credentials=self._credentials, timeout=timeout
|
||||
)
|
||||
self._terminal_uuid: Optional[str] = None
|
||||
self._request_id_generator = SnowflakeId(1, 1)
|
||||
self._query_lock = asyncio.Lock()
|
||||
|
||||
def get_smart_request(self, method, params=None) -> str:
|
||||
"""Get a request message as a string."""
|
||||
request = {
|
||||
"method": method,
|
||||
"params": params,
|
||||
"requestID": self._request_id_generator.generate_id(),
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
"terminal_uuid": self._terminal_uuid,
|
||||
}
|
||||
return json_dumps(request)
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Query the device retrying for retry_count on failure."""
|
||||
async with self._query_lock:
|
||||
resp_dict = await self._query(request, retry_count)
|
||||
if "result" in resp_dict:
|
||||
return resp_dict["result"]
|
||||
return {}
|
||||
|
||||
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
for retry in range(retry_count + 1):
|
||||
try:
|
||||
return await self._execute_query(request, retry)
|
||||
except httpx.CloseError as sdex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {sdex}"
|
||||
) from sdex
|
||||
continue
|
||||
except httpx.ConnectError as cex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {cex}"
|
||||
) from cex
|
||||
except TimeoutError as tex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||||
) from tex
|
||||
except AuthenticationException as auex:
|
||||
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||||
raise auex
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||
|
||||
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||||
if isinstance(request, dict):
|
||||
smart_method = next(iter(request))
|
||||
smart_params = request[smart_method]
|
||||
else:
|
||||
smart_method = request
|
||||
smart_params = None
|
||||
|
||||
if self._transport.needs_handshake:
|
||||
await self._transport.handshake()
|
||||
|
||||
if self._transport.needs_login:
|
||||
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
|
||||
"UTF-8"
|
||||
)
|
||||
login_request = self.get_smart_request("login_device")
|
||||
await self._transport.login(login_request)
|
||||
|
||||
smart_request = self.get_smart_request(smart_method, smart_params)
|
||||
response_data = await self._transport.send(smart_request)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s << %s",
|
||||
self.host,
|
||||
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol."""
|
||||
await self._transport.close()
|
||||
|
||||
|
||||
class SnowflakeId:
|
||||
"""Class for generating snowflake ids."""
|
||||
|
||||
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||||
WORKER_ID_BITS = 5
|
||||
DATA_CENTER_ID_BITS = 5
|
||||
SEQUENCE_BITS = 12
|
||||
|
||||
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||||
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||||
|
||||
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||||
|
||||
def __init__(self, worker_id, data_center_id):
|
||||
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||||
raise ValueError(
|
||||
"Worker ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_WORKER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||||
raise ValueError(
|
||||
"Data center ID can't be greater than "
|
||||
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||||
+ " or less than 0"
|
||||
)
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.data_center_id = data_center_id
|
||||
self.sequence = 0
|
||||
self.last_timestamp = -1
|
||||
|
||||
def generate_id(self):
|
||||
"""Generate a snowflake id."""
|
||||
timestamp = self._current_millis()
|
||||
|
||||
if timestamp < self.last_timestamp:
|
||||
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||||
|
||||
if timestamp == self.last_timestamp:
|
||||
# Within the same millisecond, increment the sequence number
|
||||
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||||
if self.sequence == 0:
|
||||
# Sequence exceeds its bit range, wait until the next millisecond
|
||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||
else:
|
||||
# New millisecond, reset the sequence number
|
||||
self.sequence = 0
|
||||
|
||||
# Update the last timestamp
|
||||
self.last_timestamp = timestamp
|
||||
|
||||
# Generate and return the final ID
|
||||
return (
|
||||
(
|
||||
(timestamp - SnowflakeId.EPOCH)
|
||||
<< (
|
||||
SnowflakeId.WORKER_ID_BITS
|
||||
+ SnowflakeId.SEQUENCE_BITS
|
||||
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||||
)
|
||||
)
|
||||
| (
|
||||
self.data_center_id
|
||||
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||||
)
|
||||
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||||
| self.sequence
|
||||
)
|
||||
|
||||
def _current_millis(self):
|
||||
return round(time.time() * 1000)
|
||||
|
||||
def _wait_next_millis(self, last_timestamp):
|
||||
timestamp = self._current_millis()
|
||||
while timestamp <= last_timestamp:
|
||||
timestamp = self._current_millis()
|
||||
return timestamp
|
||||
Reference in New Issue
Block a user