diff --git a/kasa/aestransport.py b/kasa/aestransport.py index df26c4c4..919732cc 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from .credentials import Credentials from .deviceconfig import DeviceConfig from .exceptions import ( SMART_AUTHENTICATION_ERRORS, @@ -62,6 +63,16 @@ class AesTransport(BaseTransport): ) -> None: super().__init__(config=config) + self._login_version = config.connection_type.login_version + if not self._credentials and not self._credentials_hash: + self._credentials = Credentials() + if self._credentials: + self._login_params = self._get_login_params() + else: + self._login_params = json_loads( + base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr] + ) + self._default_http_client: Optional[httpx.AsyncClient] = None self._handshake_done = False @@ -80,6 +91,11 @@ class AesTransport(BaseTransport): """Default port for the transport.""" return self.DEFAULT_PORT + @property + def credentials_hash(self) -> str: + """The hashed credentials used by the transport.""" + return base64.b64encode(json_dumps(self._login_params).encode()).decode() + @property def _http_client(self) -> httpx.AsyncClient: if self._config.http_client: @@ -88,6 +104,12 @@ class AesTransport(BaseTransport): self._default_http_client = httpx.AsyncClient() return self._default_http_client + def _get_login_params(self): + """Get the login parameters based on the login_version.""" + un, pw = self.hash_credentials(self._login_version == 2) + password_field_name = "password2" if self._login_version == 2 else "password" + return {password_field_name: pw, "username": un} + def hash_credentials(self, login_v2): """Hash the credentials.""" if login_v2: @@ -171,14 +193,12 @@ class AesTransport(BaseTransport): resp_dict = json_loads(response) return resp_dict - async def _perform_login_for_version(self, *, login_version: int = 1): + async def perform_login(self): """Login to the device.""" self._login_token = None - un, pw = self.hash_credentials(login_version == 2) - password_field_name = "password2" if login_version == 2 else "password" login_request = { "method": "login_device", - "params": {password_field_name: pw, "username": un}, + "params": self._login_params, "request_time_milis": round(time.time() * 1000), } request = json_dumps(login_request) @@ -187,15 +207,6 @@ class AesTransport(BaseTransport): self._handle_response_error_code(resp_dict, "Error logging in") self._login_token = resp_dict["result"]["token"] - async def perform_login(self) -> None: - """Login to the device.""" - try: - await self._perform_login_for_version(login_version=2) - except AuthenticationException: - _LOGGER.warning("Login version 2 failed, trying version 1") - await self.perform_handshake() - await self._perform_login_for_version(login_version=1) - async def perform_handshake(self): """Perform the handshake.""" _LOGGER.debug("Will perform handshaking...") diff --git a/kasa/cli.py b/kasa/cli.py index 1fb522cf..6c60332d 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -184,6 +184,12 @@ def json_formatter_cb(result, **kwargs): default=None, type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False), ) +@click.option( + "--login-version", + envvar="KASA_LOGIN_VERSION", + default=None, + type=int, +) @click.option( "--timeout", envvar="KASA_TIMEOUT", @@ -214,6 +220,13 @@ def json_formatter_cb(result, **kwargs): envvar="KASA_PASSWORD", help="Password to use to authenticate to device.", ) +@click.option( + "--credentials-hash", + default=None, + required=False, + envvar="KASA_CREDENTIALS_HASH", + help="Hashed credentials used to authenticate to the device.", +) @click.version_option(package_name="python-kasa") @click.pass_context async def cli( @@ -227,11 +240,13 @@ async def cli( type, encrypt_type, device_family, + login_version, json, timeout, discovery_timeout, username, password, + credentials_hash, ): """A tool for controlling TP-Link smart home devices.""" # noqa # no need to perform any checks if we are just displaying the help @@ -291,7 +306,10 @@ async def cli( "username", "Using authentication requires both --username and --password" ) - credentials = Credentials(username=username, password=password) + if username: + credentials = Credentials(username=username, password=password) + else: + credentials = None if host is None: echo("No host name given, trying discovery..") @@ -300,13 +318,18 @@ async def cli( if type is not None: dev = TYPE_TO_CLASS[type](host) await dev.update() - elif device_family or encrypt_type: + elif device_family and encrypt_type: ctype = ConnectionType( DeviceFamilyType(device_family), EncryptType(encrypt_type), + login_version, ) config = DeviceConfig( - host=host, credentials=credentials, timeout=timeout, connection_type=ctype + host=host, + credentials=credentials, + credentials_hash=credentials_hash, + timeout=timeout, + connection_type=ctype, ) dev = await SmartDevice.connect(config=config) else: @@ -495,6 +518,7 @@ async def state(dev: SmartDevice): echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]") echo(f"\tHost: {dev.host}") echo(f"\tPort: {dev.port}") + echo(f"\tCredentials hash: {dev.credentials_hash}") echo(f"\tDevice state: {dev.is_on}") if dev.is_strip: echo("\t[bold]== Plugs ==[/bold]") diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index c753c2bc..5235868f 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -2,7 +2,7 @@ import logging from dataclasses import asdict, dataclass, field, fields, is_dataclass from enum import Enum -from typing import Dict, Optional +from typing import Dict, Optional, Union import httpx @@ -69,21 +69,25 @@ class ConnectionType: device_family: DeviceFamilyType encryption_type: EncryptType + login_version: Optional[int] = None @staticmethod def from_values( device_family: str, encryption_type: str, + login_version: Optional[int] = None, ) -> "ConnectionType": """Return connection parameters from string values.""" try: return ConnectionType( DeviceFamilyType(device_family), EncryptType(encryption_type), + login_version, ) - except ValueError as ex: + except (ValueError, TypeError) as ex: raise SmartDeviceException( - f"Invalid connection parameters for {device_family}.{encryption_type}" + f"Invalid connection parameters for {device_family}." + + f"{encryption_type}.{login_version}" ) from ex @staticmethod @@ -94,18 +98,26 @@ class ConnectionType: and (device_family := connection_type_dict.get("device_family")) and (encryption_type := connection_type_dict.get("encryption_type")) ): - return ConnectionType.from_values(device_family, encryption_type) + if login_version := connection_type_dict.get("login_version"): + login_version = int(login_version) # type: ignore[assignment] + return ConnectionType.from_values( + device_family, + encryption_type, + login_version, # type: ignore[arg-type] + ) raise SmartDeviceException( f"Invalid connection type data for {connection_type_dict}" ) - def to_dict(self) -> Dict[str, str]: + def to_dict(self) -> Dict[str, Union[str, int]]: """Convert connection params to dict.""" - result = { + result: Dict[str, Union[str, int]] = { "device_family": self.device_family.value, "encryption_type": self.encryption_type.value, } + if self.login_version: + result["login_version"] = self.login_version return result @@ -118,10 +130,11 @@ class DeviceConfig: host: str timeout: Optional[int] = DEFAULT_TIMEOUT port_override: Optional[int] = None - credentials: Credentials = field(default_factory=lambda: Credentials()) + credentials: Optional[Credentials] = None + credentials_hash: Optional[str] = None connection_type: ConnectionType = field( default_factory=lambda: ConnectionType( - DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor + DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1 ) ) @@ -130,15 +143,22 @@ class DeviceConfig: http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False) def __post_init__(self): - if self.credentials is None: - self.credentials = Credentials() if self.connection_type is None: self.connection_type = ConnectionType( DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor ) - def to_dict(self) -> Dict[str, Dict[str, str]]: + def to_dict( + self, + *, + credentials_hash: Optional[str] = None, + exclude_credentials: bool = False, + ) -> Dict[str, Dict[str, str]]: """Convert connection params to dict.""" + if credentials_hash or exclude_credentials: + self.credentials = None + if credentials_hash: + self.credentials_hash = credentials_hash return _dataclass_to_dict(self) @staticmethod diff --git a/kasa/discover.py b/kasa/discover.py index e39122f3..8fbd6ff0 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -422,7 +422,9 @@ class Discover: try: config.connection_type = ConnectionType.from_values( - type_, discovery_result.mgt_encrypt_schm.encrypt_type + type_, + discovery_result.mgt_encrypt_schm.encrypt_type, + discovery_result.mgt_encrypt_schm.lv, ) except SmartDeviceException as ex: raise UnsupportedDeviceException( diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 945346ee..8a77a775 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -41,6 +41,7 @@ https://github.com/python-kasa/python-kasa/pull/117 """ import asyncio +import base64 import datetime import hashlib import logging @@ -99,8 +100,13 @@ class KlapTransport(BaseTransport): self._default_http_client: Optional[httpx.AsyncClient] = None self._local_seed: Optional[bytes] = None - self._local_auth_hash = self.generate_auth_hash(self._credentials) - self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() + if not self._credentials and not self._credentials_hash: + self._credentials = Credentials() + if self._credentials: + self._local_auth_hash = self.generate_auth_hash(self._credentials) + self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() + else: + self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr] self._kasa_setup_auth_hash = None self._blank_auth_hash = None self._handshake_lock = asyncio.Lock() @@ -119,6 +125,11 @@ class KlapTransport(BaseTransport): """Default port for the transport.""" return self.DEFAULT_PORT + @property + def credentials_hash(self) -> str: + """The hashed credentials used by the transport.""" + return base64.b64encode(self._local_auth_hash).decode() + @property def _http_client(self) -> httpx.AsyncClient: if self._config.http_client: diff --git a/kasa/protocol.py b/kasa/protocol.py index c998807c..47d4a90b 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -56,6 +56,7 @@ class BaseTransport(ABC): self._host = config.host self._port = config.port_override or self.default_port self._credentials = config.credentials + self._credentials_hash = config.credentials_hash self._timeout = config.timeout @property @@ -63,6 +64,11 @@ class BaseTransport(ABC): def default_port(self) -> int: """The default port for the transport.""" + @property + @abstractmethod + def credentials_hash(self) -> str: + """The hashed credentials used by the transport.""" + @abstractmethod async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" @@ -120,6 +126,11 @@ class _XorTransport(BaseTransport): """Default port for the transport.""" return self.DEFAULT_PORT + @property + def credentials_hash(self) -> str: + """The hashed credentials used by the transport.""" + return "" + async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" return {} diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index c3561812..912f7cd9 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -245,6 +245,11 @@ class SmartDevice: """The device credentials.""" return self.protocol._transport._credentials + @property + def credentials_hash(self) -> Optional[str]: + """Return the connection parameters the device is using.""" + return self.protocol._transport.credentials_hash + def add_module(self, name: str, module: Module): """Register a module.""" if name in self.modules: diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 785269a3..4e5a96cd 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -38,8 +38,8 @@ class TapoDevice(SmartDevice): async def update(self, update_children: bool = True): """Update the device.""" - if self.credentials is None: - raise AuthenticationException("Device requires authentication.") + if self.credentials is None and self.credentials_hash is None: + raise AuthenticationException("Tapo plug requires authentication.") if self._components_raw is None: resp = await self.protocol.query("component_nego") diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index fc5a1b9e..8ef47000 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -430,6 +430,7 @@ def discovery_mock(all_fixture_data, mocker): query_data: dict device_type: str encrypt_type: str + login_version: Optional[int] = None port_override: Optional[int] = None if "discovery_result" in all_fixture_data: @@ -438,6 +439,9 @@ def discovery_mock(all_fixture_data, mocker): encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][ "encrypt_type" ] + login_version = all_fixture_data["discovery_result"]["mgt_encrypt_schm"].get( + "lv" + ) datagram = ( b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + json_dumps(discovery_data).encode() @@ -450,12 +454,14 @@ def discovery_mock(all_fixture_data, mocker): all_fixture_data, device_type, encrypt_type, + login_version, ) else: sys_info = all_fixture_data["system"]["get_sysinfo"] discovery_data = {"system": {"get_sysinfo": sys_info}} device_type = sys_info.get("mic_type") or sys_info.get("type") encrypt_type = "XOR" + login_version = None datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] dm = _DiscoveryMock( "127.0.0.123", @@ -465,6 +471,7 @@ def discovery_mock(all_fixture_data, mocker): all_fixture_data, device_type, encrypt_type, + login_version, ) def mock_discover(self): diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index ec50321c..064dbaeb 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -1,3 +1,4 @@ +import base64 import copy import logging import re @@ -320,6 +321,11 @@ class FakeSmartTransport(BaseTransport): """Default port for the transport.""" return 80 + @property + def credentials_hash(self): + """The hashed credentials used by the transport.""" + return self._credentials.username + self._credentials.password + "hash" + async def send(self, request: str): request_dict = json_loads(request) method = request_dict["method"] diff --git a/kasa/tests/test_deviceconfig.py b/kasa/tests/test_deviceconfig.py index 7970449d..22d42b81 100644 --- a/kasa/tests/test_deviceconfig.py +++ b/kasa/tests/test_deviceconfig.py @@ -19,3 +19,30 @@ def test_serialization(): config2_dict = json_loads(config_json) config2 = DeviceConfig.from_dict(config2_dict) assert config == config2 + + +def test_credentials_hash(): + config = DeviceConfig( + host="Foo", + http_client=httpx.AsyncClient(), + credentials=Credentials("foo", "bar"), + ) + config_dict = config.to_dict(credentials_hash="credhash") + config_json = json_dumps(config_dict) + config2_dict = json_loads(config_json) + config2 = DeviceConfig.from_dict(config2_dict) + assert config2.credentials_hash == "credhash" + assert config2.credentials is None + + +def test_no_credentials_serialization(): + config = DeviceConfig( + host="Foo", + http_client=httpx.AsyncClient(), + credentials=Credentials("foo", "bar"), + ) + config_dict = config.to_dict(exclude_credentials=True) + config_json = json_dumps(config_dict) + config2_dict = json_loads(config_json) + config2 = DeviceConfig.from_dict(config2_dict) + assert config2.credentials is None diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 396ef2f2..51aedfb7 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -110,11 +110,17 @@ async def test_discover_single(discovery_mock, custom_port, mocker): assert update_mock.call_count == 0 ct = ConnectionType.from_values( - discovery_mock.device_type, discovery_mock.encrypt_type + discovery_mock.device_type, + discovery_mock.encrypt_type, + discovery_mock.login_version, ) uses_http = discovery_mock.default_port == 80 config = DeviceConfig( - host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http + host=host, + port_override=custom_port, + connection_type=ct, + uses_http=uses_http, + credentials=Credentials(), ) assert x.config == config diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 0e74da3b..05ae40f3 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -9,8 +9,11 @@ import sys import pytest +from ..aestransport import AesTransport +from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import SmartDeviceException +from ..klaptransport import KlapTransport, KlapTransportV2 from ..protocol import ( BaseTransport, TPLinkProtocol, @@ -298,3 +301,19 @@ def test_transport_init_signature(class_name_obj): assert ( params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY ) + + +@pytest.mark.parametrize( + "transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport] +) +async def test_transport_credentials_hash(mocker, transport_class): + host = "127.0.0.1" + + credentials = Credentials("Foo", "Bar") + config = DeviceConfig(host, credentials=credentials) + transport = transport_class(config=config) + credentials_hash = transport.credentials_hash + config = DeviceConfig(host, credentials_hash=credentials_hash) + transport = transport_class(config=config) + + assert transport.credentials_hash == credentials_hash