Move transports into their own package (#1247)

This moves all transport implementations into a new `transports` package
for cleaner main package & easier to understand project structure.
This commit is contained in:
Teemu R.
2024-11-12 14:40:44 +01:00
committed by GitHub
parent 71ae06fa83
commit 668ba748c5
27 changed files with 159 additions and 102 deletions

View File

@@ -41,8 +41,9 @@ from kasa.iotprotocol import (
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
)
from kasa.module import Module
from kasa.protocol import BaseProtocol, BaseTransport
from kasa.protocol import BaseProtocol
from kasa.smartprotocol import SmartProtocol
from kasa.transports import BaseTransport
__version__ = version("python-kasa")

View File

@@ -128,7 +128,7 @@ from .feature import Feature
from .iotprotocol import IotProtocol
from .module import Module
from .protocol import BaseProtocol
from .xortransport import XorTransport
from .transports import XorTransport
if TYPE_CHECKING:
from .modulemapping import ModuleMapping, ModuleName

View File

@@ -6,7 +6,6 @@ import logging
import time
from typing import Any
from .aestransport import AesTransport
from .device import Device
from .device_type import DeviceType
from .deviceconfig import DeviceConfig
@@ -24,14 +23,18 @@ from .iot import (
IotWallSwitch,
)
from .iotprotocol import IotProtocol
from .klaptransport import KlapTransport, KlapTransportV2
from .protocol import (
BaseProtocol,
BaseTransport,
)
from .smart import SmartDevice
from .smartprotocol import SmartProtocol
from .xortransport import XorTransport
from .transports import (
AesTransport,
BaseTransport,
KlapTransport,
KlapTransportV2,
XorTransport,
)
_LOGGER = logging.getLogger(__name__)

View File

@@ -111,7 +111,6 @@ from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError
from kasa import Device
from kasa.aestransport import AesEncyptionSession, KeyPair
from kasa.credentials import Credentials
from kasa.device_factory import (
get_device_class_from_family,
@@ -134,12 +133,14 @@ from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import mask_mac, redact_data
from kasa.xortransport import XorEncryption
from kasa.transports.aestransport import AesEncyptionSession, KeyPair
from kasa.transports.xortransport import XorEncryption
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from kasa import BaseProtocol, BaseTransport
from kasa import BaseProtocol
from kasa.transports import BaseTransport
class ConnectAttempt(NamedTuple):

View File

@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any, Dict, cast
from yarl import URL
from ..aestransport import AesEncyptionSession
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
@@ -28,7 +27,8 @@ from ..exceptions import (
from ..httpclient import HttpClient
from ..json import dumps as json_dumps
from ..json import loads as json_loads
from ..protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials
from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials
from ..transports import AesEncyptionSession, BaseTransport
_LOGGER = logging.getLogger(__name__)

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
import logging
from pprint import pformat as pf
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable
from .deviceconfig import DeviceConfig
from .exceptions import (
@@ -16,8 +16,11 @@ from .exceptions import (
_RetryableError,
)
from .json import dumps as json_dumps
from .protocol import BaseProtocol, BaseTransport, mask_mac, redact_data
from .xortransport import XorEncryption, XorTransport
from .protocol import BaseProtocol, mask_mac, redact_data
from .transports import XorEncryption, XorTransport
if TYPE_CHECKING:
from .transports import BaseTransport
_LOGGER = logging.getLogger(__name__)

View File

@@ -18,7 +18,7 @@ import hashlib
import logging
import struct
from abc import ABC, abstractmethod
from typing import Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
@@ -32,6 +32,10 @@ _UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
_T = TypeVar("_T")
if TYPE_CHECKING:
from .transports import BaseTransport
def redact_data(data: _T, redactors: dict[str, Callable[[Any], Any] | None]) -> _T:
"""Redact sensitive data for logging."""
if not isinstance(data, (dict, list)):
@@ -75,49 +79,6 @@ def md5(payload: bytes) -> bytes:
return hashlib.md5(payload).digest() # noqa: S324
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
DEFAULT_TIMEOUT = 5
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
"""Create a protocol object."""
self._config = config
self._host = config.host
self._port = config.port_override or self.default_port
self._credentials = config.credentials
self._credentials_hash = config.credentials_hash
if not config.timeout:
config.timeout = self.DEFAULT_TIMEOUT
self._timeout = config.timeout
@property
@abstractmethod
def default_port(self) -> int:
"""The default port for the transport."""
@property
@abstractmethod
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
@abstractmethod
async def send(self, request: str) -> dict:
"""Send a message to the device and return a response."""
@abstractmethod
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
@abstractmethod
async def reset(self) -> None:
"""Reset internal state."""
class BaseProtocol(ABC):
"""Base class for all TP-Link Smart Home communication."""

View File

@@ -9,7 +9,6 @@ from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta, timezone, tzinfo
from typing import TYPE_CHECKING, Any, cast
from ..aestransport import AesTransport
from ..device import Device, WifiNetwork
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
@@ -18,6 +17,7 @@ from ..feature import Feature
from ..module import Module
from ..modulemapping import ModuleMapping, ModuleName
from ..smartprotocol import SmartProtocol
from ..transports import AesTransport
from .modules import (
ChildDevice,
Cloud,

View File

@@ -12,7 +12,7 @@ import logging
import time
import uuid
from pprint import pformat as pf
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
@@ -26,7 +26,11 @@ from .exceptions import (
_RetryableError,
)
from .json import dumps as json_dumps
from .protocol import BaseProtocol, BaseTransport, mask_mac, md5, redact_data
from .protocol import BaseProtocol, mask_mac, md5, redact_data
if TYPE_CHECKING:
from .transports import BaseTransport
_LOGGER = logging.getLogger(__name__)

View File

@@ -0,0 +1,16 @@
"""Package containing all supported transports."""
from .aestransport import AesEncyptionSession, AesTransport
from .basetransport import BaseTransport
from .klaptransport import KlapTransport, KlapTransportV2
from .xortransport import XorEncryption, XorTransport
__all__ = [
"AesTransport",
"AesEncyptionSession",
"BaseTransport",
"KlapTransport",
"KlapTransportV2",
"XorTransport",
"XorEncryption",
]

View File

@@ -20,9 +20,9 @@ from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from yarl import URL
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import (
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
AuthenticationError,
@@ -33,10 +33,12 @@ from .exceptions import (
_ConnectionError,
_RetryableError,
)
from .httpclient import HttpClient
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials
from kasa.httpclient import HttpClient
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)

View File

@@ -0,0 +1,55 @@
"""Base class for all transport implementations.
All transport classes must derive from this to implement the common interface.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from kasa import DeviceConfig
class BaseTransport(ABC):
"""Base class for all TP-Link protocol transports."""
DEFAULT_TIMEOUT = 5
def __init__(
self,
*,
config: DeviceConfig,
) -> None:
"""Create a protocol object."""
self._config = config
self._host = config.host
self._port = config.port_override or self.default_port
self._credentials = config.credentials
self._credentials_hash = config.credentials_hash
if not config.timeout:
config.timeout = self.DEFAULT_TIMEOUT
self._timeout = config.timeout
@property
@abstractmethod
def default_port(self) -> int:
"""The default port for the transport."""
@property
@abstractmethod
def credentials_hash(self) -> str | None:
"""The hashed credentials used by the transport."""
@abstractmethod
async def send(self, request: str) -> dict:
"""Send a message to the device and return a response."""
@abstractmethod
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
@abstractmethod
async def reset(self) -> None:
"""Reset internal state."""

View File

@@ -57,12 +57,18 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from yarl import URL
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import AuthenticationError, KasaException, _RetryableError
from .httpclient import HttpClient
from .json import loads as json_loads
from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials, md5
from kasa.credentials import Credentials
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import AuthenticationError, KasaException, _RetryableError
from kasa.httpclient import HttpClient
from kasa.json import loads as json_loads
from kasa.protocol import (
DEFAULT_CREDENTIALS,
get_default_credentials,
md5,
)
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)

View File

@@ -24,10 +24,11 @@ from collections.abc import Generator
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from .deviceconfig import DeviceConfig
from .exceptions import KasaException, _RetryableError
from .json import loads as json_loads
from .protocol import BaseTransport
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import KasaException, _RetryableError
from kasa.json import loads as json_loads
from .basetransport import BaseTransport
_LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}