Cleanup credentials handling (#605)

* credentials: don't allow none to simplify checks

* Implement __bool__ for credentials

* Cleanup klaptransport cred usage

* Cleanup deviceconfig and tapodevice

* fix linting

* Pass dummy credentials for tests

* Remove __bool__ dunder and add docs to credentials

* Check for cred noneness in tapodevice.update()
This commit is contained in:
Teemu R 2024-01-03 19:26:52 +01:00 committed by GitHub
parent 10fc2c3c54
commit 30c4e6a6a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 29 additions and 17 deletions

View File

@ -1,12 +1,13 @@
"""Credentials class for username / passwords.""" """Credentials class for username / passwords."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
@dataclass @dataclass
class Credentials: class Credentials:
"""Credentials for authentication.""" """Credentials for authentication."""
username: Optional[str] = field(default="", repr=False) #: Username (email address) of the cloud account
password: Optional[str] = field(default="", repr=False) username: str = field(default="", repr=False)
#: Password of the cloud account
password: str = field(default="", repr=False)

View File

@ -117,9 +117,7 @@ class DeviceConfig:
host: str host: str
timeout: Optional[int] = DEFAULT_TIMEOUT timeout: Optional[int] = DEFAULT_TIMEOUT
port_override: Optional[int] = None port_override: Optional[int] = None
credentials: Credentials = field( credentials: Credentials = field(default_factory=lambda: Credentials())
default_factory=lambda: Credentials(username="", password="")
)
connection_type: ConnectionType = field( connection_type: ConnectionType = field(
default_factory=lambda: ConnectionType( default_factory=lambda: ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor
@ -132,7 +130,7 @@ class DeviceConfig:
def __post_init__(self): def __post_init__(self):
if self.credentials is None: if self.credentials is None:
self.credentials = Credentials(username="", password="") self.credentials = Credentials()
if self.connection_type is None: if self.connection_type is None:
self.connection_type = ConnectionType( self.connection_type = ConnectionType(
DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor

View File

@ -26,6 +26,7 @@ class AuthenticationException(SmartDeviceException):
class RetryableException(SmartDeviceException): class RetryableException(SmartDeviceException):
"""Retryable exception for device errors.""" """Retryable exception for device errors."""
class TimeoutException(SmartDeviceException): class TimeoutException(SmartDeviceException):
"""Timeout exception for device errors.""" """Timeout exception for device errors."""

View File

@ -221,7 +221,8 @@ class KlapTransport(BaseTransport):
return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore
# Finally check against blank credentials if not already blank # Finally check against blank credentials if not already blank
if self._credentials != (blank_creds := Credentials(username="", password="")): blank_creds = Credentials()
if self._credentials != blank_creds:
if not self._blank_auth_hash: if not self._blank_auth_hash:
self._blank_auth_hash = self.generate_auth_hash(blank_creds) self._blank_auth_hash = self.generate_auth_hash(blank_creds)
@ -369,8 +370,8 @@ class KlapTransport(BaseTransport):
@staticmethod @staticmethod
def generate_auth_hash(creds: Credentials): def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or "" un = creds.username
pw = creds.password or "" pw = creds.password
return md5(md5(un.encode()) + md5(pw.encode())) return md5(md5(un.encode()) + md5(pw.encode()))
@ -391,7 +392,7 @@ class KlapTransport(BaseTransport):
@staticmethod @staticmethod
def generate_owner_hash(creds: Credentials): def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object.""" """Return the MD5 hash of the username in this object."""
un = creds.username or "" un = creds.username
return md5(un.encode()) return md5(un.encode())
@ -401,8 +402,8 @@ class KlapTransportV2(KlapTransport):
@staticmethod @staticmethod
def generate_auth_hash(creds: Credentials): def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials.""" """Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or "" un = creds.username
pw = creds.password or "" pw = creds.password
return _sha256(_sha1(un.encode()) + _sha1(pw.encode())) return _sha256(_sha1(un.encode()) + _sha1(pw.encode()))

View File

@ -38,8 +38,8 @@ class TapoDevice(SmartDevice):
async def update(self, update_children: bool = True): async def update(self, update_children: bool = True):
"""Update the device.""" """Update the device."""
if self.credentials is None or self.credentials.username is None: if self.credentials is None:
raise AuthenticationException("Tapo plug requires authentication.") raise AuthenticationException("Device requires authentication.")
if self._components_raw is None: if self._components_raw is None:
resp = await self.protocol.query("component_nego") resp = await self.protocol.query("component_nego")

View File

@ -305,7 +305,13 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport): class FakeSmartTransport(BaseTransport):
def __init__(self, info): def __init__(self, info):
super().__init__( super().__init__(
config=DeviceConfig("127.0.0.123", credentials=Credentials()), config=DeviceConfig(
"127.0.0.123",
credentials=Credentials(
username="dummy_user",
password="dummy_password", # noqa: S106
),
),
) )
self.info = info self.info = info

View File

@ -76,7 +76,12 @@ async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
host = "127.0.0.1" host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data) ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype) config = DeviceConfig(
host=host,
port_override=custom_port,
connection_type=ctype,
credentials=Credentials("dummy_user", "dummy_password"),
)
default_port = 80 if "discovery_result" in all_fixture_data else 9999 default_port = 80 if "discovery_result" in all_fixture_data else 9999
ctype, _ = _get_connection_type_device_class(all_fixture_data) ctype, _ = _get_connection_type_device_class(all_fixture_data)