Do login entirely within AesTransport (#580)

* Do login entirely within AesTransport

* Remove login and handshake attributes from BaseTransport

* Add AesTransport tests

* Synchronise transport and protocol __init__ signatures and rename internal variables

* Update after review
This commit is contained in:
sdb9696
2023-12-19 14:11:59 +00:00
committed by GitHub
parent 209391c422
commit 20ea6700a5
13 changed files with 468 additions and 237 deletions

View File

@@ -10,12 +10,10 @@ import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
from typing import Dict, Union
import httpx
from .aestransport import AesTransport
from .credentials import Credentials
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
@@ -36,26 +34,17 @@ logging.getLogger("httpx").propagate = False
class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol."""
DEFAULT_PORT = 80
SLEEP_SECONDS_AFTER_TIMEOUT = 1
def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
transport: BaseTransport,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or AesTransport(
host, credentials=self._credentials, timeout=timeout
)
self._terminal_uuid: Optional[str] = None
"""Create a protocol object."""
super().__init__(host, transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock()
@@ -79,7 +68,7 @@ class SmartProtocol(TPLinkProtocol):
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = (
f"Error querying device: {self.host}: "
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
)
if error_code in SMART_TIMEOUT_ERRORS:
@@ -101,51 +90,53 @@ class SmartProtocol(TPLinkProtocol):
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex:
if retry >= retry_count:
await self.close()
raise SmartDeviceException(
"Unable to connect to the device, "
+ f"timed out: {self.host}: {tex}"
+ f"timed out: {self._host}: {tex}"
) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except AuthenticationException as auex:
await self.close()
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except Exception as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to query the device {self.host}:{self.port}: {ex}"
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
"Unable to query the device %s, retrying: %s", self._host, ex
)
continue
@@ -160,27 +151,17 @@ class SmartProtocol(TPLinkProtocol):
smart_method = request
smart_params = None
if self._transport.needs_handshake:
await self._transport.handshake()
if self._transport.needs_login:
self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
login_request = self.get_smart_request("login_device")
await self._transport.login(login_request)
smart_request = self.get_smart_request(smart_method, smart_params)
_LOGGER.debug(
"%s >> %s",
self.host,
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(smart_request),
)
response_data = await self._transport.send(smart_request)
_LOGGER.debug(
"%s << %s",
self.host,
self._host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)