mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-10-20 06:18:01 +00:00
Do login entirely within AesTransport (#580)
* Do login entirely within AesTransport * Remove login and handshake attributes from BaseTransport * Add AesTransport tests * Synchronise transport and protocol __init__ signatures and rename internal variables * Update after review
This commit is contained in:
108
kasa/protocol.py
108
kasa/protocol.py
@@ -44,35 +44,21 @@ def md5(payload: bytes) -> bytes:
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for all TP-Link protocol transports."""
|
||||
|
||||
DEFAULT_TIMEOUT = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_handshake(self) -> bool:
|
||||
"""Return true if the transport needs to do a handshake."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_login(self) -> bool:
|
||||
"""Return true if the transport needs to do a login."""
|
||||
|
||||
@abstractmethod
|
||||
async def login(self, request: str) -> None:
|
||||
"""Login to the device."""
|
||||
|
||||
@abstractmethod
|
||||
async def handshake(self) -> None:
|
||||
"""Perform the encryption handshake."""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._credentials = credentials or Credentials(username="", password="")
|
||||
self._timeout = timeout or self.DEFAULT_TIMEOUT
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, request: str) -> Dict:
|
||||
@@ -90,14 +76,14 @@ class TPLinkProtocol(ABC):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
self._transport = transport
|
||||
|
||||
@property
|
||||
def _host(self):
|
||||
return self._transport._host
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
@@ -108,6 +94,40 @@ class TPLinkProtocol(ABC):
|
||||
"""Close the protocol. Abstract method to be overriden."""
|
||||
|
||||
|
||||
class _XorTransport(BaseTransport):
|
||||
"""Implementation of the Xor encryption transport.
|
||||
|
||||
WIP, currently only to ensure consistent __init__ method signatures
|
||||
for protocol classes. Will eventually incorporate the logic from
|
||||
TPLinkSmartHomeProtocol to simplify the API and re-use the IotProtocol
|
||||
class.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 9999
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
host,
|
||||
port=port or self.DEFAULT_PORT,
|
||||
credentials=credentials,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def send(self, request: str) -> Dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport. Abstract method to be overriden."""
|
||||
|
||||
|
||||
class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
"""Implementation of the TP-Link Smart Home protocol."""
|
||||
|
||||
@@ -120,20 +140,18 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
self,
|
||||
host: str,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
credentials: Optional[Credentials] = None,
|
||||
transport: BaseTransport,
|
||||
) -> None:
|
||||
"""Create a protocol object."""
|
||||
super().__init__(
|
||||
host=host, port=port or self.DEFAULT_PORT, credentials=credentials
|
||||
)
|
||||
super().__init__(host, transport=transport)
|
||||
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
self.query_lock = asyncio.Lock()
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self.timeout = timeout or TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
|
||||
|
||||
self._timeout = self._transport._timeout
|
||||
self._port = self._transport._port
|
||||
|
||||
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
|
||||
"""Request information from a TP-Link SmartHome Device.
|
||||
@@ -149,7 +167,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
assert isinstance(request, str) # noqa: S101
|
||||
|
||||
async with self.query_lock:
|
||||
return await self._query(request, retry_count, self.timeout)
|
||||
return await self._query(request, retry_count, self._timeout)
|
||||
|
||||
async def _connect(self, timeout: int) -> None:
|
||||
"""Try to connect or reconnect to the device."""
|
||||
@@ -157,7 +175,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
return
|
||||
self.reader = self.writer = None
|
||||
|
||||
task = asyncio.open_connection(self.host, self.port)
|
||||
task = asyncio.open_connection(self._host, self._port)
|
||||
async with asyncio_timeout(timeout):
|
||||
self.reader, self.writer = await task
|
||||
sock: socket.socket = self.writer.get_extra_info("socket")
|
||||
@@ -174,7 +192,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
if debug_log:
|
||||
_LOGGER.debug("%s >> %s", self.host, request)
|
||||
_LOGGER.debug("%s >> %s", self._host, request)
|
||||
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
|
||||
await self.writer.drain()
|
||||
|
||||
@@ -185,7 +203,7 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
response = TPLinkSmartHomeProtocol.decrypt(buffer)
|
||||
json_payload = json_loads(response)
|
||||
if debug_log:
|
||||
_LOGGER.debug("%s << %s", self.host, pf(json_payload))
|
||||
_LOGGER.debug("%s << %s", self._host, pf(json_payload))
|
||||
|
||||
return json_payload
|
||||
|
||||
@@ -219,23 +237,23 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
except ConnectionRefusedError as ex:
|
||||
await self.close()
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device: {self.host}:{self.port}: {ex}"
|
||||
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
except OSError as ex:
|
||||
await self.close()
|
||||
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device:"
|
||||
f" {self.host}:{self.port}: {ex}"
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to connect to the device:"
|
||||
f" {self.host}:{self.port}: {ex}"
|
||||
f" {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
continue
|
||||
|
||||
@@ -247,13 +265,13 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol):
|
||||
except Exception as ex:
|
||||
await self.close()
|
||||
if retry >= retry_count:
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
|
||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||
raise SmartDeviceException(
|
||||
f"Unable to query the device {self.host}:{self.port}: {ex}"
|
||||
f"Unable to query the device {self._host}:{self._port}: {ex}"
|
||||
) from ex
|
||||
|
||||
_LOGGER.debug(
|
||||
"Unable to query the device %s, retrying: %s", self.host, ex
|
||||
"Unable to query the device %s, retrying: %s", self._host, ex
|
||||
)
|
||||
|
||||
# make mypy happy, this should never be reached..
|
||||
|
Reference in New Issue
Block a user