mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 03:33:35 +00:00
Allow serializing and passing of credentials_hashes in DeviceConfig (#607)
* Allow passing of credentials_hashes in DeviceConfig * Update following review
This commit is contained in:
parent
3692e4812f
commit
e9bf9f58ee
@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
|
|||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
|
from .credentials import Credentials
|
||||||
from .deviceconfig import DeviceConfig
|
from .deviceconfig import DeviceConfig
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
SMART_AUTHENTICATION_ERRORS,
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
@ -62,6 +63,16 @@ class AesTransport(BaseTransport):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
self._login_version = config.connection_type.login_version
|
||||||
|
if not self._credentials and not self._credentials_hash:
|
||||||
|
self._credentials = Credentials()
|
||||||
|
if self._credentials:
|
||||||
|
self._login_params = self._get_login_params()
|
||||||
|
else:
|
||||||
|
self._login_params = json_loads(
|
||||||
|
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
|
||||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
@ -80,6 +91,11 @@ class AesTransport(BaseTransport):
|
|||||||
"""Default port for the transport."""
|
"""Default port for the transport."""
|
||||||
return self.DEFAULT_PORT
|
return self.DEFAULT_PORT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_hash(self) -> str:
|
||||||
|
"""The hashed credentials used by the transport."""
|
||||||
|
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _http_client(self) -> httpx.AsyncClient:
|
def _http_client(self) -> httpx.AsyncClient:
|
||||||
if self._config.http_client:
|
if self._config.http_client:
|
||||||
@ -88,6 +104,12 @@ class AesTransport(BaseTransport):
|
|||||||
self._default_http_client = httpx.AsyncClient()
|
self._default_http_client = httpx.AsyncClient()
|
||||||
return self._default_http_client
|
return self._default_http_client
|
||||||
|
|
||||||
|
def _get_login_params(self):
|
||||||
|
"""Get the login parameters based on the login_version."""
|
||||||
|
un, pw = self.hash_credentials(self._login_version == 2)
|
||||||
|
password_field_name = "password2" if self._login_version == 2 else "password"
|
||||||
|
return {password_field_name: pw, "username": un}
|
||||||
|
|
||||||
def hash_credentials(self, login_v2):
|
def hash_credentials(self, login_v2):
|
||||||
"""Hash the credentials."""
|
"""Hash the credentials."""
|
||||||
if login_v2:
|
if login_v2:
|
||||||
@ -171,14 +193,12 @@ class AesTransport(BaseTransport):
|
|||||||
resp_dict = json_loads(response)
|
resp_dict = json_loads(response)
|
||||||
return resp_dict
|
return resp_dict
|
||||||
|
|
||||||
async def _perform_login_for_version(self, *, login_version: int = 1):
|
async def perform_login(self):
|
||||||
"""Login to the device."""
|
"""Login to the device."""
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
un, pw = self.hash_credentials(login_version == 2)
|
|
||||||
password_field_name = "password2" if login_version == 2 else "password"
|
|
||||||
login_request = {
|
login_request = {
|
||||||
"method": "login_device",
|
"method": "login_device",
|
||||||
"params": {password_field_name: pw, "username": un},
|
"params": self._login_params,
|
||||||
"request_time_milis": round(time.time() * 1000),
|
"request_time_milis": round(time.time() * 1000),
|
||||||
}
|
}
|
||||||
request = json_dumps(login_request)
|
request = json_dumps(login_request)
|
||||||
@ -187,15 +207,6 @@ class AesTransport(BaseTransport):
|
|||||||
self._handle_response_error_code(resp_dict, "Error logging in")
|
self._handle_response_error_code(resp_dict, "Error logging in")
|
||||||
self._login_token = resp_dict["result"]["token"]
|
self._login_token = resp_dict["result"]["token"]
|
||||||
|
|
||||||
async def perform_login(self) -> None:
|
|
||||||
"""Login to the device."""
|
|
||||||
try:
|
|
||||||
await self._perform_login_for_version(login_version=2)
|
|
||||||
except AuthenticationException:
|
|
||||||
_LOGGER.warning("Login version 2 failed, trying version 1")
|
|
||||||
await self.perform_handshake()
|
|
||||||
await self._perform_login_for_version(login_version=1)
|
|
||||||
|
|
||||||
async def perform_handshake(self):
|
async def perform_handshake(self):
|
||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
_LOGGER.debug("Will perform handshaking...")
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
28
kasa/cli.py
28
kasa/cli.py
@ -184,6 +184,12 @@ def json_formatter_cb(result, **kwargs):
|
|||||||
default=None,
|
default=None,
|
||||||
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
|
type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--login-version",
|
||||||
|
envvar="KASA_LOGIN_VERSION",
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--timeout",
|
"--timeout",
|
||||||
envvar="KASA_TIMEOUT",
|
envvar="KASA_TIMEOUT",
|
||||||
@ -214,6 +220,13 @@ def json_formatter_cb(result, **kwargs):
|
|||||||
envvar="KASA_PASSWORD",
|
envvar="KASA_PASSWORD",
|
||||||
help="Password to use to authenticate to device.",
|
help="Password to use to authenticate to device.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--credentials-hash",
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
envvar="KASA_CREDENTIALS_HASH",
|
||||||
|
help="Hashed credentials used to authenticate to the device.",
|
||||||
|
)
|
||||||
@click.version_option(package_name="python-kasa")
|
@click.version_option(package_name="python-kasa")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
async def cli(
|
async def cli(
|
||||||
@ -227,11 +240,13 @@ async def cli(
|
|||||||
type,
|
type,
|
||||||
encrypt_type,
|
encrypt_type,
|
||||||
device_family,
|
device_family,
|
||||||
|
login_version,
|
||||||
json,
|
json,
|
||||||
timeout,
|
timeout,
|
||||||
discovery_timeout,
|
discovery_timeout,
|
||||||
username,
|
username,
|
||||||
password,
|
password,
|
||||||
|
credentials_hash,
|
||||||
):
|
):
|
||||||
"""A tool for controlling TP-Link smart home devices.""" # noqa
|
"""A tool for controlling TP-Link smart home devices.""" # noqa
|
||||||
# no need to perform any checks if we are just displaying the help
|
# no need to perform any checks if we are just displaying the help
|
||||||
@ -291,7 +306,10 @@ async def cli(
|
|||||||
"username", "Using authentication requires both --username and --password"
|
"username", "Using authentication requires both --username and --password"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if username:
|
||||||
credentials = Credentials(username=username, password=password)
|
credentials = Credentials(username=username, password=password)
|
||||||
|
else:
|
||||||
|
credentials = None
|
||||||
|
|
||||||
if host is None:
|
if host is None:
|
||||||
echo("No host name given, trying discovery..")
|
echo("No host name given, trying discovery..")
|
||||||
@ -300,13 +318,18 @@ async def cli(
|
|||||||
if type is not None:
|
if type is not None:
|
||||||
dev = TYPE_TO_CLASS[type](host)
|
dev = TYPE_TO_CLASS[type](host)
|
||||||
await dev.update()
|
await dev.update()
|
||||||
elif device_family or encrypt_type:
|
elif device_family and encrypt_type:
|
||||||
ctype = ConnectionType(
|
ctype = ConnectionType(
|
||||||
DeviceFamilyType(device_family),
|
DeviceFamilyType(device_family),
|
||||||
EncryptType(encrypt_type),
|
EncryptType(encrypt_type),
|
||||||
|
login_version,
|
||||||
)
|
)
|
||||||
config = DeviceConfig(
|
config = DeviceConfig(
|
||||||
host=host, credentials=credentials, timeout=timeout, connection_type=ctype
|
host=host,
|
||||||
|
credentials=credentials,
|
||||||
|
credentials_hash=credentials_hash,
|
||||||
|
timeout=timeout,
|
||||||
|
connection_type=ctype,
|
||||||
)
|
)
|
||||||
dev = await SmartDevice.connect(config=config)
|
dev = await SmartDevice.connect(config=config)
|
||||||
else:
|
else:
|
||||||
@ -495,6 +518,7 @@ async def state(dev: SmartDevice):
|
|||||||
echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]")
|
echo(f"[bold]== {dev.alias} - {dev.model} ==[/bold]")
|
||||||
echo(f"\tHost: {dev.host}")
|
echo(f"\tHost: {dev.host}")
|
||||||
echo(f"\tPort: {dev.port}")
|
echo(f"\tPort: {dev.port}")
|
||||||
|
echo(f"\tCredentials hash: {dev.credentials_hash}")
|
||||||
echo(f"\tDevice state: {dev.is_on}")
|
echo(f"\tDevice state: {dev.is_on}")
|
||||||
if dev.is_strip:
|
if dev.is_strip:
|
||||||
echo("\t[bold]== Plugs ==[/bold]")
|
echo("\t[bold]== Plugs ==[/bold]")
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@ -69,21 +69,25 @@ class ConnectionType:
|
|||||||
|
|
||||||
device_family: DeviceFamilyType
|
device_family: DeviceFamilyType
|
||||||
encryption_type: EncryptType
|
encryption_type: EncryptType
|
||||||
|
login_version: Optional[int] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_values(
|
def from_values(
|
||||||
device_family: str,
|
device_family: str,
|
||||||
encryption_type: str,
|
encryption_type: str,
|
||||||
|
login_version: Optional[int] = None,
|
||||||
) -> "ConnectionType":
|
) -> "ConnectionType":
|
||||||
"""Return connection parameters from string values."""
|
"""Return connection parameters from string values."""
|
||||||
try:
|
try:
|
||||||
return ConnectionType(
|
return ConnectionType(
|
||||||
DeviceFamilyType(device_family),
|
DeviceFamilyType(device_family),
|
||||||
EncryptType(encryption_type),
|
EncryptType(encryption_type),
|
||||||
|
login_version,
|
||||||
)
|
)
|
||||||
except ValueError as ex:
|
except (ValueError, TypeError) as ex:
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Invalid connection parameters for {device_family}.{encryption_type}"
|
f"Invalid connection parameters for {device_family}."
|
||||||
|
+ f"{encryption_type}.{login_version}"
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -94,18 +98,26 @@ class ConnectionType:
|
|||||||
and (device_family := connection_type_dict.get("device_family"))
|
and (device_family := connection_type_dict.get("device_family"))
|
||||||
and (encryption_type := connection_type_dict.get("encryption_type"))
|
and (encryption_type := connection_type_dict.get("encryption_type"))
|
||||||
):
|
):
|
||||||
return ConnectionType.from_values(device_family, encryption_type)
|
if login_version := connection_type_dict.get("login_version"):
|
||||||
|
login_version = int(login_version) # type: ignore[assignment]
|
||||||
|
return ConnectionType.from_values(
|
||||||
|
device_family,
|
||||||
|
encryption_type,
|
||||||
|
login_version, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
raise SmartDeviceException(
|
raise SmartDeviceException(
|
||||||
f"Invalid connection type data for {connection_type_dict}"
|
f"Invalid connection type data for {connection_type_dict}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, str]:
|
def to_dict(self) -> Dict[str, Union[str, int]]:
|
||||||
"""Convert connection params to dict."""
|
"""Convert connection params to dict."""
|
||||||
result = {
|
result: Dict[str, Union[str, int]] = {
|
||||||
"device_family": self.device_family.value,
|
"device_family": self.device_family.value,
|
||||||
"encryption_type": self.encryption_type.value,
|
"encryption_type": self.encryption_type.value,
|
||||||
}
|
}
|
||||||
|
if self.login_version:
|
||||||
|
result["login_version"] = self.login_version
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -118,10 +130,11 @@ class DeviceConfig:
|
|||||||
host: str
|
host: str
|
||||||
timeout: Optional[int] = DEFAULT_TIMEOUT
|
timeout: Optional[int] = DEFAULT_TIMEOUT
|
||||||
port_override: Optional[int] = None
|
port_override: Optional[int] = None
|
||||||
credentials: Credentials = field(default_factory=lambda: Credentials())
|
credentials: Optional[Credentials] = None
|
||||||
|
credentials_hash: Optional[str] = None
|
||||||
connection_type: ConnectionType = field(
|
connection_type: ConnectionType = field(
|
||||||
default_factory=lambda: ConnectionType(
|
default_factory=lambda: ConnectionType(
|
||||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor, 1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -130,15 +143,22 @@ class DeviceConfig:
|
|||||||
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
|
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.credentials is None:
|
|
||||||
self.credentials = Credentials()
|
|
||||||
if self.connection_type is None:
|
if self.connection_type is None:
|
||||||
self.connection_type = ConnectionType(
|
self.connection_type = ConnectionType(
|
||||||
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Dict[str, str]]:
|
def to_dict(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
credentials_hash: Optional[str] = None,
|
||||||
|
exclude_credentials: bool = False,
|
||||||
|
) -> Dict[str, Dict[str, str]]:
|
||||||
"""Convert connection params to dict."""
|
"""Convert connection params to dict."""
|
||||||
|
if credentials_hash or exclude_credentials:
|
||||||
|
self.credentials = None
|
||||||
|
if credentials_hash:
|
||||||
|
self.credentials_hash = credentials_hash
|
||||||
return _dataclass_to_dict(self)
|
return _dataclass_to_dict(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -422,7 +422,9 @@ class Discover:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
config.connection_type = ConnectionType.from_values(
|
config.connection_type = ConnectionType.from_values(
|
||||||
type_, discovery_result.mgt_encrypt_schm.encrypt_type
|
type_,
|
||||||
|
discovery_result.mgt_encrypt_schm.encrypt_type,
|
||||||
|
discovery_result.mgt_encrypt_schm.lv,
|
||||||
)
|
)
|
||||||
except SmartDeviceException as ex:
|
except SmartDeviceException as ex:
|
||||||
raise UnsupportedDeviceException(
|
raise UnsupportedDeviceException(
|
||||||
|
@ -41,6 +41,7 @@ https://github.com/python-kasa/python-kasa/pull/117
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
@ -99,8 +100,13 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
self._default_http_client: Optional[httpx.AsyncClient] = None
|
||||||
self._local_seed: Optional[bytes] = None
|
self._local_seed: Optional[bytes] = None
|
||||||
|
if not self._credentials and not self._credentials_hash:
|
||||||
|
self._credentials = Credentials()
|
||||||
|
if self._credentials:
|
||||||
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
self._local_auth_hash = self.generate_auth_hash(self._credentials)
|
||||||
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
self._local_auth_owner = self.generate_owner_hash(self._credentials).hex()
|
||||||
|
else:
|
||||||
|
self._local_auth_hash = base64.b64decode(self._credentials_hash.encode()) # type: ignore[union-attr]
|
||||||
self._kasa_setup_auth_hash = None
|
self._kasa_setup_auth_hash = None
|
||||||
self._blank_auth_hash = None
|
self._blank_auth_hash = None
|
||||||
self._handshake_lock = asyncio.Lock()
|
self._handshake_lock = asyncio.Lock()
|
||||||
@ -119,6 +125,11 @@ class KlapTransport(BaseTransport):
|
|||||||
"""Default port for the transport."""
|
"""Default port for the transport."""
|
||||||
return self.DEFAULT_PORT
|
return self.DEFAULT_PORT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_hash(self) -> str:
|
||||||
|
"""The hashed credentials used by the transport."""
|
||||||
|
return base64.b64encode(self._local_auth_hash).decode()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _http_client(self) -> httpx.AsyncClient:
|
def _http_client(self) -> httpx.AsyncClient:
|
||||||
if self._config.http_client:
|
if self._config.http_client:
|
||||||
|
@ -56,6 +56,7 @@ class BaseTransport(ABC):
|
|||||||
self._host = config.host
|
self._host = config.host
|
||||||
self._port = config.port_override or self.default_port
|
self._port = config.port_override or self.default_port
|
||||||
self._credentials = config.credentials
|
self._credentials = config.credentials
|
||||||
|
self._credentials_hash = config.credentials_hash
|
||||||
self._timeout = config.timeout
|
self._timeout = config.timeout
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -63,6 +64,11 @@ class BaseTransport(ABC):
|
|||||||
def default_port(self) -> int:
|
def default_port(self) -> int:
|
||||||
"""The default port for the transport."""
|
"""The default port for the transport."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def credentials_hash(self) -> str:
|
||||||
|
"""The hashed credentials used by the transport."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def send(self, request: str) -> Dict:
|
async def send(self, request: str) -> Dict:
|
||||||
"""Send a message to the device and return a response."""
|
"""Send a message to the device and return a response."""
|
||||||
@ -120,6 +126,11 @@ class _XorTransport(BaseTransport):
|
|||||||
"""Default port for the transport."""
|
"""Default port for the transport."""
|
||||||
return self.DEFAULT_PORT
|
return self.DEFAULT_PORT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_hash(self) -> str:
|
||||||
|
"""The hashed credentials used by the transport."""
|
||||||
|
return ""
|
||||||
|
|
||||||
async def send(self, request: str) -> Dict:
|
async def send(self, request: str) -> Dict:
|
||||||
"""Send a message to the device and return a response."""
|
"""Send a message to the device and return a response."""
|
||||||
return {}
|
return {}
|
||||||
|
@ -245,6 +245,11 @@ class SmartDevice:
|
|||||||
"""The device credentials."""
|
"""The device credentials."""
|
||||||
return self.protocol._transport._credentials
|
return self.protocol._transport._credentials
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_hash(self) -> Optional[str]:
|
||||||
|
"""Return the connection parameters the device is using."""
|
||||||
|
return self.protocol._transport.credentials_hash
|
||||||
|
|
||||||
def add_module(self, name: str, module: Module):
|
def add_module(self, name: str, module: Module):
|
||||||
"""Register a module."""
|
"""Register a module."""
|
||||||
if name in self.modules:
|
if name in self.modules:
|
||||||
|
@ -38,8 +38,8 @@ class TapoDevice(SmartDevice):
|
|||||||
|
|
||||||
async def update(self, update_children: bool = True):
|
async def update(self, update_children: bool = True):
|
||||||
"""Update the device."""
|
"""Update the device."""
|
||||||
if self.credentials is None:
|
if self.credentials is None and self.credentials_hash is None:
|
||||||
raise AuthenticationException("Device requires authentication.")
|
raise AuthenticationException("Tapo plug requires authentication.")
|
||||||
|
|
||||||
if self._components_raw is None:
|
if self._components_raw is None:
|
||||||
resp = await self.protocol.query("component_nego")
|
resp = await self.protocol.query("component_nego")
|
||||||
|
@ -430,6 +430,7 @@ def discovery_mock(all_fixture_data, mocker):
|
|||||||
query_data: dict
|
query_data: dict
|
||||||
device_type: str
|
device_type: str
|
||||||
encrypt_type: str
|
encrypt_type: str
|
||||||
|
login_version: Optional[int] = None
|
||||||
port_override: Optional[int] = None
|
port_override: Optional[int] = None
|
||||||
|
|
||||||
if "discovery_result" in all_fixture_data:
|
if "discovery_result" in all_fixture_data:
|
||||||
@ -438,6 +439,9 @@ def discovery_mock(all_fixture_data, mocker):
|
|||||||
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
|
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
|
||||||
"encrypt_type"
|
"encrypt_type"
|
||||||
]
|
]
|
||||||
|
login_version = all_fixture_data["discovery_result"]["mgt_encrypt_schm"].get(
|
||||||
|
"lv"
|
||||||
|
)
|
||||||
datagram = (
|
datagram = (
|
||||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||||
+ json_dumps(discovery_data).encode()
|
+ json_dumps(discovery_data).encode()
|
||||||
@ -450,12 +454,14 @@ def discovery_mock(all_fixture_data, mocker):
|
|||||||
all_fixture_data,
|
all_fixture_data,
|
||||||
device_type,
|
device_type,
|
||||||
encrypt_type,
|
encrypt_type,
|
||||||
|
login_version,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sys_info = all_fixture_data["system"]["get_sysinfo"]
|
sys_info = all_fixture_data["system"]["get_sysinfo"]
|
||||||
discovery_data = {"system": {"get_sysinfo": sys_info}}
|
discovery_data = {"system": {"get_sysinfo": sys_info}}
|
||||||
device_type = sys_info.get("mic_type") or sys_info.get("type")
|
device_type = sys_info.get("mic_type") or sys_info.get("type")
|
||||||
encrypt_type = "XOR"
|
encrypt_type = "XOR"
|
||||||
|
login_version = None
|
||||||
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
||||||
dm = _DiscoveryMock(
|
dm = _DiscoveryMock(
|
||||||
"127.0.0.123",
|
"127.0.0.123",
|
||||||
@ -465,6 +471,7 @@ def discovery_mock(all_fixture_data, mocker):
|
|||||||
all_fixture_data,
|
all_fixture_data,
|
||||||
device_type,
|
device_type,
|
||||||
encrypt_type,
|
encrypt_type,
|
||||||
|
login_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mock_discover(self):
|
def mock_discover(self):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import base64
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -320,6 +321,11 @@ class FakeSmartTransport(BaseTransport):
|
|||||||
"""Default port for the transport."""
|
"""Default port for the transport."""
|
||||||
return 80
|
return 80
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_hash(self):
|
||||||
|
"""The hashed credentials used by the transport."""
|
||||||
|
return self._credentials.username + self._credentials.password + "hash"
|
||||||
|
|
||||||
async def send(self, request: str):
|
async def send(self, request: str):
|
||||||
request_dict = json_loads(request)
|
request_dict = json_loads(request)
|
||||||
method = request_dict["method"]
|
method = request_dict["method"]
|
||||||
|
@ -19,3 +19,30 @@ def test_serialization():
|
|||||||
config2_dict = json_loads(config_json)
|
config2_dict = json_loads(config_json)
|
||||||
config2 = DeviceConfig.from_dict(config2_dict)
|
config2 = DeviceConfig.from_dict(config2_dict)
|
||||||
assert config == config2
|
assert config == config2
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_hash():
|
||||||
|
config = DeviceConfig(
|
||||||
|
host="Foo",
|
||||||
|
http_client=httpx.AsyncClient(),
|
||||||
|
credentials=Credentials("foo", "bar"),
|
||||||
|
)
|
||||||
|
config_dict = config.to_dict(credentials_hash="credhash")
|
||||||
|
config_json = json_dumps(config_dict)
|
||||||
|
config2_dict = json_loads(config_json)
|
||||||
|
config2 = DeviceConfig.from_dict(config2_dict)
|
||||||
|
assert config2.credentials_hash == "credhash"
|
||||||
|
assert config2.credentials is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_credentials_serialization():
|
||||||
|
config = DeviceConfig(
|
||||||
|
host="Foo",
|
||||||
|
http_client=httpx.AsyncClient(),
|
||||||
|
credentials=Credentials("foo", "bar"),
|
||||||
|
)
|
||||||
|
config_dict = config.to_dict(exclude_credentials=True)
|
||||||
|
config_json = json_dumps(config_dict)
|
||||||
|
config2_dict = json_loads(config_json)
|
||||||
|
config2 = DeviceConfig.from_dict(config2_dict)
|
||||||
|
assert config2.credentials is None
|
||||||
|
@ -110,11 +110,17 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
|
|||||||
assert update_mock.call_count == 0
|
assert update_mock.call_count == 0
|
||||||
|
|
||||||
ct = ConnectionType.from_values(
|
ct = ConnectionType.from_values(
|
||||||
discovery_mock.device_type, discovery_mock.encrypt_type
|
discovery_mock.device_type,
|
||||||
|
discovery_mock.encrypt_type,
|
||||||
|
discovery_mock.login_version,
|
||||||
)
|
)
|
||||||
uses_http = discovery_mock.default_port == 80
|
uses_http = discovery_mock.default_port == 80
|
||||||
config = DeviceConfig(
|
config = DeviceConfig(
|
||||||
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
|
host=host,
|
||||||
|
port_override=custom_port,
|
||||||
|
connection_type=ct,
|
||||||
|
uses_http=uses_http,
|
||||||
|
credentials=Credentials(),
|
||||||
)
|
)
|
||||||
assert x.config == config
|
assert x.config == config
|
||||||
|
|
||||||
|
@ -9,8 +9,11 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..aestransport import AesTransport
|
||||||
|
from ..credentials import Credentials
|
||||||
from ..deviceconfig import DeviceConfig
|
from ..deviceconfig import DeviceConfig
|
||||||
from ..exceptions import SmartDeviceException
|
from ..exceptions import SmartDeviceException
|
||||||
|
from ..klaptransport import KlapTransport, KlapTransportV2
|
||||||
from ..protocol import (
|
from ..protocol import (
|
||||||
BaseTransport,
|
BaseTransport,
|
||||||
TPLinkProtocol,
|
TPLinkProtocol,
|
||||||
@ -298,3 +301,19 @@ def test_transport_init_signature(class_name_obj):
|
|||||||
assert (
|
assert (
|
||||||
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"transport_class", [AesTransport, KlapTransport, KlapTransportV2, _XorTransport]
|
||||||
|
)
|
||||||
|
async def test_transport_credentials_hash(mocker, transport_class):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
|
||||||
|
credentials = Credentials("Foo", "Bar")
|
||||||
|
config = DeviceConfig(host, credentials=credentials)
|
||||||
|
transport = transport_class(config=config)
|
||||||
|
credentials_hash = transport.credentials_hash
|
||||||
|
config = DeviceConfig(host, credentials_hash=credentials_hash)
|
||||||
|
transport = transport_class(config=config)
|
||||||
|
|
||||||
|
assert transport.credentials_hash == credentials_hash
|
||||||
|
Loading…
Reference in New Issue
Block a user