mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-04-26 16:46:23 +00:00
Do not regenerate aes key pair (#1114)
And read it from `device_config` if provided. This is required as key generation can eat up cpu when a device is not fully available and the library is retrying.
This commit is contained in:
parent
2a89e58ae0
commit
fcf8f07232
@ -106,6 +106,9 @@ class AesTransport(BaseTransport):
|
|||||||
self._session_cookie: dict[str, str] | None = None
|
self._session_cookie: dict[str, str] | None = None
|
||||||
|
|
||||||
self._key_pair: KeyPair | None = None
|
self._key_pair: KeyPair | None = None
|
||||||
|
if config.aes_keys:
|
||||||
|
aes_keys = config.aes_keys
|
||||||
|
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"])
|
||||||
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
||||||
self._token_url: URL | None = None
|
self._token_url: URL | None = None
|
||||||
|
|
||||||
@ -271,7 +274,14 @@ class AesTransport(BaseTransport):
|
|||||||
can be made to the device.
|
can be made to the device.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug("Generating keypair")
|
_LOGGER.debug("Generating keypair")
|
||||||
self._key_pair = KeyPair.create_key_pair()
|
if not self._key_pair:
|
||||||
|
kp = KeyPair.create_key_pair()
|
||||||
|
self._config.aes_keys = {
|
||||||
|
"private": kp.get_private_key(),
|
||||||
|
"public": kp.get_public_key(),
|
||||||
|
}
|
||||||
|
self._key_pair = kp
|
||||||
|
|
||||||
pub_key = (
|
pub_key = (
|
||||||
"-----BEGIN PUBLIC KEY-----\n"
|
"-----BEGIN PUBLIC KEY-----\n"
|
||||||
+ self._key_pair.get_public_key() # type: ignore[union-attr]
|
+ self._key_pair.get_public_key() # type: ignore[union-attr]
|
||||||
@ -286,7 +296,6 @@ class AesTransport(BaseTransport):
|
|||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
_LOGGER.debug("Will perform handshaking...")
|
_LOGGER.debug("Will perform handshaking...")
|
||||||
|
|
||||||
self._key_pair = None
|
|
||||||
self._token_url = None
|
self._token_url = None
|
||||||
self._session_expire_at = None
|
self._session_expire_at = None
|
||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
@ -34,7 +34,7 @@ Living Room Bulb
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Optional, TypedDict, Union
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .exceptions import KasaException
|
from .exceptions import KasaException
|
||||||
@ -45,6 +45,13 @@ if TYPE_CHECKING:
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyPairDict(TypedDict):
|
||||||
|
"""Class to represent a public/private key pair."""
|
||||||
|
|
||||||
|
private: str
|
||||||
|
public: str
|
||||||
|
|
||||||
|
|
||||||
class DeviceEncryptionType(Enum):
|
class DeviceEncryptionType(Enum):
|
||||||
"""Encrypt type enum."""
|
"""Encrypt type enum."""
|
||||||
|
|
||||||
@ -182,7 +189,7 @@ class DeviceConfig:
|
|||||||
#: The batch size for protoools supporting multiple request batches.
|
#: The batch size for protoools supporting multiple request batches.
|
||||||
connection_type: DeviceConnectionParameters = field(
|
connection_type: DeviceConnectionParameters = field(
|
||||||
default_factory=lambda: DeviceConnectionParameters(
|
default_factory=lambda: DeviceConnectionParameters(
|
||||||
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor, 1
|
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
#: True if the device uses http. Consumers should retrieve rather than set this
|
#: True if the device uses http. Consumers should retrieve rather than set this
|
||||||
@ -193,6 +200,8 @@ class DeviceConfig:
|
|||||||
#: Set a custom http_client for the device to use.
|
#: Set a custom http_client for the device to use.
|
||||||
http_client: Optional["ClientSession"] = field(default=None, compare=False)
|
http_client: Optional["ClientSession"] = field(default=None, compare=False)
|
||||||
|
|
||||||
|
aes_keys: Optional[KeyPairDict] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.connection_type is None:
|
if self.connection_type is None:
|
||||||
self.connection_type = DeviceConnectionParameters(
|
self.connection_type = DeviceConnectionParameters(
|
||||||
|
@ -80,6 +80,29 @@ async def test_handshake(
|
|||||||
assert transport._state is TransportState.LOGIN_REQUIRED
|
assert transport._state is TransportState.LOGIN_REQUIRED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_handshake_with_keys(mocker):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
mock_aes_device = MockAesDevice(host)
|
||||||
|
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
|
||||||
|
|
||||||
|
test_keys = {
|
||||||
|
"private": "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAMo/JQpXIbP2M3bLOKyfEVCURFCxHIXv4HDME8J58AL4BwGDXf0oQycgj9nV+T/MzgEd/4iVysYuYfLuIEKXADP7Lby6AfA/dbcinZZ7bLUNMNa7TaylIvVKtSfR0LV8AmG0jdQYkr4cTzLAEd+AEs/wG3nMQNEcoQRVY+svLPDjAgMBAAECgYBCsDOch0KbvrEVmMklUoY5Fcq4+M249HIDf6d8VwznTbWxsAmL8nzCKCCG6eF4QiYjhCrAdPQaCS1PF2oXywbLhngid/9W9gz4CKKDJChs1X8KvLi+TLg1jgJUXvq9yVNh1CB+lS2ho4gdDDCbVmiVOZR5TDfEf0xeJ+Zz3zlUEQJBAPkhuNdc3yRue8huFZbrWwikURQPYBxLOYfVTDsfV9mZGSkGoWS1FPDsxrqSXugTmcTRuw+lrXKDabJ72kqywA8CQQDP0oaGh5r7F12Xzcwb7X9JkTvyr+rO8YgVtKNBaNVOPabAzysNwOlvH/sNCVQcRj8rn5LNXitgLx6T+Q5uqa3tAkA7J0elUzbkhps7ju/vYri9x448zh3K+g2R9BJio2GPmCuCM0HVEK4FOqNBH4oLXsQPGKFq6LLTUuKg74l4XRL/AkBHBO6r8pNn0yhMxCtIL/UbsuIFoVBgv/F9WWmg5K5gOnlN0n4oCRC8xPUKE3IG54qW4cVNIS05hWCxuJ7R+nJRAkByt/+kX1nQxis2wIXj90fztXG3oSmoVaieYxaXPxlWvX3/Q5kslFF5UsGy9gcK0v2PXhqjTbhud3/X0Er6YP4v",
|
||||||
|
"public": "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDKPyUKVyGz9jN2yzisnxFQlERQsRyF7+BwzBPCefAC+AcBg139KEMnII/Z1fk/zM4BHf+IlcrGLmHy7iBClwAz+y28ugHwP3W3Ip2We2y1DTDWu02spSL1SrUn0dC1fAJhtI3UGJK+HE8ywBHfgBLP8Bt5zEDRHKEEVWPrLyzw4wIDAQAB",
|
||||||
|
}
|
||||||
|
transport = AesTransport(
|
||||||
|
config=DeviceConfig(
|
||||||
|
host, credentials=Credentials("foo", "bar"), aes_keys=test_keys
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transport._encryption_session is None
|
||||||
|
assert transport._state is TransportState.HANDSHAKE_REQUIRED
|
||||||
|
|
||||||
|
await transport.perform_handshake()
|
||||||
|
assert transport._key_pair.get_private_key() == test_keys["private"]
|
||||||
|
assert transport._key_pair.get_public_key() == test_keys["public"]
|
||||||
|
|
||||||
|
|
||||||
@status_parameters
|
@status_parameters
|
||||||
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
|
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
@ -97,6 +120,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
|||||||
with expectation:
|
with expectation:
|
||||||
await transport.perform_login()
|
await transport.perform_login()
|
||||||
assert mock_aes_device.token in str(transport._token_url)
|
assert mock_aes_device.token in str(transport._token_url)
|
||||||
|
assert transport._config.aes_keys == transport._key_pair
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user