Do login entirely within AesTransport (#580)

* Do login entirely within AesTransport

* Remove login and handshake attributes from BaseTransport

* Add AesTransport tests

* Synchronise transport and protocol __init__ signatures and rename internal variables

* Update after review
This commit is contained in:
sdb9696
2023-12-19 14:11:59 +00:00
committed by GitHub
parent 209391c422
commit 20ea6700a5
13 changed files with 468 additions and 237 deletions

View File

@@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes:
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
DEFAULT_TIMEOUT = 5
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
"""Create a protocol object."""
self.host = host
self.port = port
self.credentials = credentials
@property
@abstractmethod
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
@property
@abstractmethod
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
@abstractmethod
async def login(self, request: str) -> None:
"""Login to the device."""
@abstractmethod
async def handshake(self) -> None:
"""Perform the encryption handshake."""
self._host = host
self._port = port
self._credentials = credentials or Credentials(username="", password="")
self._timeout = timeout or self.DEFAULT_TIMEOUT
@abstractmethod
async def send(self, request: str) -> Dict:
@@ -90,14 +76,14 @@ class TPLinkProtocol(ABC):
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
transport: Optional[BaseTransport] = None,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
self.host = host
self.port = port
self.credentials = credentials
self._transport = transport
@property
def _host(self):
return self._transport._host
@abstractmethod
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
@@ -108,6 +94,40 @@ class TPLinkProtocol(ABC):
"""Close the protocol. Abstract method to be overriden."""
class _XorTransport(BaseTransport):
"""Implementation of the Xor encryption transport.
WIP, currently only to ensure consistent __init__ method signatures
for protocol classes. Will eventually incorporate the logic from
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
class.
"""
DEFAULT_PORT = 9999
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
return {}
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
class TPLinkSmartHomeProtocol(TPLinkProtocol):
"""Implementation of the TP-Link Smart Home protocol."""
@@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
self,
host: str,
*,
port: Optional[int] = None,
timeout: Optional[int] = None,
credentials: Optional[Credentials] = None,
transport: BaseTransport,
) -> None:
"""Create a protocol object."""
super().__init__(
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
)
super().__init__(host, transport=transport)
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock = asyncio.Lock()
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
self._timeout = self._transport._timeout
self._port = self._transport._port
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
@@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
assert isinstance(request, str) # noqa: S101
async with self.query_lock:
return await self._query(request, retry_count, self.timeout)
return await self._query(request, retry_count, self._timeout)
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""
@@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
return
self.reader = self.writer = None
task = asyncio.open_connection(self.host, self.port)
task = asyncio.open_connection(self._host, self._port)
async with asyncio_timeout(timeout):
self.reader, self.writer = await task
sock: socket.socket = self.writer.get_extra_info("socket")
@@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_log:
_LOGGER.debug("%s >> %s", self.host, request)
_LOGGER.debug("%s >> %s", self._host, request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()
@@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json_loads(response)
if debug_log:
_LOGGER.debug("%s << %s", self.host, pf(json_payload))
_LOGGER.debug("%s << %s", self._host, pf(json_payload))
return json_payload
@@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
except ConnectionRefusedError as ex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.close()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device:"
f" {self.host}:{self.port}: {ex}"
f" {self._host}:{self._port}: {ex}"
) from ex
continue
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device:"
f" {self.host}:{self.port}: {ex}"
f" {self._host}:{self._port}: {ex}"
) from ex
continue
@@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to query the device {self.host}:{self.port}: {ex}"
f"Unable to query the device {self._host}:{self._port}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
"Unable to query the device %s, retrying: %s", self._host, ex
)
# make mypy happy, this should never be reached..