mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Allow serializing and passing of credentials_hashes in DeviceConfig (#607)
* Allow passing of credentials_hashes in DeviceConfig * Update following review
This commit is contained in:
parent
3692e4812f
commit
e9bf9f58ee
@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .credentials import Credentials
|
||||
from .deviceconfig import DeviceConfig
|
||||
from .exceptions import (
|
||||
SMART_AUTHENTICATION_ERRORS,
|
||||
@ -62,6 +63,16 @@ class AesTransport(BaseTransport):
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
self._login_version = config.connection_type.login_version
|
||||
if not self._credentials and not self._credentials_hash:
|
||||
self._credentials = Credentials()
|
||||
if self._credentials:
|
||||
self._login_params = self._get_login_params()
|
||||
else:
|
||||
self._login_params = json_loads(
|
||||
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
self._handshake_done = False
|
||||
@ -80,6 +91,11 @@ class AesTransport(BaseTransport):
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str:
|
||||
"""The hashed credentials used by the transport."""
|
||||
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
@ -88,6 +104,12 @@ class AesTransport(BaseTransport):
|
||||
self._default_http_client = httpx.AsyncClient()
|
||||
return self._default_http_client
|
||||
|
||||
def _get_login_params(self):
|
||||
"""Get the login parameters based on the login_version."""
|
||||
un, pw = self.hash_credentials(self._login_version == 2)
|
||||
password_field_name = "password2" if self._login_version == 2 else "password"
|
||||
return {password_field_name: pw, "username": un}
|
||||
|
||||
def hash_credentials(self, login_v2):
|
||||
"""Hash the credentials."""
|
||||
if login_v2:
|
||||
@ -171,14 +193,12 @@ class AesTransport(BaseTransport):
|
||||
resp_dict = json_loads(response)
|
||||
return resp_dict
|
||||
|
||||
async def _perform_login_for_version(self, *, login_version: int = 1):
|
||||
async def perform_login(self):
|
||||
"""Login to the device."""
|
||||
self._login_token = None
|
||||
un, pw = self.hash_credentials(login_version == 2)
|
||||
password_field_name = "password2" if login_version == 2 else "password"
|
||||
login_request = {
|
||||
"method": "login_device",
|
||||
"params": {password_field_name: pw, "username": un},
|
||||
"params": self._login_params,
|
||||
"request_time_milis": round(time.time() * 1000),
|
||||
}
|
||||
request = json_dumps(login_request)
|
||||
@ -187,15 +207,6 @@ class AesTransport(BaseTransport):
|
||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||
self._login_token = resp_dict["result"]["token"]
|
||||
|
||||
async def perform_login(self) -> None:
|
||||
"""Login to the device."""
|
||||
try:
|
||||
await self._perform_login_for_version(login_version=2)
|
||||
except AuthenticationException:
|
||||
_LOGGER.warning("Login version 2 failed, trying version 1")
|
||||
await self.perform_handshake()
|
||||
await self._perform_login_for_version(login_version=1)
|
||||
|
||||
async def perform_handshake(self):
|
||||
"""Perform the handshake."""
|
||||
_LOGGER.debug("Will perform handshaking...")
|
||||
|
30
kasa/cli.py
30
kasa/cli.py
@ -184,6 +184,12 @@ def json_formatter_cb(result, **kwargs):
|
||||
default=None,
|
||||
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"--login-version",
|
||||
envvar="KASA_LOGIN_VERSION",
|
||||
default=None,
|
||||
type=int,
|
||||
)
|
||||
@click.option(
|
||||
"--timeout",
|
||||
envvar="KASA_TIMEOUT",
|
||||
@ -214,6 +220,13 @@ def json_formatter_cb(result, **kwargs):
|
||||
envvar="KASA_PASSWORD",
|
||||
help="Password to use to authenticate to device.",
|
||||
)
|
||||
@click.option(
|
||||
"--credentials-hash",
|
||||
default=None,
|
||||
required=False,
|
||||
envvar="KASA_CREDENTIALS_HASH",
|
||||
help="Hashed credentials used to authenticate to the device.",
|
||||
)
|
||||
@click.version_option(package_name="python-kasa")
|
||||
@click.pass_context
|
||||
async def cli(
|
||||
@ -227,11 +240,13 @@ async def cli(
|
||||
type,
|
||||
encrypt_type,
|
||||
device_family,
|
||||
login_version,
|
||||
json,
|
||||
timeout,
|
||||
discovery_timeout,
|
||||
username,
|
||||
password,
|
||||
credentials_hash,
|
||||
):
|
||||
"""A tool for controlling TP-Link smart home devices.""" # noqa
|
||||
# no need to perform any checks if we are just displaying the help
|
||||
@ -291,7 +306,10 @@ async def cli(
|
||||
"username", "Using authentication requires both --username and --password"
|
||||
)
|
||||
|
||||
credentials = Credentials(username=username, password=password)
|
||||
if username:
|
||||
credentials = Credentials(username=username, password=password)
|
||||
else:
|
||||
credentials = None
|
||||
|
||||
if host is None:
|
||||
echo("No host name given, trying discovery..")
|
||||
@ -300,13 +318,18 @@ async def cli(
|
||||
if type is not None:
|
||||
dev = TYPE_TO_CLASS[type](host)
|
||||
await dev.update()
|
||||
elif device_family or encrypt_type:
|
||||
elif device_family and encrypt_type:
|
||||
ctype = ConnectionType(
|
||||
DeviceFamilyType(device_family),
|
||||
EncryptType(encrypt_type),
|
||||
login_version,
|
||||
)
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=credentials, timeout=timeout, connection_type=ctype
|
||||
host=host,
|
||||
credentials=credentials,
|
||||
credentials_hash=credentials_hash,
|
||||
timeout=timeout,
|
||||
connection_type=ctype,
|
||||
)
|
||||
dev = await SmartDevice.connect(config=config)
|
||||
else:
|
||||
@ -495,6 +518,7 @@ async def state(dev: SmartDevice):
|
||||
echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]")
|
||||
echo(f"\tHost: {dev.host}")
|
||||
echo(f"\tPort: {dev.port}")
|
||||
echo(f"\tCredentials hash: {dev.credentials_hash}")
|
||||
echo(f"\tDevice state: {dev.is_on}")
|
||||
if dev.is_strip:
|
||||
echo("\t[bold]== Plugs ==[/bold]")
|
||||
|
@ -2,7 +2,7 @@
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
@ -69,21 +69,25 @@ class ConnectionType:
|
||||
|
||||
device_family: DeviceFamilyType
|
||||
encryption_type: EncryptType
|
||||
login_version: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_values(
|
||||
device_family: str,
|
||||
encryption_type: str,
|
||||
login_version: Optional[int] = None,
|
||||
) -> "ConnectionType":
|
||||
"""Return connection parameters from string values."""
|
||||
try:
|
||||
return ConnectionType(
|
||||
DeviceFamilyType(device_family),
|
||||
EncryptType(encryption_type),
|
||||
login_version,
|
||||
)
|
||||
except ValueError as ex:
|
||||
except (ValueError, TypeError) as ex:
|
||||
raise SmartDeviceException(
|
||||
f"Invalid connection parameters for {device_family}.{encryption_type}"
|
||||
f"Invalid connection parameters for {device_family}."
|
||||
+ f"{encryption_type}.{login_version}"
|
||||
) from ex
|
||||
|
||||
@staticmethod
|
||||
@ -94,18 +98,26 @@ class ConnectionType:
|
||||
and (device_family := connection_type_dict.get("device_family"))
|
||||
and (encryption_type := connection_type_dict.get("encryption_type"))
|
||||
):
|
||||
return ConnectionType.from_values(device_family, encryption_type)
|
||||
if login_version := connection_type_dict.get("login_version"):
|
||||
login_version = int(login_version) # type: ignore[assignment]
|
||||
return ConnectionType.from_values(
|
||||
device_family,
|
||||
encryption_type,
|
||||
login_version, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
raise SmartDeviceException(
|
||||
f"Invalid connection type data for {connection_type_dict}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
def to_dict(self) -> Dict[str, Union[str, int]]:
|
||||
"""Convert connection params to dict."""
|
||||
result = {
|
||||
result: Dict[str, Union[str, int]] = {
|
||||
"device_family": self.device_family.value,
|
||||
"encryption_type": self.encryption_type.value,
|
||||
}
|
||||
if self.login_version:
|
||||
result["login_version"] = self.login_version
|
||||
return result
|
||||
|
||||
|
||||
@ -118,10 +130,11 @@ class DeviceConfig:
|
||||
host: str
|
||||
timeout: Optional[int] = DEFAULT_TIMEOUT
|
||||
port_override: Optional[int] = None
|
||||
credentials: Credentials = field(default_factory=lambda: Credentials())
|
||||
credentials: Optional[Credentials] = None
|
||||
credentials_hash: Optional[str] = None
|
||||
connection_type: ConnectionType = field(
|
||||
default_factory=lambda: ConnectionType(
|
||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
|
||||
)
|
||||
)
|
||||
|
||||
@ -130,15 +143,22 @@ class DeviceConfig:
|
||||
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.credentials is None:
|
||||
self.credentials = Credentials()
|
||||
if self.connection_type is None:
|
||||
self.connection_type = ConnectionType(
|
||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, str]]:
|
||||
def to_dict(
|
||||
self,
|
||||
*,
|
||||
credentials_hash: Optional[str] = None,
|
||||
exclude_credentials: bool = False,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""Convert connection params to dict."""
|
||||
if credentials_hash or exclude_credentials:
|
||||
self.credentials = None
|
||||
if credentials_hash:
|
||||
self.credentials_hash = credentials_hash
|
||||
return _dataclass_to_dict(self)
|
||||
|
||||
@staticmethod
|
||||
|
@ -422,7 +422,9 @@ class Discover:
|
||||
|
||||
try:
|
||||
config.connection_type = ConnectionType.from_values(
|
||||
type_, discovery_result.mgt_encrypt_schm.encrypt_type
|
||||
type_,
|
||||
discovery_result.mgt_encrypt_schm.encrypt_type,
|
||||
discovery_result.mgt_encrypt_schm.lv,
|
||||
)
|
||||
except SmartDeviceException as ex:
|
||||
raise UnsupportedDeviceException(
|
||||
|
@ -41,6 +41,7 @@ https://github.com/python-kasa/python-kasa/pull/117
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
@ -99,8 +100,13 @@ class KlapTransport(BaseTransport):
|
||||
|
||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||
self._local_seed: Optional[bytes] = None
|
||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||
if not self._credentials and not self._credentials_hash:
|
||||
self._credentials = Credentials()
|
||||
if self._credentials:
|
||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||
else:
|
||||
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
|
||||
self._kasa_setup_auth_hash = None
|
||||
self._blank_auth_hash = None
|
||||
self._handshake_lock = asyncio.Lock()
|
||||
@ -119,6 +125,11 @@ class KlapTransport(BaseTransport):
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str:
|
||||
"""The hashed credentials used by the transport."""
|
||||
return base64.b64encode(self._local_auth_hash).decode()
|
||||
|
||||
@property
|
||||
def _http_client(self) -> httpx.AsyncClient:
|
||||
if self._config.http_client:
|
||||
|
@ -56,6 +56,7 @@ class BaseTransport(ABC):
|
||||
self._host = config.host
|
||||
self._port = config.port_override or self.default_port
|
||||
self._credentials = config.credentials
|
||||
self._credentials_hash = config.credentials_hash
|
||||
self._timeout = config.timeout
|
||||
|
||||
@property
|
||||
@ -63,6 +64,11 @@ class BaseTransport(ABC):
|
||||
def default_port(self) -> int:
|
||||
"""The default port for the transport."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def credentials_hash(self) -> str:
|
||||
"""The hashed credentials used by the transport."""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, request: str) -> Dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
@ -120,6 +126,11 @@ class _XorTransport(BaseTransport):
|
||||
"""Default port for the transport."""
|
||||
return self.DEFAULT_PORT
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> str:
|
||||
"""The hashed credentials used by the transport."""
|
||||
return ""
|
||||
|
||||
async def send(self, request: str) -> Dict:
|
||||
"""Send a message to the device and return a response."""
|
||||
return {}
|
||||
|
@ -245,6 +245,11 @@ class SmartDevice:
|
||||
"""The device credentials."""
|
||||
return self.protocol._transport._credentials
|
||||
|
||||
@property
|
||||
def credentials_hash(self) -> Optional[str]:
|
||||
"""Return the connection parameters the device is using."""
|
||||
return self.protocol._transport.credentials_hash
|
||||
|
||||
def add_module(self, name: str, module: Module):
|
||||
"""Register a module."""
|
||||
if name in self.modules:
|
||||
|
@ -38,8 +38,8 @@ class TapoDevice(SmartDevice):
|
||||
|
||||
async def update(self, update_children: bool = True):
|
||||
"""Update the device."""
|
||||
if self.credentials is None:
|
||||
raise AuthenticationException("Device requires authentication.")
|
||||
if self.credentials is None and self.credentials_hash is None:
|
||||
raise AuthenticationException("Tapo plug requires authentication.")
|
||||
|
||||
if self._components_raw is None:
|
||||
resp = await self.protocol.query("component_nego")
|
||||
|
@ -430,6 +430,7 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
query_data: dict
|
||||
device_type: str
|
||||
encrypt_type: str
|
||||
login_version: Optional[int] = None
|
||||
port_override: Optional[int] = None
|
||||
|
||||
if "discovery_result" in all_fixture_data:
|
||||
@ -438,6 +439,9 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
|
||||
"encrypt_type"
|
||||
]
|
||||
login_version = all_fixture_data["discovery_result"]["mgt_encrypt_schm"].get(
|
||||
"lv"
|
||||
)
|
||||
datagram = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
@ -450,12 +454,14 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
all_fixture_data,
|
||||
device_type,
|
||||
encrypt_type,
|
||||
login_version,
|
||||
)
|
||||
else:
|
||||
sys_info = all_fixture_data["system"]["get_sysinfo"]
|
||||
discovery_data = {"system": {"get_sysinfo": sys_info}}
|
||||
device_type = sys_info.get("mic_type") or sys_info.get("type")
|
||||
encrypt_type = "XOR"
|
||||
login_version = None
|
||||
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
||||
dm = _DiscoveryMock(
|
||||
"127.0.0.123",
|
||||
@ -465,6 +471,7 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
all_fixture_data,
|
||||
device_type,
|
||||
encrypt_type,
|
||||
login_version,
|
||||
)
|
||||
|
||||
def mock_discover(self):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
@ -320,6 +321,11 @@ class FakeSmartTransport(BaseTransport):
|
||||
"""Default port for the transport."""
|
||||
return 80
|
||||
|
||||
@property
|
||||
def credentials_hash(self):
|
||||
"""The hashed credentials used by the transport."""
|
||||
return self._credentials.username + self._credentials.password + "hash"
|
||||
|
||||
async def send(self, request: str):
|
||||
request_dict = json_loads(request)
|
||||
method = request_dict["method"]
|
||||
|
@ -19,3 +19,30 @@ def test_serialization():
|
||||
config2_dict = json_loads(config_json)
|
||||
config2 = DeviceConfig.from_dict(config2_dict)
|
||||
assert config == config2
|
||||
|
||||
|
||||
def test_credentials_hash():
|
||||
config = DeviceConfig(
|
||||
host="Foo",
|
||||
http_client=httpx.AsyncClient(),
|
||||
credentials=Credentials("foo", "bar"),
|
||||
)
|
||||
config_dict = config.to_dict(credentials_hash="credhash")
|
||||
config_json = json_dumps(config_dict)
|
||||
config2_dict = json_loads(config_json)
|
||||
config2 = DeviceConfig.from_dict(config2_dict)
|
||||
assert config2.credentials_hash == "credhash"
|
||||
assert config2.credentials is None
|
||||
|
||||
|
||||
def test_no_credentials_serialization():
|
||||
config = DeviceConfig(
|
||||
host="Foo",
|
||||
http_client=httpx.AsyncClient(),
|
||||
credentials=Credentials("foo", "bar"),
|
||||
)
|
||||
config_dict = config.to_dict(exclude_credentials=True)
|
||||
config_json = json_dumps(config_dict)
|
||||
config2_dict = json_loads(config_json)
|
||||
config2 = DeviceConfig.from_dict(config2_dict)
|
||||
assert config2.credentials is None
|
||||
|
@ -110,11 +110,17 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
|
||||
assert update_mock.call_count == 0
|
||||
|
||||
ct = ConnectionType.from_values(
|
||||
discovery_mock.device_type, discovery_mock.encrypt_type
|
||||
discovery_mock.device_type,
|
||||
discovery_mock.encrypt_type,
|
||||
discovery_mock.login_version,
|
||||
)
|
||||
uses_http = discovery_mock.default_port == 80
|
||||
config = DeviceConfig(
|
||||
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
|
||||
host=host,
|
||||
port_override=custom_port,
|
||||
connection_type=ct,
|
||||
uses_http=uses_http,
|
||||
credentials=Credentials(),
|
||||
)
|
||||
assert x.config == config
|
||||
|
||||
|
@ -9,8 +9,11 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..klaptransport import KlapTransport, KlapTransportV2
|
||||
from ..protocol import (
|
||||
BaseTransport,
|
||||
TPLinkProtocol,
|
||||
@ -298,3 +301,19 @@ def test_transport_init_signature(class_name_obj):
|
||||
assert (
|
||||
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport]
|
||||
)
|
||||
async def test_transport_credentials_hash(mocker, transport_class):
|
||||
host = "127.0.0.1"
|
||||
|
||||
credentials = Credentials("Foo", "Bar")
|
||||
config = DeviceConfig(host, credentials=credentials)
|
||||
transport = transport_class(config=config)
|
||||
credentials_hash = transport.credentials_hash
|
||||
config = DeviceConfig(host, credentials_hash=credentials_hash)
|
||||
transport = transport_class(config=config)
|
||||
|
||||
assert transport.credentials_hash == credentials_hash
|
||||
|
Loading…
Reference in New Issue
Block a user