Allow serializing and passing of credentials_hashes in DeviceConfig (#607)

* Allow passing of credentials_hashes in DeviceConfig

* Update following review
This commit is contained in:
sdb9696 2024-01-03 21:46:08 +00:00 committed by GitHub
parent 3692e4812f
commit e9bf9f58ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 183 additions and 34 deletions

View File

@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig from .deviceconfig import DeviceConfig
from .exceptions import ( from .exceptions import (
SMART_AUTHENTICATION_ERRORS, SMART_AUTHENTICATION_ERRORS,
@ -62,6 +63,16 @@ class AesTransport(BaseTransport):
) -> None: ) -> None:
super().__init__(config=config) super().__init__(config=config)
self._login_version = config.connection_type.login_version
if not self._credentials and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params()
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_http_client: Optional[httpx.AsyncClient] = None self._default_http_client: Optional[httpx.AsyncClient] = None
self._handshake_done = False self._handshake_done = False
@ -80,6 +91,11 @@ class AesTransport(BaseTransport):
"""Default port for the transport.""" """Default port for the transport."""
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
@property @property
def _http_client(self) -> httpx.AsyncClient: def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client: if self._config.http_client:
@ -88,6 +104,12 @@ class AesTransport(BaseTransport):
self._default_http_client = httpx.AsyncClient() self._default_http_client = httpx.AsyncClient()
return self._default_http_client return self._default_http_client
def _get_login_params(self):
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
def hash_credentials(self, login_v2): def hash_credentials(self, login_v2):
"""Hash the credentials.""" """Hash the credentials."""
if login_v2: if login_v2:
@ -171,14 +193,12 @@ class AesTransport(BaseTransport):
resp_dict = json_loads(response) resp_dict = json_loads(response)
return resp_dict return resp_dict
async def _perform_login_for_version(self, *, login_version: int = 1): async def perform_login(self):
"""Login to the device.""" """Login to the device."""
self._login_token = None self._login_token = None
un, pw = self.hash_credentials(login_version == 2)
password_field_name = "password2" if login_version == 2 else "password"
login_request = { login_request = {
"method": "login_device", "method": "login_device",
"params": {password_field_name: pw, "username": un}, "params": self._login_params,
"request_time_milis": round(time.time() * 1000), "request_time_milis": round(time.time() * 1000),
} }
request = json_dumps(login_request) request = json_dumps(login_request)
@ -187,15 +207,6 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Error logging in") self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"] self._login_token = resp_dict["result"]["token"]
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self._perform_login_for_version(login_version=2)
except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake()
await self._perform_login_for_version(login_version=1)
async def perform_handshake(self): async def perform_handshake(self):
"""Perform the handshake.""" """Perform the handshake."""
_LOGGER.debug("Will perform handshaking...") _LOGGER.debug("Will perform handshaking...")

View File

@ -184,6 +184,12 @@ def json_formatter_cb(result, **kwargs):
default=None, default=None,
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False), type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
) )
@click.option(
"--login-version",
envvar="KASA_LOGIN_VERSION",
default=None,
type=int,
)
@click.option( @click.option(
"--timeout", "--timeout",
envvar="KASA_TIMEOUT", envvar="KASA_TIMEOUT",
@ -214,6 +220,13 @@ def json_formatter_cb(result, **kwargs):
envvar="KASA_PASSWORD", envvar="KASA_PASSWORD",
help="Password to use to authenticate to device.", help="Password to use to authenticate to device.",
) )
@click.option(
"--credentials-hash",
default=None,
required=False,
envvar="KASA_CREDENTIALS_HASH",
help="Hashed credentials used to authenticate to the device.",
)
@click.version_option(package_name="python-kasa") @click.version_option(package_name="python-kasa")
@click.pass_context @click.pass_context
async def cli( async def cli(
@ -227,11 +240,13 @@ async def cli(
type, type,
encrypt_type, encrypt_type,
device_family, device_family,
login_version,
json, json,
timeout, timeout,
discovery_timeout, discovery_timeout,
username, username,
password, password,
credentials_hash,
): ):
"""A tool for controlling TP-Link smart home devices.""" # noqa """A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help # no need to perform any checks if we are just displaying the help
@ -291,7 +306,10 @@ async def cli(
"username", "Using authentication requires both --username and --password" "username", "Using authentication requires both --username and --password"
) )
credentials = Credentials(username=username, password=password) if username:
credentials = Credentials(username=username, password=password)
else:
credentials = None
if host is None: if host is None:
echo("No host name given, trying discovery..") echo("No host name given, trying discovery..")
@ -300,13 +318,18 @@ async def cli(
if type is not None: if type is not None:
dev = TYPE_TO_CLASS[type](host) dev = TYPE_TO_CLASS[type](host)
await dev.update() await dev.update()
elif device_family or encrypt_type: elif device_family and encrypt_type:
ctype = ConnectionType( ctype = ConnectionType(
DeviceFamilyType(device_family), DeviceFamilyType(device_family),
EncryptType(encrypt_type), EncryptType(encrypt_type),
login_version,
) )
config = DeviceConfig( config = DeviceConfig(
host=host, credentials=credentials, timeout=timeout, connection_type=ctype host=host,
credentials=credentials,
credentials_hash=credentials_hash,
timeout=timeout,
connection_type=ctype,
) )
dev = await SmartDevice.connect(config=config) dev = await SmartDevice.connect(config=config)
else: else:
@ -495,6 +518,7 @@ async def state(dev: SmartDevice):
echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]") echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]")
echo(f"\tHost: {dev.host}") echo(f"\tHost: {dev.host}")
echo(f"\tPort: {dev.port}") echo(f"\tPort: {dev.port}")
echo(f"\tCredentials hash: {dev.credentials_hash}")
echo(f"\tDevice state: {dev.is_on}") echo(f"\tDevice state: {dev.is_on}")
if dev.is_strip: if dev.is_strip:
echo("\t[bold]== Plugs ==[/bold]") echo("\t[bold]== Plugs ==[/bold]")

View File

@ -2,7 +2,7 @@
import logging import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional, Union
import httpx import httpx
@ -69,21 +69,25 @@ class ConnectionType:
device_family: DeviceFamilyType device_family: DeviceFamilyType
encryption_type: EncryptType encryption_type: EncryptType
login_version: Optional[int] = None
@staticmethod @staticmethod
def from_values( def from_values(
device_family: str, device_family: str,
encryption_type: str, encryption_type: str,
login_version: Optional[int] = None,
) -> "ConnectionType": ) -> "ConnectionType":
"""Return connection parameters from string values.""" """Return connection parameters from string values."""
try: try:
return ConnectionType( return ConnectionType(
DeviceFamilyType(device_family), DeviceFamilyType(device_family),
EncryptType(encryption_type), EncryptType(encryption_type),
login_version,
) )
except ValueError as ex: except (ValueError, TypeError) as ex:
raise SmartDeviceException( raise SmartDeviceException(
f"Invalid connection parameters for {device_family}.{encryption_type}" f"Invalid connection parameters for {device_family}."
+ f"{encryption_type}.{login_version}"
) from ex ) from ex
@staticmethod @staticmethod
@ -94,18 +98,26 @@ class ConnectionType:
and (device_family := connection_type_dict.get("device_family")) and (device_family := connection_type_dict.get("device_family"))
and (encryption_type := connection_type_dict.get("encryption_type")) and (encryption_type := connection_type_dict.get("encryption_type"))
): ):
return ConnectionType.from_values(device_family, encryption_type) if login_version := connection_type_dict.get("login_version"):
login_version = int(login_version) # type: ignore[assignment]
return ConnectionType.from_values(
device_family,
encryption_type,
login_version, # type: ignore[arg-type]
)
raise SmartDeviceException( raise SmartDeviceException(
f"Invalid connection type data for {connection_type_dict}" f"Invalid connection type data for {connection_type_dict}"
) )
def to_dict(self) -> Dict[str, str]: def to_dict(self) -> Dict[str, Union[str, int]]:
"""Convert connection params to dict.""" """Convert connection params to dict."""
result = { result: Dict[str, Union[str, int]] = {
"device_family": self.device_family.value, "device_family": self.device_family.value,
"encryption_type": self.encryption_type.value, "encryption_type": self.encryption_type.value,
} }
if self.login_version:
result["login_version"] = self.login_version
return result return result
@ -118,10 +130,11 @@ class DeviceConfig:
host: str host: str
timeout: Optional[int] = DEFAULT_TIMEOUT timeout: Optional[int] = DEFAULT_TIMEOUT
port_override: Optional[int] = None port_override: Optional[int] = None
credentials: Credentials = field(default_factory=lambda: Credentials()) credentials: Optional[Credentials] = None
credentials_hash: Optional[str] = None
connection_type: ConnectionType = field( connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType( default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
) )
) )
@ -130,15 +143,22 @@ class DeviceConfig:
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False) http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
def __post_init__(self): def __post_init__(self):
if self.credentials is None:
self.credentials = Credentials()
if self.connection_type is None: if self.connection_type is None:
self.connection_type = ConnectionType( self.connection_type = ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
) )
def to_dict(self) -> Dict[str, Dict[str, str]]: def to_dict(
self,
*,
credentials_hash: Optional[str] = None,
exclude_credentials: bool = False,
) -> Dict[str, Dict[str, str]]:
"""Convert connection params to dict.""" """Convert connection params to dict."""
if credentials_hash or exclude_credentials:
self.credentials = None
if credentials_hash:
self.credentials_hash = credentials_hash
return _dataclass_to_dict(self) return _dataclass_to_dict(self)
@staticmethod @staticmethod

View File

@ -422,7 +422,9 @@ class Discover:
try: try:
config.connection_type = ConnectionType.from_values( config.connection_type = ConnectionType.from_values(
type_, discovery_result.mgt_encrypt_schm.encrypt_type type_,
discovery_result.mgt_encrypt_schm.encrypt_type,
discovery_result.mgt_encrypt_schm.lv,
) )
except SmartDeviceException as ex: except SmartDeviceException as ex:
raise UnsupportedDeviceException( raise UnsupportedDeviceException(

View File

@ -41,6 +41,7 @@ https://github.com/python-kasa/python-kasa/pull/117
""" """
import asyncio import asyncio
import base64
import datetime import datetime
import hashlib import hashlib
import logging import logging
@ -99,8 +100,13 @@ class KlapTransport(BaseTransport):
self._default_http_client: Optional[httpx.AsyncClient] = None self._default_http_client: Optional[httpx.AsyncClient] = None
self._local_seed: Optional[bytes] = None self._local_seed: Optional[bytes] = None
self._local_auth_hash = self.generate_auth_hash(self._credentials) if not self._credentials and not self._credentials_hash:
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() self._credentials = Credentials()
if self._credentials:
self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._kasa_setup_auth_hash = None self._kasa_setup_auth_hash = None
self._blank_auth_hash = None self._blank_auth_hash = None
self._handshake_lock = asyncio.Lock() self._handshake_lock = asyncio.Lock()
@ -119,6 +125,11 @@ class KlapTransport(BaseTransport):
"""Default port for the transport.""" """Default port for the transport."""
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(self._local_auth_hash).decode()
@property @property
def _http_client(self) -> httpx.AsyncClient: def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client: if self._config.http_client:

View File

@ -56,6 +56,7 @@ class BaseTransport(ABC):
self._host = config.host self._host = config.host
self._port = config.port_override or self.default_port self._port = config.port_override or self.default_port
self._credentials = config.credentials self._credentials = config.credentials
self._credentials_hash = config.credentials_hash
self._timeout = config.timeout self._timeout = config.timeout
@property @property
@ -63,6 +64,11 @@ class BaseTransport(ABC):
def default_port(self) -> int: def default_port(self) -> int:
"""The default port for the transport.""" """The default port for the transport."""
@property
@abstractmethod
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
@abstractmethod @abstractmethod
async def send(self, request: str) -> Dict: async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response.""" """Send a message to the device and return a response."""
@ -120,6 +126,11 @@ class _XorTransport(BaseTransport):
"""Default port for the transport.""" """Default port for the transport."""
return self.DEFAULT_PORT return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return ""
async def send(self, request: str) -> Dict: async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response.""" """Send a message to the device and return a response."""
return {} return {}

View File

@ -245,6 +245,11 @@ class SmartDevice:
"""The device credentials.""" """The device credentials."""
return self.protocol._transport._credentials return self.protocol._transport._credentials
@property
def credentials_hash(self) -> Optional[str]:
"""Return the connection parameters the device is using."""
return self.protocol._transport.credentials_hash
def add_module(self, name: str, module: Module): def add_module(self, name: str, module: Module):
"""Register a module.""" """Register a module."""
if name in self.modules: if name in self.modules:

View File

@ -38,8 +38,8 @@ class TapoDevice(SmartDevice):
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""
if self.credentials is None: if self.credentials is None and self.credentials_hash is None:
raise AuthenticationException("Device requires authentication.") raise AuthenticationException("Tapo plug requires authentication.")
if self._components_raw is None: if self._components_raw is None:
resp = await self.protocol.query("component_nego") resp = await self.protocol.query("component_nego")

View File

@ -430,6 +430,7 @@ def discovery_mock(all_fixture_data, mocker):
query_data: dict query_data: dict
device_type: str device_type: str
encrypt_type: str encrypt_type: str
login_version: Optional[int] = None
port_override: Optional[int] = None port_override: Optional[int] = None
if "discovery_result" in all_fixture_data: if "discovery_result" in all_fixture_data:
@ -438,6 +439,9 @@ def discovery_mock(all_fixture_data, mocker):
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][ encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
"encrypt_type" "encrypt_type"
] ]
login_version = all_fixture_data["discovery_result"]["mgt_encrypt_schm"].get(
"lv"
)
datagram = ( datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode() + json_dumps(discovery_data).encode()
@ -450,12 +454,14 @@ def discovery_mock(all_fixture_data, mocker):
all_fixture_data, all_fixture_data,
device_type, device_type,
encrypt_type, encrypt_type,
login_version,
) )
else: else:
sys_info = all_fixture_data["system"]["get_sysinfo"] sys_info = all_fixture_data["system"]["get_sysinfo"]
discovery_data = {"system": {"get_sysinfo": sys_info}} discovery_data = {"system": {"get_sysinfo": sys_info}}
device_type = sys_info.get("mic_type") or sys_info.get("type") device_type = sys_info.get("mic_type") or sys_info.get("type")
encrypt_type = "XOR" encrypt_type = "XOR"
login_version = None
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock( dm = _DiscoveryMock(
"127.0.0.123", "127.0.0.123",
@ -465,6 +471,7 @@ def discovery_mock(all_fixture_data, mocker):
all_fixture_data, all_fixture_data,
device_type, device_type,
encrypt_type, encrypt_type,
login_version,
) )
def mock_discover(self): def mock_discover(self):

View File

@ -1,3 +1,4 @@
import base64
import copy import copy
import logging import logging
import re import re
@ -320,6 +321,11 @@ class FakeSmartTransport(BaseTransport):
"""Default port for the transport.""" """Default port for the transport."""
return 80 return 80
@property
def credentials_hash(self):
"""The hashed credentials used by the transport."""
return self._credentials.username + self._credentials.password + "hash"
async def send(self, request: str): async def send(self, request: str):
request_dict = json_loads(request) request_dict = json_loads(request)
method = request_dict["method"] method = request_dict["method"]

View File

@ -19,3 +19,30 @@ def test_serialization():
config2_dict = json_loads(config_json) config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict) config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2 assert config == config2
def test_credentials_hash():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(credentials_hash="credhash")
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config2.credentials_hash == "credhash"
assert config2.credentials is None
def test_no_credentials_serialization():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(exclude_credentials=True)
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config2.credentials is None

View File

@ -110,11 +110,17 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
assert update_mock.call_count == 0 assert update_mock.call_count == 0
ct = ConnectionType.from_values( ct = ConnectionType.from_values(
discovery_mock.device_type, discovery_mock.encrypt_type discovery_mock.device_type,
discovery_mock.encrypt_type,
discovery_mock.login_version,
) )
uses_http = discovery_mock.default_port == 80 uses_http = discovery_mock.default_port == 80
config = DeviceConfig( config = DeviceConfig(
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http host=host,
port_override=custom_port,
connection_type=ct,
uses_http=uses_http,
credentials=Credentials(),
) )
assert x.config == config assert x.config == config

View File

@ -9,8 +9,11 @@ import sys
import pytest import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException from ..exceptions import SmartDeviceException
from ..klaptransport import KlapTransport, KlapTransportV2
from ..protocol import ( from ..protocol import (
BaseTransport, BaseTransport,
TPLinkProtocol, TPLinkProtocol,
@ -298,3 +301,19 @@ def test_transport_init_signature(class_name_obj):
assert ( assert (
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
) )
@pytest.mark.parametrize(
"transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport]
)
async def test_transport_credentials_hash(mocker, transport_class):
host = "127.0.0.1"
credentials = Credentials("Foo", "Bar")
config = DeviceConfig(host, credentials=credentials)
transport = transport_class(config=config)
credentials_hash = transport.credentials_hash
config = DeviceConfig(host, credentials_hash=credentials_hash)
transport = transport_class(config=config)
assert transport.credentials_hash == credentials_hash