Allow serializing and passing of credentials_hashes in DeviceConfig (#607)

* Allow passing of credentials_hashes in DeviceConfig

* Update following review
This commit is contained in:
sdb9696 2024-01-03 21:46:08 +00:00 committed by GitHub
parent 3692e4812f
commit e9bf9f58ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 183 additions and 34 deletions

View File

@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .deviceconfig import DeviceConfig
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
@ -62,6 +63,16 @@ class AesTransport(BaseTransport):
) -> None:
super().__init__(config=config)
self._login_version = config.connection_type.login_version
if not self._credentials and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._login_params = self._get_login_params()
else:
self._login_params = json_loads(
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
)
self._default_http_client: Optional[httpx.AsyncClient] = None
self._handshake_done = False
@ -80,6 +91,11 @@ class AesTransport(BaseTransport):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:
@ -88,6 +104,12 @@ class AesTransport(BaseTransport):
self._default_http_client = httpx.AsyncClient()
return self._default_http_client
def _get_login_params(self):
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}
def hash_credentials(self, login_v2):
"""Hash the credentials."""
if login_v2:
@ -171,14 +193,12 @@ class AesTransport(BaseTransport):
resp_dict = json_loads(response)
return resp_dict
async def _perform_login_for_version(self, *, login_version: int = 1):
async def perform_login(self):
"""Login to the device."""
self._login_token = None
un, pw = self.hash_credentials(login_version == 2)
password_field_name = "password2" if login_version == 2 else "password"
login_request = {
"method": "login_device",
"params": {password_field_name: pw, "username": un},
"params": self._login_params,
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
@ -187,15 +207,6 @@ class AesTransport(BaseTransport):
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self._perform_login_for_version(login_version=2)
except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake()
await self._perform_login_for_version(login_version=1)
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")

View File

@ -184,6 +184,12 @@ def json_formatter_cb(result, **kwargs):
default=None,
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
)
@click.option(
"--login-version",
envvar="KASA_LOGIN_VERSION",
default=None,
type=int,
)
@click.option(
"--timeout",
envvar="KASA_TIMEOUT",
@ -214,6 +220,13 @@ def json_formatter_cb(result, **kwargs):
envvar="KASA_PASSWORD",
help="Password to use to authenticate to device.",
)
@click.option(
"--credentials-hash",
default=None,
required=False,
envvar="KASA_CREDENTIALS_HASH",
help="Hashed credentials used to authenticate to the device.",
)
@click.version_option(package_name="python-kasa")
@click.pass_context
async def cli(
@ -227,11 +240,13 @@ async def cli(
type,
encrypt_type,
device_family,
login_version,
json,
timeout,
discovery_timeout,
username,
password,
credentials_hash,
):
"""A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help
@ -291,7 +306,10 @@ async def cli(
"username", "Using authentication requires both --username and --password"
)
if username:
credentials = Credentials(username=username, password=password)
else:
credentials = None
if host is None:
echo("No host name given, trying discovery..")
@ -300,13 +318,18 @@ async def cli(
if type is not None:
dev = TYPE_TO_CLASS[type](host)
await dev.update()
elif device_family or encrypt_type:
elif device_family and encrypt_type:
ctype = ConnectionType(
DeviceFamilyType(device_family),
EncryptType(encrypt_type),
login_version,
)
config = DeviceConfig(
host=host, credentials=credentials, timeout=timeout, connection_type=ctype
host=host,
credentials=credentials,
credentials_hash=credentials_hash,
timeout=timeout,
connection_type=ctype,
)
dev = await SmartDevice.connect(config=config)
else:
@ -495,6 +518,7 @@ async def state(dev: SmartDevice):
echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]")
echo(f"\tHost: {dev.host}")
echo(f"\tPort: {dev.port}")
echo(f"\tCredentials hash: {dev.credentials_hash}")
echo(f"\tDevice state: {dev.is_on}")
if dev.is_strip:
echo("\t[bold]== Plugs ==[/bold]")

View File

@ -2,7 +2,7 @@
import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, Union
import httpx
@ -69,21 +69,25 @@ class ConnectionType:
device_family: DeviceFamilyType
encryption_type: EncryptType
login_version: Optional[int] = None
@staticmethod
def from_values(
device_family: str,
encryption_type: str,
login_version: Optional[int] = None,
) -> "ConnectionType":
"""Return connection parameters from string values."""
try:
return ConnectionType(
DeviceFamilyType(device_family),
EncryptType(encryption_type),
login_version,
)
except ValueError as ex:
except (ValueError, TypeError) as ex:
raise SmartDeviceException(
f"Invalid connection parameters for {device_family}.{encryption_type}"
f"Invalid connection parameters for {device_family}."
+ f"{encryption_type}.{login_version}"
) from ex
@staticmethod
@ -94,18 +98,26 @@ class ConnectionType:
and (device_family := connection_type_dict.get("device_family"))
and (encryption_type := connection_type_dict.get("encryption_type"))
):
return ConnectionType.from_values(device_family, encryption_type)
if login_version := connection_type_dict.get("login_version"):
login_version = int(login_version) # type: ignore[assignment]
return ConnectionType.from_values(
device_family,
encryption_type,
login_version, # type: ignore[arg-type]
)
raise SmartDeviceException(
f"Invalid connection type data for {connection_type_dict}"
)
def to_dict(self) -> Dict[str, str]:
def to_dict(self) -> Dict[str, Union[str, int]]:
"""Convert connection params to dict."""
result = {
result: Dict[str, Union[str, int]] = {
"device_family": self.device_family.value,
"encryption_type": self.encryption_type.value,
}
if self.login_version:
result["login_version"] = self.login_version
return result
@ -118,10 +130,11 @@ class DeviceConfig:
host: str
timeout: Optional[int] = DEFAULT_TIMEOUT
port_override: Optional[int] = None
credentials: Credentials = field(default_factory=lambda: Credentials())
credentials: Optional[Credentials] = None
credentials_hash: Optional[str] = None
connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
)
)
@ -130,15 +143,22 @@ class DeviceConfig:
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
def __post_init__(self):
if self.credentials is None:
self.credentials = Credentials()
if self.connection_type is None:
self.connection_type = ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
)
def to_dict(self) -> Dict[str, Dict[str, str]]:
def to_dict(
self,
*,
credentials_hash: Optional[str] = None,
exclude_credentials: bool = False,
) -> Dict[str, Dict[str, str]]:
"""Convert connection params to dict."""
if credentials_hash or exclude_credentials:
self.credentials = None
if credentials_hash:
self.credentials_hash = credentials_hash
return _dataclass_to_dict(self)
@staticmethod

View File

@ -422,7 +422,9 @@ class Discover:
try:
config.connection_type = ConnectionType.from_values(
type_, discovery_result.mgt_encrypt_schm.encrypt_type
type_,
discovery_result.mgt_encrypt_schm.encrypt_type,
discovery_result.mgt_encrypt_schm.lv,
)
except SmartDeviceException as ex:
raise UnsupportedDeviceException(

View File

@ -41,6 +41,7 @@ https://github.com/python-kasa/python-kasa/pull/117
"""
import asyncio
import base64
import datetime
import hashlib
import logging
@ -99,8 +100,13 @@ class KlapTransport(BaseTransport):
self._default_http_client: Optional[httpx.AsyncClient] = None
self._local_seed: Optional[bytes] = None
if not self._credentials and not self._credentials_hash:
self._credentials = Credentials()
if self._credentials:
self._local_auth_hash = self.generate_auth_hash(self._credentials)
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
else:
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
self._kasa_setup_auth_hash = None
self._blank_auth_hash = None
self._handshake_lock = asyncio.Lock()
@ -119,6 +125,11 @@ class KlapTransport(BaseTransport):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(self._local_auth_hash).decode()
@property
def _http_client(self) -> httpx.AsyncClient:
if self._config.http_client:

View File

@ -56,6 +56,7 @@ class BaseTransport(ABC):
self._host = config.host
self._port = config.port_override or self.default_port
self._credentials = config.credentials
self._credentials_hash = config.credentials_hash
self._timeout = config.timeout
@property
@ -63,6 +64,11 @@ class BaseTransport(ABC):
def default_port(self) -> int:
"""The default port for the transport."""
@property
@abstractmethod
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
@abstractmethod
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
@ -120,6 +126,11 @@ class _XorTransport(BaseTransport):
"""Default port for the transport."""
return self.DEFAULT_PORT
@property
def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return ""
async def send(self, request: str) -> Dict:
"""Send a message to the device and return a response."""
return {}

View File

@ -245,6 +245,11 @@ class SmartDevice:
"""The device credentials."""
return self.protocol._transport._credentials
@property
def credentials_hash(self) -> Optional[str]:
"""Return the connection parameters the device is using."""
return self.protocol._transport.credentials_hash
def add_module(self, name: str, module: Module):
"""Register a module."""
if name in self.modules:

View File

@ -38,8 +38,8 @@ class TapoDevice(SmartDevice):
async def update(self, update_children: bool = True):
"""Update the device."""
if self.credentials is None:
raise AuthenticationException("Device requires authentication.")
if self.credentials is None and self.credentials_hash is None:
raise AuthenticationException("Tapo plug requires authentication.")
if self._components_raw is None:
resp = await self.protocol.query("component_nego")

View File

@ -430,6 +430,7 @@ def discovery_mock(all_fixture_data, mocker):
query_data: dict
device_type: str
encrypt_type: str
login_version: Optional[int] = None
port_override: Optional[int] = None
if "discovery_result" in all_fixture_data:
@ -438,6 +439,9 @@ def discovery_mock(all_fixture_data, mocker):
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
"encrypt_type"
]
login_version = all_fixture_data["discovery_result"]["mgt_encrypt_schm"].get(
"lv"
)
datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
@ -450,12 +454,14 @@ def discovery_mock(all_fixture_data, mocker):
all_fixture_data,
device_type,
encrypt_type,
login_version,
)
else:
sys_info = all_fixture_data["system"]["get_sysinfo"]
discovery_data = {"system": {"get_sysinfo": sys_info}}
device_type = sys_info.get("mic_type") or sys_info.get("type")
encrypt_type = "XOR"
login_version = None
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock(
"127.0.0.123",
@ -465,6 +471,7 @@ def discovery_mock(all_fixture_data, mocker):
all_fixture_data,
device_type,
encrypt_type,
login_version,
)
def mock_discover(self):

View File

@ -1,3 +1,4 @@
import base64
import copy
import logging
import re
@ -320,6 +321,11 @@ class FakeSmartTransport(BaseTransport):
"""Default port for the transport."""
return 80
@property
def credentials_hash(self):
"""The hashed credentials used by the transport."""
return self._credentials.username + self._credentials.password + "hash"
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]

View File

@ -19,3 +19,30 @@ def test_serialization():
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2
def test_credentials_hash():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(credentials_hash="credhash")
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config2.credentials_hash == "credhash"
assert config2.credentials is None
def test_no_credentials_serialization():
config = DeviceConfig(
host="Foo",
http_client=httpx.AsyncClient(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(exclude_credentials=True)
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config2.credentials is None

View File

@ -110,11 +110,17 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
assert update_mock.call_count == 0
ct = ConnectionType.from_values(
discovery_mock.device_type, discovery_mock.encrypt_type
discovery_mock.device_type,
discovery_mock.encrypt_type,
discovery_mock.login_version,
)
uses_http = discovery_mock.default_port == 80
config = DeviceConfig(
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
host=host,
port_override=custom_port,
connection_type=ct,
uses_http=uses_http,
credentials=Credentials(),
)
assert x.config == config

View File

@ -9,8 +9,11 @@ import sys
import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..klaptransport import KlapTransport, KlapTransportV2
from ..protocol import (
BaseTransport,
TPLinkProtocol,
@ -298,3 +301,19 @@ def test_transport_init_signature(class_name_obj):
assert (
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
)
@pytest.mark.parametrize(
"transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport]
)
async def test_transport_credentials_hash(mocker, transport_class):
host = "127.0.0.1"
credentials = Credentials("Foo", "Bar")
config = DeviceConfig(host, credentials=credentials)
transport = transport_class(config=config)
credentials_hash = transport.credentials_hash
config = DeviceConfig(host, credentials_hash=credentials_hash)
transport = transport_class(config=config)
assert transport.credentials_hash == credentials_hash