diff --git a/devtools/parse_pcap.py b/devtools/parse_pcap.py index 02d3911c..f08e4dd3 100644 --- a/devtools/parse_pcap.py +++ b/devtools/parse_pcap.py @@ -9,7 +9,7 @@ import dpkt from dpkt.ethernet import ETH_TYPE_IP, Ethernet from kasa.cli.main import echo -from kasa.xortransport import XorEncryption +from kasa.transports.xortransport import XorEncryption def read_payloads_from_file(file): diff --git a/devtools/parse_pcap_klap.py b/devtools/parse_pcap_klap.py index 640c7aef..9af59023 100755 --- a/devtools/parse_pcap_klap.py +++ b/devtools/parse_pcap_klap.py @@ -25,8 +25,8 @@ from kasa.deviceconfig import ( DeviceEncryptionType, DeviceFamily, ) -from kasa.klaptransport import KlapEncryptionSession, KlapTransportV2 from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials +from kasa.transports.klaptransport import KlapEncryptionSession, KlapTransportV2 def _get_seq_from_query(packet): diff --git a/docs/source/reference.md b/docs/source/reference.md index c1bc4662..b8ebee9f 100644 --- a/docs/source/reference.md +++ b/docs/source/reference.md @@ -107,35 +107,35 @@ ``` ```{eval-rst} -.. autoclass:: kasa.protocol.BaseTransport +.. autoclass:: kasa.transports.BaseTransport :members: :inherited-members: :undoc-members: ``` ```{eval-rst} -.. autoclass:: kasa.xortransport.XorTransport +.. autoclass:: kasa.transports.XorTransport :members: :inherited-members: :undoc-members: ``` ```{eval-rst} -.. autoclass:: kasa.klaptransport.KlapTransport +.. autoclass:: kasa.transports.KlapTransport :members: :inherited-members: :undoc-members: ``` ```{eval-rst} -.. autoclass:: kasa.klaptransport.KlapTransportV2 +.. autoclass:: kasa.transports.KlapTransportV2 :members: :inherited-members: :undoc-members: ``` ```{eval-rst} -.. autoclass:: kasa.aestransport.AesTransport +.. autoclass:: kasa.transports.AesTransport :members: :inherited-members: :undoc-members: diff --git a/kasa/__init__.py b/kasa/__init__.py index ffeaa503..49e77966 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -41,8 +41,9 @@ from kasa.iotprotocol import ( _deprecated_TPLinkSmartHomeProtocol, # noqa: F401 ) from kasa.module import Module -from kasa.protocol import BaseProtocol, BaseTransport +from kasa.protocol import BaseProtocol from kasa.smartprotocol import SmartProtocol +from kasa.transports import BaseTransport __version__ = version("python-kasa") diff --git a/kasa/device.py b/kasa/device.py index ca16bb6b..acb3af8c 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -128,7 +128,7 @@ from .feature import Feature from .iotprotocol import IotProtocol from .module import Module from .protocol import BaseProtocol -from .xortransport import XorTransport +from .transports import XorTransport if TYPE_CHECKING: from .modulemapping import ModuleMapping, ModuleName diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 0c1ed427..9cdef53e 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -6,7 +6,6 @@ import logging import time from typing import Any -from .aestransport import AesTransport from .device import Device from .device_type import DeviceType from .deviceconfig import DeviceConfig @@ -24,14 +23,18 @@ from .iot import ( IotWallSwitch, ) from .iotprotocol import IotProtocol -from .klaptransport import KlapTransport, KlapTransportV2 from .protocol import ( BaseProtocol, - BaseTransport, ) from .smart import SmartDevice from .smartprotocol import SmartProtocol -from .xortransport import XorTransport +from .transports import ( + AesTransport, + BaseTransport, + KlapTransport, + KlapTransportV2, + XorTransport, +) _LOGGER = logging.getLogger(__name__) diff --git a/kasa/discover.py b/kasa/discover.py index efb1e5e4..bed43e85 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -111,7 +111,6 @@ from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, ValidationError from kasa import Device -from kasa.aestransport import AesEncyptionSession, KeyPair from kasa.credentials import Credentials from kasa.device_factory import ( get_device_class_from_family, @@ -134,12 +133,14 @@ from kasa.iotprotocol import REDACTORS as IOT_REDACTORS from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import mask_mac, redact_data -from kasa.xortransport import XorEncryption +from kasa.transports.aestransport import AesEncyptionSession, KeyPair +from kasa.transports.xortransport import XorEncryption _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from kasa import BaseProtocol, BaseTransport + from kasa import BaseProtocol + from kasa.transports import BaseTransport class ConnectAttempt(NamedTuple): diff --git a/kasa/experimental/sslaestransport.py b/kasa/experimental/sslaestransport.py index f188f144..6b5144b1 100644 --- a/kasa/experimental/sslaestransport.py +++ b/kasa/experimental/sslaestransport.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any, Dict, cast from yarl import URL -from ..aestransport import AesEncyptionSession from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import ( @@ -28,7 +27,8 @@ from ..exceptions import ( from ..httpclient import HttpClient from ..json import dumps as json_dumps from ..json import loads as json_loads -from ..protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials +from ..protocol import DEFAULT_CREDENTIALS, get_default_credentials +from ..transports import AesEncyptionSession, BaseTransport _LOGGER = logging.getLogger(__name__) diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index 91edb032..bb570498 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -5,7 +5,7 @@ from __future__ import annotations import asyncio import logging from pprint import pformat as pf -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from .deviceconfig import DeviceConfig from .exceptions import ( @@ -16,8 +16,11 @@ from .exceptions import ( _RetryableError, ) from .json import dumps as json_dumps -from .protocol import BaseProtocol, BaseTransport, mask_mac, redact_data -from .xortransport import XorEncryption, XorTransport +from .protocol import BaseProtocol, mask_mac, redact_data +from .transports import XorEncryption, XorTransport + +if TYPE_CHECKING: + from .transports import BaseTransport _LOGGER = logging.getLogger(__name__) diff --git a/kasa/protocol.py b/kasa/protocol.py index f2560987..8e8a2352 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -18,7 +18,7 @@ import hashlib import logging import struct from abc import ABC, abstractmethod -from typing import Any, Callable, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -32,6 +32,10 @@ _UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I") _T = TypeVar("_T") +if TYPE_CHECKING: + from .transports import BaseTransport + + def redact_data(data: _T, redactors: dict[str, Callable[[Any], Any] | None]) -> _T: """Redact sensitive data for logging.""" if not isinstance(data, (dict, list)): @@ -75,49 +79,6 @@ def md5(payload: bytes) -> bytes: return hashlib.md5(payload).digest() # noqa: S324 -class BaseTransport(ABC): - """Base class for all TP-Link protocol transports.""" - - DEFAULT_TIMEOUT = 5 - - def __init__( - self, - *, - config: DeviceConfig, - ) -> None: - """Create a protocol object.""" - self._config = config - self._host = config.host - self._port = config.port_override or self.default_port - self._credentials = config.credentials - self._credentials_hash = config.credentials_hash - if not config.timeout: - config.timeout = self.DEFAULT_TIMEOUT - self._timeout = config.timeout - - @property - @abstractmethod - def default_port(self) -> int: - """The default port for the transport.""" - - @property - @abstractmethod - def credentials_hash(self) -> str | None: - """The hashed credentials used by the transport.""" - - @abstractmethod - async def send(self, request: str) -> dict: - """Send a message to the device and return a response.""" - - @abstractmethod - async def close(self) -> None: - """Close the transport. Abstract method to be overriden.""" - - @abstractmethod - async def reset(self) -> None: - """Reset internal state.""" - - class BaseProtocol(ABC): """Base class for all TP-Link Smart Home communication.""" diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index e497b8e8..e8e2186c 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -9,7 +9,6 @@ from collections.abc import Mapping, Sequence from datetime import datetime, timedelta, timezone, tzinfo from typing import TYPE_CHECKING, Any, cast -from ..aestransport import AesTransport from ..device import Device, WifiNetwork from ..device_type import DeviceType from ..deviceconfig import DeviceConfig @@ -18,6 +17,7 @@ from ..feature import Feature from ..module import Module from ..modulemapping import ModuleMapping, ModuleName from ..smartprotocol import SmartProtocol +from ..transports import AesTransport from .modules import ( ChildDevice, Cloud, diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index e2ff6af7..7d43bdb4 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -12,7 +12,7 @@ import logging import time import uuid from pprint import pformat as pf -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from .exceptions import ( SMART_AUTHENTICATION_ERRORS, @@ -26,7 +26,11 @@ from .exceptions import ( _RetryableError, ) from .json import dumps as json_dumps -from .protocol import BaseProtocol, BaseTransport, mask_mac, md5, redact_data +from .protocol import BaseProtocol, mask_mac, md5, redact_data + +if TYPE_CHECKING: + from .transports import BaseTransport + _LOGGER = logging.getLogger(__name__) diff --git a/kasa/transports/__init__.py b/kasa/transports/__init__.py new file mode 100644 index 00000000..8ccdae65 --- /dev/null +++ b/kasa/transports/__init__.py @@ -0,0 +1,16 @@ +"""Package containing all supported transports.""" + +from .aestransport import AesEncyptionSession, AesTransport +from .basetransport import BaseTransport +from .klaptransport import KlapTransport, KlapTransportV2 +from .xortransport import XorEncryption, XorTransport + +__all__ = [ + "AesTransport", + "AesEncyptionSession", + "BaseTransport", + "KlapTransport", + "KlapTransportV2", + "XorTransport", + "XorEncryption", +] diff --git a/kasa/aestransport.py b/kasa/transports/aestransport.py similarity index 98% rename from kasa/aestransport.py rename to kasa/transports/aestransport.py index fc807fb3..61b7c27b 100644 --- a/kasa/aestransport.py +++ b/kasa/transports/aestransport.py @@ -20,9 +20,9 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from yarl import URL -from .credentials import Credentials -from .deviceconfig import DeviceConfig -from .exceptions import ( +from kasa.credentials import Credentials +from kasa.deviceconfig import DeviceConfig +from kasa.exceptions import ( SMART_AUTHENTICATION_ERRORS, SMART_RETRYABLE_ERRORS, AuthenticationError, @@ -33,10 +33,12 @@ from .exceptions import ( _ConnectionError, _RetryableError, ) -from .httpclient import HttpClient -from .json import dumps as json_dumps -from .json import loads as json_loads -from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials +from kasa.httpclient import HttpClient +from kasa.json import dumps as json_dumps +from kasa.json import loads as json_loads +from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials + +from .basetransport import BaseTransport _LOGGER = logging.getLogger(__name__) diff --git a/kasa/transports/basetransport.py b/kasa/transports/basetransport.py new file mode 100644 index 00000000..1f1ed7d9 --- /dev/null +++ b/kasa/transports/basetransport.py @@ -0,0 +1,55 @@ +"""Base class for all transport implementations. + +All transport classes must derive from this to implement the common interface. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from kasa import DeviceConfig + + +class BaseTransport(ABC): + """Base class for all TP-Link protocol transports.""" + + DEFAULT_TIMEOUT = 5 + + def __init__( + self, + *, + config: DeviceConfig, + ) -> None: + """Create a protocol object.""" + self._config = config + self._host = config.host + self._port = config.port_override or self.default_port + self._credentials = config.credentials + self._credentials_hash = config.credentials_hash + if not config.timeout: + config.timeout = self.DEFAULT_TIMEOUT + self._timeout = config.timeout + + @property + @abstractmethod + def default_port(self) -> int: + """The default port for the transport.""" + + @property + @abstractmethod + def credentials_hash(self) -> str | None: + """The hashed credentials used by the transport.""" + + @abstractmethod + async def send(self, request: str) -> dict: + """Send a message to the device and return a response.""" + + @abstractmethod + async def close(self) -> None: + """Close the transport. Abstract method to be overriden.""" + + @abstractmethod + async def reset(self) -> None: + """Reset internal state.""" diff --git a/kasa/klaptransport.py b/kasa/transports/klaptransport.py similarity index 98% rename from kasa/klaptransport.py rename to kasa/transports/klaptransport.py index 870304d1..d9d5e952 100644 --- a/kasa/klaptransport.py +++ b/kasa/transports/klaptransport.py @@ -57,12 +57,18 @@ from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from yarl import URL -from .credentials import Credentials -from .deviceconfig import DeviceConfig -from .exceptions import AuthenticationError, KasaException, _RetryableError -from .httpclient import HttpClient -from .json import loads as json_loads -from .protocol import DEFAULT_CREDENTIALS, BaseTransport, get_default_credentials, md5 +from kasa.credentials import Credentials +from kasa.deviceconfig import DeviceConfig +from kasa.exceptions import AuthenticationError, KasaException, _RetryableError +from kasa.httpclient import HttpClient +from kasa.json import loads as json_loads +from kasa.protocol import ( + DEFAULT_CREDENTIALS, + get_default_credentials, + md5, +) + +from .basetransport import BaseTransport _LOGGER = logging.getLogger(__name__) diff --git a/kasa/xortransport.py b/kasa/transports/xortransport.py similarity index 97% rename from kasa/xortransport.py rename to kasa/transports/xortransport.py index 7abc2a3b..932a9415 100644 --- a/kasa/xortransport.py +++ b/kasa/transports/xortransport.py @@ -24,10 +24,11 @@ from collections.abc import Generator # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout -from .deviceconfig import DeviceConfig -from .exceptions import KasaException, _RetryableError -from .json import loads as json_loads -from .protocol import BaseTransport +from kasa.deviceconfig import DeviceConfig +from kasa.exceptions import KasaException, _RetryableError +from kasa.json import loads as json_loads + +from .basetransport import BaseTransport _LOGGER = logging.getLogger(__name__) _NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED} diff --git a/tests/conftest.py b/tests/conftest.py index 0d47080f..c56cba0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from kasa import ( DeviceConfig, SmartProtocol, ) -from kasa.protocol import BaseTransport +from kasa.transports.basetransport import BaseTransport from .device_fixtures import * # noqa: F403 from .discovery_fixtures import * # noqa: F403 diff --git a/tests/discovery_fixtures.py b/tests/discovery_fixtures.py index ccad1510..e69a8b73 100644 --- a/tests/discovery_fixtures.py +++ b/tests/discovery_fixtures.py @@ -6,7 +6,7 @@ from json import dumps as json_dumps import pytest -from kasa.xortransport import XorEncryption +from kasa.transports.xortransport import XorEncryption from .fakeprotocol_iot import FakeIotProtocol from .fakeprotocol_smart import FakeSmartProtocol, FakeSmartTransport diff --git a/tests/fakeprotocol_iot.py b/tests/fakeprotocol_iot.py index c8897d9b..1249ec21 100644 --- a/tests/fakeprotocol_iot.py +++ b/tests/fakeprotocol_iot.py @@ -3,7 +3,7 @@ import logging from kasa.deviceconfig import DeviceConfig from kasa.iotprotocol import IotProtocol -from kasa.protocol import BaseTransport +from kasa.transports.basetransport import BaseTransport _LOGGER = logging.getLogger(__name__) diff --git a/tests/fakeprotocol_smart.py b/tests/fakeprotocol_smart.py index 842147f3..ce60a61b 100644 --- a/tests/fakeprotocol_smart.py +++ b/tests/fakeprotocol_smart.py @@ -6,8 +6,8 @@ import pytest from kasa import Credentials, DeviceConfig, SmartProtocol from kasa.exceptions import SmartErrorCode -from kasa.protocol import BaseTransport from kasa.smart import SmartChildDevice +from kasa.transports.basetransport import BaseTransport class FakeSmartProtocol(SmartProtocol): diff --git a/tests/fakeprotocol_smartcamera.py b/tests/fakeprotocol_smartcamera.py index d7465489..7ff0bab2 100644 --- a/tests/fakeprotocol_smartcamera.py +++ b/tests/fakeprotocol_smartcamera.py @@ -5,7 +5,7 @@ from json import loads as json_loads from kasa import Credentials, DeviceConfig, SmartProtocol from kasa.experimental.smartcameraprotocol import SmartCameraProtocol -from kasa.protocol import BaseTransport +from kasa.transports.basetransport import BaseTransport from .fakeprotocol_smart import FakeSmartTransport diff --git a/tests/test_aestransport.py b/tests/test_aestransport.py index 302f195a..4c95289a 100644 --- a/tests/test_aestransport.py +++ b/tests/test_aestransport.py @@ -18,7 +18,6 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd from freezegun.api import FrozenDateTimeFactory from yarl import URL -from kasa.aestransport import AesEncyptionSession, AesTransport, TransportState from kasa.credentials import Credentials from kasa.deviceconfig import DeviceConfig from kasa.exceptions import ( @@ -28,6 +27,11 @@ from kasa.exceptions import ( _ConnectionError, ) from kasa.httpclient import HttpClient +from kasa.transports.aestransport import ( + AesEncyptionSession, + AesTransport, + TransportState, +) pytestmark = [pytest.mark.requires_dummy] diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 7f69977e..32330dca 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -23,7 +23,6 @@ from kasa import ( IotProtocol, KasaException, ) -from kasa.aestransport import AesEncyptionSession from kasa.device_factory import ( get_device_class_from_family, get_device_class_from_sys_info, @@ -41,7 +40,8 @@ from kasa.discover import ( ) from kasa.exceptions import AuthenticationError, UnsupportedDeviceError from kasa.iot import IotDevice -from kasa.xortransport import XorEncryption, XorTransport +from kasa.transports.aestransport import AesEncyptionSession +from kasa.transports.xortransport import XorEncryption, XorTransport from .conftest import ( bulb_iot, diff --git a/tests/test_klapprotocol.py b/tests/test_klapprotocol.py index 524a6be3..bdb05490 100644 --- a/tests/test_klapprotocol.py +++ b/tests/test_klapprotocol.py @@ -9,7 +9,6 @@ import aiohttp import pytest from yarl import URL -from kasa.aestransport import AesTransport from kasa.credentials import Credentials from kasa.deviceconfig import DeviceConfig from kasa.exceptions import ( @@ -21,14 +20,15 @@ from kasa.exceptions import ( ) from kasa.httpclient import HttpClient from kasa.iotprotocol import IotProtocol -from kasa.klaptransport import ( +from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials +from kasa.smartprotocol import SmartProtocol +from kasa.transports.aestransport import AesTransport +from kasa.transports.klaptransport import ( KlapEncryptionSession, KlapTransport, KlapTransportV2, _sha256, ) -from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials -from kasa.smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7638a4bf..11e2afcf 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -13,21 +13,21 @@ from unittest.mock import AsyncMock import pytest -from kasa.aestransport import AesTransport from kasa.credentials import Credentials from kasa.device import Device from kasa.deviceconfig import DeviceConfig from kasa.exceptions import KasaException from kasa.iot import IotDevice from kasa.iotprotocol import IotProtocol, _deprecated_TPLinkSmartHomeProtocol -from kasa.klaptransport import KlapTransport, KlapTransportV2 from kasa.protocol import ( BaseProtocol, - BaseTransport, mask_mac, redact_data, ) -from kasa.xortransport import XorEncryption, XorTransport +from kasa.transports.aestransport import AesTransport +from kasa.transports.basetransport import BaseTransport +from kasa.transports.klaptransport import KlapTransport, KlapTransportV2 +from kasa.transports.xortransport import XorEncryption, XorTransport from .conftest import device_iot from .fakeprotocol_iot import FakeIotTransport diff --git a/tests/test_sslaestransport.py b/tests/test_sslaestransport.py index 49605d37..52507892 100644 --- a/tests/test_sslaestransport.py +++ b/tests/test_sslaestransport.py @@ -11,7 +11,6 @@ import aiohttp import pytest from yarl import URL -from kasa.aestransport import AesEncyptionSession from kasa.credentials import Credentials from kasa.deviceconfig import DeviceConfig from kasa.exceptions import ( @@ -26,6 +25,7 @@ from kasa.experimental.sslaestransport import ( ) from kasa.httpclient import HttpClient from kasa.protocol import DEFAULT_CREDENTIALS, get_default_credentials +from kasa.transports.aestransport import AesEncyptionSession # Transport tests are not designed for real devices pytestmark = [pytest.mark.requires_dummy]