mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-24 05:37:59 +00:00
220 lines
7.6 KiB
Python
220 lines
7.6 KiB
Python
|
"""Implementation of the TP-Link AES Protocol.
|
||
|
|
||
|
Based on the work of https://github.com/petretiandrea/plugp100
|
||
|
under compatible GNU GPL3 license.
|
||
|
"""
|
||
|
|
||
|
import asyncio
|
||
|
import base64
|
||
|
import logging
|
||
|
import time
|
||
|
import uuid
|
||
|
from pprint import pformat as pf
|
||
|
from typing import Dict, Optional, Union
|
||
|
|
||
|
import httpx
|
||
|
|
||
|
from .aestransport import AesTransport
|
||
|
from .credentials import Credentials
|
||
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||
|
from .json import dumps as json_dumps
|
||
|
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
logging.getLogger("httpx").propagate = False
|
||
|
|
||
|
|
||
|
class SmartProtocol(TPLinkProtocol):
|
||
|
"""Class for the new TPLink SMART protocol."""
|
||
|
|
||
|
DEFAULT_PORT = 80
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
host: str,
|
||
|
*,
|
||
|
transport: Optional[BaseTransport] = None,
|
||
|
credentials: Optional[Credentials] = None,
|
||
|
timeout: Optional[int] = None,
|
||
|
) -> None:
|
||
|
super().__init__(host=host, port=self.DEFAULT_PORT)
|
||
|
|
||
|
self._credentials: Credentials = credentials or Credentials(
|
||
|
username="", password=""
|
||
|
)
|
||
|
self._transport: BaseTransport = transport or AesTransport(
|
||
|
host, credentials=self._credentials, timeout=timeout
|
||
|
)
|
||
|
self._terminal_uuid: Optional[str] = None
|
||
|
self._request_id_generator = SnowflakeId(1, 1)
|
||
|
self._query_lock = asyncio.Lock()
|
||
|
|
||
|
def get_smart_request(self, method, params=None) -> str:
|
||
|
"""Get a request message as a string."""
|
||
|
request = {
|
||
|
"method": method,
|
||
|
"params": params,
|
||
|
"requestID": self._request_id_generator.generate_id(),
|
||
|
"request_time_milis": round(time.time() * 1000),
|
||
|
"terminal_uuid": self._terminal_uuid,
|
||
|
}
|
||
|
return json_dumps(request)
|
||
|
|
||
|
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||
|
"""Query the device retrying for retry_count on failure."""
|
||
|
async with self._query_lock:
|
||
|
resp_dict = await self._query(request, retry_count)
|
||
|
if "result" in resp_dict:
|
||
|
return resp_dict["result"]
|
||
|
return {}
|
||
|
|
||
|
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||
|
for retry in range(retry_count + 1):
|
||
|
try:
|
||
|
return await self._execute_query(request, retry)
|
||
|
except httpx.CloseError as sdex:
|
||
|
await self.close()
|
||
|
if retry >= retry_count:
|
||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||
|
raise SmartDeviceException(
|
||
|
f"Unable to connect to the device: {self.host}: {sdex}"
|
||
|
) from sdex
|
||
|
continue
|
||
|
except httpx.ConnectError as cex:
|
||
|
await self.close()
|
||
|
raise SmartDeviceException(
|
||
|
f"Unable to connect to the device: {self.host}: {cex}"
|
||
|
) from cex
|
||
|
except TimeoutError as tex:
|
||
|
await self.close()
|
||
|
raise SmartDeviceException(
|
||
|
f"Unable to connect to the device, timed out: {self.host}: {tex}"
|
||
|
) from tex
|
||
|
except AuthenticationException as auex:
|
||
|
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
|
||
|
raise auex
|
||
|
except Exception as ex:
|
||
|
await self.close()
|
||
|
if retry >= retry_count:
|
||
|
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||
|
raise SmartDeviceException(
|
||
|
f"Unable to connect to the device: {self.host}: {ex}"
|
||
|
) from ex
|
||
|
continue
|
||
|
|
||
|
# make mypy happy, this should never be reached..
|
||
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||
|
|
||
|
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
|
||
|
if isinstance(request, dict):
|
||
|
smart_method = next(iter(request))
|
||
|
smart_params = request[smart_method]
|
||
|
else:
|
||
|
smart_method = request
|
||
|
smart_params = None
|
||
|
|
||
|
if self._transport.needs_handshake:
|
||
|
await self._transport.handshake()
|
||
|
|
||
|
if self._transport.needs_login:
|
||
|
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
|
||
|
"UTF-8"
|
||
|
)
|
||
|
login_request = self.get_smart_request("login_device")
|
||
|
await self._transport.login(login_request)
|
||
|
|
||
|
smart_request = self.get_smart_request(smart_method, smart_params)
|
||
|
response_data = await self._transport.send(smart_request)
|
||
|
|
||
|
_LOGGER.debug(
|
||
|
"%s << %s",
|
||
|
self.host,
|
||
|
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
|
||
|
)
|
||
|
|
||
|
return response_data
|
||
|
|
||
|
async def close(self) -> None:
|
||
|
"""Close the protocol."""
|
||
|
await self._transport.close()
|
||
|
|
||
|
|
||
|
class SnowflakeId:
|
||
|
"""Class for generating snowflake ids."""
|
||
|
|
||
|
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
|
||
|
WORKER_ID_BITS = 5
|
||
|
DATA_CENTER_ID_BITS = 5
|
||
|
SEQUENCE_BITS = 12
|
||
|
|
||
|
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
|
||
|
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
|
||
|
|
||
|
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
|
||
|
|
||
|
def __init__(self, worker_id, data_center_id):
|
||
|
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
|
||
|
raise ValueError(
|
||
|
"Worker ID can't be greater than "
|
||
|
+ str(SnowflakeId.MAX_WORKER_ID)
|
||
|
+ " or less than 0"
|
||
|
)
|
||
|
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
|
||
|
raise ValueError(
|
||
|
"Data center ID can't be greater than "
|
||
|
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
|
||
|
+ " or less than 0"
|
||
|
)
|
||
|
|
||
|
self.worker_id = worker_id
|
||
|
self.data_center_id = data_center_id
|
||
|
self.sequence = 0
|
||
|
self.last_timestamp = -1
|
||
|
|
||
|
def generate_id(self):
|
||
|
"""Generate a snowflake id."""
|
||
|
timestamp = self._current_millis()
|
||
|
|
||
|
if timestamp < self.last_timestamp:
|
||
|
raise ValueError("Clock moved backwards. Refusing to generate ID.")
|
||
|
|
||
|
if timestamp == self.last_timestamp:
|
||
|
# Within the same millisecond, increment the sequence number
|
||
|
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
|
||
|
if self.sequence == 0:
|
||
|
# Sequence exceeds its bit range, wait until the next millisecond
|
||
|
timestamp = self._wait_next_millis(self.last_timestamp)
|
||
|
else:
|
||
|
# New millisecond, reset the sequence number
|
||
|
self.sequence = 0
|
||
|
|
||
|
# Update the last timestamp
|
||
|
self.last_timestamp = timestamp
|
||
|
|
||
|
# Generate and return the final ID
|
||
|
return (
|
||
|
(
|
||
|
(timestamp - SnowflakeId.EPOCH)
|
||
|
<< (
|
||
|
SnowflakeId.WORKER_ID_BITS
|
||
|
+ SnowflakeId.SEQUENCE_BITS
|
||
|
+ SnowflakeId.DATA_CENTER_ID_BITS
|
||
|
)
|
||
|
)
|
||
|
| (
|
||
|
self.data_center_id
|
||
|
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
|
||
|
)
|
||
|
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
|
||
|
| self.sequence
|
||
|
)
|
||
|
|
||
|
def _current_millis(self):
|
||
|
return round(time.time() * 1000)
|
||
|
|
||
|
def _wait_next_millis(self, last_timestamp):
|
||
|
timestamp = self._current_millis()
|
||
|
while timestamp <= last_timestamp:
|
||
|
timestamp = self._current_millis()
|
||
|
return timestamp
|