mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-06 18:54:08 +00:00
Add DeviceConfig to allow specifying configuration parameters (#569)
* Add DeviceConfig handling * Update post review * Further update post latest review * Update following latest review * Update docstrings and docs
This commit is contained in:
@@ -388,7 +388,6 @@ async def get_device_for_file(file, protocol):
|
||||
d = device_for_file(model, protocol)(host="127.0.0.123")
|
||||
if protocol == "SMART":
|
||||
d.protocol = FakeSmartProtocol(sysinfo)
|
||||
d.credentials = Credentials("", "")
|
||||
else:
|
||||
d.protocol = FakeTransportProtocol(sysinfo)
|
||||
await _update_and_close(d)
|
||||
@@ -426,28 +425,53 @@ def discovery_mock(all_fixture_data, mocker):
|
||||
class _DiscoveryMock:
|
||||
ip: str
|
||||
default_port: int
|
||||
discovery_port: int
|
||||
discovery_data: dict
|
||||
query_data: dict
|
||||
device_type: str
|
||||
encrypt_type: str
|
||||
port_override: Optional[int] = None
|
||||
|
||||
if "discovery_result" in all_fixture_data:
|
||||
discovery_data = {"result": all_fixture_data["discovery_result"]}
|
||||
device_type = all_fixture_data["discovery_result"]["device_type"]
|
||||
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
|
||||
"encrypt_type"
|
||||
]
|
||||
datagram = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data)
|
||||
dm = _DiscoveryMock(
|
||||
"127.0.0.123",
|
||||
80,
|
||||
20002,
|
||||
discovery_data,
|
||||
all_fixture_data,
|
||||
device_type,
|
||||
encrypt_type,
|
||||
)
|
||||
else:
|
||||
sys_info = all_fixture_data["system"]["get_sysinfo"]
|
||||
discovery_data = {"system": {"get_sysinfo": sys_info}}
|
||||
device_type = sys_info.get("mic_type") or sys_info.get("type")
|
||||
encrypt_type = "XOR"
|
||||
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
|
||||
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data)
|
||||
dm = _DiscoveryMock(
|
||||
"127.0.0.123",
|
||||
9999,
|
||||
9999,
|
||||
discovery_data,
|
||||
all_fixture_data,
|
||||
device_type,
|
||||
encrypt_type,
|
||||
)
|
||||
|
||||
def mock_discover(self):
|
||||
port = (
|
||||
dm.port_override
|
||||
if dm.port_override and dm.default_port != 20002
|
||||
else dm.default_port
|
||||
if dm.port_override and dm.discovery_port != 20002
|
||||
else dm.discovery_port
|
||||
)
|
||||
self.datagram_received(
|
||||
datagram,
|
||||
|
@@ -15,7 +15,9 @@ from voluptuous import (
|
||||
Schema,
|
||||
)
|
||||
|
||||
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -290,7 +292,9 @@ TIME_MODULE = {
|
||||
|
||||
class FakeSmartProtocol(SmartProtocol):
|
||||
def __init__(self, info):
|
||||
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
|
||||
super().__init__(
|
||||
transport=FakeSmartTransport(info),
|
||||
)
|
||||
|
||||
async def query(self, request, retry_count: int = 3):
|
||||
"""Implement query here so can still patch SmartProtocol.query."""
|
||||
@@ -301,10 +305,15 @@ class FakeSmartProtocol(SmartProtocol):
|
||||
class FakeSmartTransport(BaseTransport):
|
||||
def __init__(self, info):
|
||||
super().__init__(
|
||||
"127.0.0.123",
|
||||
config=DeviceConfig("127.0.0.123", credentials=Credentials()),
|
||||
)
|
||||
self.info = info
|
||||
|
||||
@property
|
||||
def default_port(self):
|
||||
"""Default port for the transport."""
|
||||
return 80
|
||||
|
||||
async def send(self, request: str):
|
||||
request_dict = json_loads(request)
|
||||
method = request_dict["method"]
|
||||
@@ -344,6 +353,11 @@ class FakeSmartTransport(BaseTransport):
|
||||
|
||||
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
|
||||
def __init__(self, info):
|
||||
super().__init__(
|
||||
transport=_XorTransport(
|
||||
config=DeviceConfig("127.0.0.123"),
|
||||
)
|
||||
)
|
||||
self.discovery_data = info
|
||||
self.writer = None
|
||||
self.reader = None
|
||||
|
@@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
|
||||
|
||||
from ..aestransport import AesEncyptionSession, AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
@@ -58,7 +59,9 @@ async def test_handshake(
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
)
|
||||
|
||||
assert transport._encryption_session is None
|
||||
assert transport._handshake_done is False
|
||||
@@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
)
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
@@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
|
||||
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(
|
||||
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
)
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
transport._login_token = mock_aes_device.token
|
||||
|
||||
un, pw = transport.hash_credentials(True)
|
||||
request = {
|
||||
"method": "get_device_info",
|
||||
"params": None,
|
||||
@@ -119,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code):
|
||||
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
|
||||
|
||||
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
transport = AesTransport(config=config)
|
||||
transport._handshake_done = True
|
||||
transport._session_expire_at = time.time() + 86400
|
||||
transport._encryption_session = mock_aes_device.encryption_session
|
||||
|
@@ -4,10 +4,26 @@ import asyncclick as click
|
||||
import pytest
|
||||
from asyncclick.testing import CliRunner
|
||||
|
||||
from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException
|
||||
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
|
||||
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
|
||||
from kasa.discover import Discover
|
||||
from kasa import (
|
||||
AuthenticationException,
|
||||
Credentials,
|
||||
SmartDevice,
|
||||
TPLinkSmartHomeProtocol,
|
||||
UnsupportedDeviceException,
|
||||
)
|
||||
from kasa.cli import (
|
||||
TYPE_TO_CLASS,
|
||||
alias,
|
||||
brightness,
|
||||
cli,
|
||||
emeter,
|
||||
raw_command,
|
||||
state,
|
||||
sysinfo,
|
||||
toggle,
|
||||
)
|
||||
from kasa.discover import Discover, DiscoveryResult
|
||||
from kasa.smartprotocol import SmartProtocol
|
||||
|
||||
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
|
||||
|
||||
@@ -145,9 +161,11 @@ async def test_credentials(discovery_mock, mocker):
|
||||
)
|
||||
|
||||
mocker.patch("kasa.cli.state", new=_state)
|
||||
for subclass in DEVICE_TYPE_TO_CLASS.values():
|
||||
mocker.patch.object(subclass, "update")
|
||||
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data)
|
||||
|
||||
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
|
||||
runner = CliRunner()
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
@@ -158,6 +176,10 @@ async def test_credentials(discovery_mock, mocker):
|
||||
"foo",
|
||||
"--password",
|
||||
"bar",
|
||||
"--device-family",
|
||||
dr.device_type,
|
||||
"--encrypt-type",
|
||||
dr.mgt_encrypt_schm.encrypt_type,
|
||||
],
|
||||
)
|
||||
assert res.exit_code == 0
|
||||
@@ -166,7 +188,7 @@ async def test_credentials(discovery_mock, mocker):
|
||||
|
||||
|
||||
@device_iot
|
||||
async def test_without_device_type(discovery_data: dict, dev, mocker):
|
||||
async def test_without_device_type(dev, mocker):
|
||||
"""Test connecting without the device type."""
|
||||
runner = CliRunner()
|
||||
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
|
||||
@@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker):
|
||||
|
||||
assert res.exit_code != 0
|
||||
assert isinstance(res.exception, AuthenticationException)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS))
|
||||
async def test_type_param(device_type, mocker):
|
||||
"""Test for handling only one of username or password supplied."""
|
||||
runner = CliRunner()
|
||||
|
||||
result_device = FileNotFoundError
|
||||
pass_dev = click.make_pass_decorator(SmartDevice)
|
||||
|
||||
@pass_dev
|
||||
async def _state(dev: SmartDevice):
|
||||
nonlocal result_device
|
||||
result_device = dev
|
||||
|
||||
mocker.patch("kasa.cli.state", new=_state)
|
||||
expected_type = TYPE_TO_CLASS[device_type]
|
||||
mocker.patch.object(expected_type, "update")
|
||||
res = await runner.invoke(
|
||||
cli,
|
||||
["--type", device_type, "--host", "127.0.0.1"],
|
||||
)
|
||||
assert res.exit_code == 0
|
||||
assert isinstance(result_device, expected_type)
|
||||
|
@@ -2,6 +2,7 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
@@ -15,122 +16,138 @@ from kasa import (
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
)
|
||||
from kasa.device_factory import (
|
||||
DEVICE_TYPE_TO_CLASS,
|
||||
connect,
|
||||
get_protocol_from_connection_name,
|
||||
from kasa.device_factory import connect, get_protocol
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult
|
||||
from kasa.iotprotocol import IotProtocol
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect(discovery_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
def _get_connection_type_device_class(the_fixture_data):
|
||||
if "discovery_result" in the_fixture_data:
|
||||
discovery_info = {"result": the_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
connection_type = ConnectionType.from_values(
|
||||
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
connection_type = ConnectionType.from_values(
|
||||
DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value
|
||||
)
|
||||
device_class = Discover._get_device_class(the_fixture_data)
|
||||
|
||||
return connection_type, device_class
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
@pytest.mark.parametrize(
|
||||
("device_type", "klass"),
|
||||
(
|
||||
(DeviceType.Plug, SmartPlug),
|
||||
(DeviceType.Bulb, SmartBulb),
|
||||
(DeviceType.Dimmer, SmartDimmer),
|
||||
(DeviceType.LightStrip, SmartLightStrip),
|
||||
(DeviceType.Unknown, SmartDevice),
|
||||
),
|
||||
)
|
||||
async def test_connect_passed_device_type(
|
||||
discovery_data: dict,
|
||||
mocker,
|
||||
device_type: DeviceType,
|
||||
klass: Type[SmartDevice],
|
||||
custom_port,
|
||||
):
|
||||
"""Make sure that connect with a passed device type."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port, device_type=device_type)
|
||||
assert isinstance(dev, klass)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
|
||||
|
||||
async def test_connect_query_fails(discovery_data: dict, mocker):
|
||||
"""Make sure that connect fails when query fails."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
|
||||
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
|
||||
|
||||
async def test_connect_logs_connect_time(
|
||||
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
host = "127.0.0.1"
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(host)
|
||||
assert "seconds to connect" in caplog.text
|
||||
|
||||
|
||||
async def test_connect_pass_protocol(
|
||||
async def test_connect(
|
||||
all_fixture_data: dict,
|
||||
mocker,
|
||||
):
|
||||
"""Test that if the protocol is passed in it's gets set correctly."""
|
||||
if "discovery_result" in all_fixture_data:
|
||||
discovery_info = {"result": all_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
else:
|
||||
device_class = Discover._get_device_class(all_fixture_data)
|
||||
|
||||
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
|
||||
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
|
||||
]
|
||||
"""Test that if the protocol is passed in it gets set correctly."""
|
||||
host = "127.0.0.1"
|
||||
if "discovery_result" in all_fixture_data:
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
ctype, device_class = _get_connection_type_device_class(all_fixture_data)
|
||||
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
connection_name = (
|
||||
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
protocol_class = get_protocol_from_connection_name(
|
||||
connection_name, host
|
||||
).__class__
|
||||
else:
|
||||
mocker.patch(
|
||||
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
|
||||
)
|
||||
protocol_class = TPLinkSmartHomeProtocol
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
protocol_class = get_protocol(config).__class__
|
||||
|
||||
dev = await connect(
|
||||
host,
|
||||
device_type=device_type,
|
||||
protocol_class=protocol_class,
|
||||
credentials=Credentials("", ""),
|
||||
config=config,
|
||||
)
|
||||
assert isinstance(dev, device_class)
|
||||
assert isinstance(dev.protocol, protocol_class)
|
||||
|
||||
assert dev.config == config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype)
|
||||
default_port = 80 if "discovery_result" in all_fixture_data else 9999
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
dev = await connect(config=config)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == default_port
|
||||
|
||||
|
||||
async def test_connect_logs_connect_time(
|
||||
all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
host = "127.0.0.1"
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(
|
||||
config=config,
|
||||
)
|
||||
assert "seconds to update" in caplog.text
|
||||
|
||||
|
||||
async def test_connect_query_fails(all_fixture_data: dict, mocker):
|
||||
"""Make sure that connect fails when query fails."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
|
||||
mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException)
|
||||
mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException)
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(config=config)
|
||||
|
||||
|
||||
async def test_connect_http_client(all_fixture_data, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client != http_client
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host,
|
||||
credentials=Credentials("foor", "bar"),
|
||||
connection_type=ctype,
|
||||
http_client=http_client,
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client == http_client
|
||||
|
21
kasa/tests/test_deviceconfig.py
Normal file
21
kasa/tests/test_deviceconfig.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
|
||||
import httpx
|
||||
|
||||
from kasa.credentials import Credentials
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
|
||||
|
||||
def test_serialization():
|
||||
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
|
||||
config_dict = config.to_dict()
|
||||
config_json = json_dumps(config_dict)
|
||||
config2_dict = json_loads(config_json)
|
||||
config2 = DeviceConfig.from_dict(config2_dict)
|
||||
assert config == config2
|
@@ -1,21 +1,29 @@
|
||||
# type: ignore
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
|
||||
import httpx
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
Credentials,
|
||||
DeviceType,
|
||||
Discover,
|
||||
SmartDevice,
|
||||
SmartDeviceException,
|
||||
SmartStrip,
|
||||
protocol,
|
||||
)
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
|
||||
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
|
||||
|
||||
from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip
|
||||
from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip
|
||||
|
||||
UNSUPPORTED = {
|
||||
"result": {
|
||||
@@ -89,13 +97,26 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
discovery_mock.port_override = custom_port
|
||||
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||
|
||||
x = await Discover.discover_single(host, port=custom_port)
|
||||
device_class = Discover._get_device_class(discovery_mock.discovery_data)
|
||||
update_mock = mocker.patch.object(device_class, "update")
|
||||
|
||||
x = await Discover.discover_single(
|
||||
host, port=custom_port, credentials=Credentials()
|
||||
)
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._discovery_info is not None
|
||||
assert x.port == custom_port or x.port == discovery_mock.default_port
|
||||
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
assert update_mock.call_count == 0
|
||||
|
||||
ct = ConnectionType.from_values(
|
||||
discovery_mock.device_type, discovery_mock.encrypt_type
|
||||
)
|
||||
uses_http = discovery_mock.default_port == 80
|
||||
config = DeviceConfig(
|
||||
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
|
||||
)
|
||||
assert x.config == config
|
||||
|
||||
|
||||
async def test_discover_single_hostname(discovery_mock, mocker):
|
||||
@@ -104,47 +125,39 @@ async def test_discover_single_hostname(discovery_mock, mocker):
|
||||
ip = "127.0.0.1"
|
||||
|
||||
discovery_mock.ip = ip
|
||||
update_mock = mocker.patch.object(SmartStrip, "update")
|
||||
device_class = Discover._get_device_class(discovery_mock.discovery_data)
|
||||
update_mock = mocker.patch.object(device_class, "update")
|
||||
|
||||
x = await Discover.discover_single(host)
|
||||
x = await Discover.discover_single(host, credentials=Credentials())
|
||||
assert issubclass(x.__class__, SmartDevice)
|
||||
assert x._discovery_info is not None
|
||||
assert x.host == host
|
||||
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
|
||||
assert update_mock.call_count == 0
|
||||
|
||||
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
|
||||
with pytest.raises(SmartDeviceException):
|
||||
x = await Discover.discover_single(host)
|
||||
x = await Discover.discover_single(host, credentials=Credentials())
|
||||
|
||||
|
||||
async def test_discover_single_unsupported(mocker):
|
||||
async def test_discover_single_unsupported(unsupported_device_info, mocker):
|
||||
"""Make sure that discover_single handles unsupported devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
self.datagram_received(data, (host, 20002))
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
|
||||
# Test with a valid unsupported response
|
||||
discovery_data = UNSUPPORTED
|
||||
with pytest.raises(
|
||||
UnsupportedDeviceException,
|
||||
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
|
||||
):
|
||||
await Discover.discover_single(host)
|
||||
|
||||
# Test with no response
|
||||
discovery_data = None
|
||||
|
||||
async def test_discover_single_no_response(mocker):
|
||||
"""Make sure that discover_single handles no response correctly."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover")
|
||||
with pytest.raises(
|
||||
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
|
||||
):
|
||||
await Discover.discover_single(host, timeout=0.001)
|
||||
await Discover.discover_single(host, discovery_timeout=0)
|
||||
|
||||
|
||||
INVALIDS = [
|
||||
@@ -241,52 +254,82 @@ AUTHENTICATION_DATA_KLAP = {
|
||||
}
|
||||
|
||||
|
||||
async def test_discover_single_authentication(mocker):
|
||||
@new_discovery
|
||||
async def test_discover_single_authentication(discovery_mock, mocker):
|
||||
"""Make sure that discover_single handles authenticating devices correctly."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
def mock_discover(self):
|
||||
if discovery_data:
|
||||
data = (
|
||||
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
|
||||
+ json_dumps(discovery_data).encode()
|
||||
)
|
||||
self.datagram_received(data, (host, 20002))
|
||||
|
||||
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
|
||||
discovery_mock.ip = host
|
||||
device_class = Discover._get_device_class(discovery_mock.discovery_data)
|
||||
mocker.patch.object(
|
||||
SmartDevice,
|
||||
device_class,
|
||||
"update",
|
||||
side_effect=AuthenticationException("Failed to authenticate"),
|
||||
)
|
||||
|
||||
# Test with a valid unsupported response
|
||||
discovery_data = AUTHENTICATION_DATA_KLAP
|
||||
with pytest.raises(
|
||||
AuthenticationException,
|
||||
match="Failed to authenticate",
|
||||
):
|
||||
device = await Discover.discover_single(host)
|
||||
device = await Discover.discover_single(
|
||||
host, credentials=Credentials("foo", "bar")
|
||||
)
|
||||
await device.update()
|
||||
|
||||
mocker.patch.object(SmartDevice, "update")
|
||||
device = await Discover.discover_single(host)
|
||||
mocker.patch.object(device_class, "update")
|
||||
device = await Discover.discover_single(host, credentials=Credentials("foo", "bar"))
|
||||
await device.update()
|
||||
assert device.device_type == DeviceType.Plug
|
||||
assert isinstance(device, device_class)
|
||||
|
||||
|
||||
async def test_device_update_from_new_discovery_info():
|
||||
@new_discovery
|
||||
async def test_device_update_from_new_discovery_info(discovery_data):
|
||||
device = SmartDevice("127.0.0.7")
|
||||
discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"])
|
||||
discover_info = DiscoveryResult(**discovery_data["result"])
|
||||
discover_dump = discover_info.get_dict()
|
||||
discover_dump["alias"] = "foobar"
|
||||
discover_dump["model"] = discover_dump["device_model"]
|
||||
device.update_from_discover_info(discover_dump)
|
||||
|
||||
assert device.alias == discover_dump["alias"]
|
||||
assert device.alias == "foobar"
|
||||
assert device.mac == discover_dump["mac"].replace("-", ":")
|
||||
assert device.model == discover_dump["model"]
|
||||
assert device.model == discover_dump["device_model"]
|
||||
|
||||
with pytest.raises(
|
||||
SmartDeviceException,
|
||||
match=re.escape("You need to await update() to access the data"),
|
||||
):
|
||||
assert device.supported_modules
|
||||
|
||||
|
||||
async def test_discover_single_http_client(discovery_mock, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
x: SmartDevice = await Discover.discover_single(host)
|
||||
|
||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||
|
||||
if discovery_mock.default_port == 80:
|
||||
assert x.protocol._transport._http_client != http_client
|
||||
x.config.http_client = http_client
|
||||
assert x.protocol._transport._http_client == http_client
|
||||
|
||||
|
||||
async def test_discover_http_client(discovery_mock, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
discovery_mock.ip = host
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
devices = await Discover.discover(discovery_timeout=0)
|
||||
x: SmartDevice = devices[host]
|
||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||
|
||||
if discovery_mock.default_port == 80:
|
||||
assert x.protocol._transport._http_client != http_client
|
||||
x.config.http_client = http_client
|
||||
assert x.protocol._transport._http_client == http_client
|
||||
|
@@ -12,9 +12,15 @@ import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
||||
from ..iotprotocol import IotProtocol
|
||||
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
|
||||
from ..klaptransport import (
|
||||
KlapEncryptionSession,
|
||||
KlapTransport,
|
||||
KlapTransportV2,
|
||||
_sha256,
|
||||
)
|
||||
from ..smartprotocol import SmartProtocol
|
||||
|
||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
|
||||
@@ -31,8 +37,9 @@ class _mock_response:
|
||||
[
|
||||
(Exception("dummy exception"), True),
|
||||
(SmartDeviceException("dummy exception"), False),
|
||||
(httpx.ConnectError("dummy exception"), True),
|
||||
],
|
||||
ids=("Exception", "SmartDeviceException"),
|
||||
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
|
||||
)
|
||||
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
|
||||
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
|
||||
@@ -42,8 +49,10 @@ async def test_protocol_retries(
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
|
||||
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=retry_count
|
||||
)
|
||||
|
||||
@@ -60,10 +69,11 @@ async def test_protocol_no_retry_on_connection_error(
|
||||
conn = mocker.patch.object(
|
||||
httpx.AsyncClient,
|
||||
"post",
|
||||
side_effect=httpx.ConnectError("foo"),
|
||||
side_effect=AuthenticationException("foo"),
|
||||
)
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=5
|
||||
)
|
||||
|
||||
@@ -81,8 +91,9 @@ async def test_protocol_retry_recoverable_error(
|
||||
"post",
|
||||
side_effect=httpx.CloseError("foo"),
|
||||
)
|
||||
config = DeviceConfig(host)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol_class(host, transport=transport_class(host)).query(
|
||||
await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=5
|
||||
)
|
||||
|
||||
@@ -115,7 +126,8 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
||||
side_effect=_fail_one_less_than_retry_count,
|
||||
)
|
||||
|
||||
response = await protocol_class(host, transport=transport_class(host)).query(
|
||||
config = DeviceConfig(host)
|
||||
response = await protocol_class(transport=transport_class(config=config)).query(
|
||||
DUMMY_QUERY, retry_count=retry_count
|
||||
)
|
||||
assert "result" in response or "foobar" in response
|
||||
@@ -136,7 +148,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
seed = secrets.token_bytes(16)
|
||||
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
|
||||
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
|
||||
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
|
||||
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
||||
protocol._transport._handshake_done = True
|
||||
protocol._transport._session_expire_at = time.time() + 86400
|
||||
@@ -181,7 +195,7 @@ def test_encrypt_unicode():
|
||||
"device_credentials, expectation",
|
||||
[
|
||||
(Credentials("foo", "bar"), does_not_raise()),
|
||||
(Credentials("", ""), does_not_raise()),
|
||||
(Credentials(), does_not_raise()),
|
||||
(
|
||||
Credentials(
|
||||
KlapTransport.KASA_SETUP_EMAIL,
|
||||
@@ -196,30 +210,37 @@ def test_encrypt_unicode():
|
||||
],
|
||||
ids=("client", "blank", "kasa_setup", "shouldfail"),
|
||||
)
|
||||
async def test_handshake1(mocker, device_credentials, expectation):
|
||||
@pytest.mark.parametrize(
|
||||
"transport_class, seed_auth_hash_calc",
|
||||
[
|
||||
pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"),
|
||||
pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"),
|
||||
],
|
||||
)
|
||||
async def test_handshake1(
|
||||
mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc
|
||||
):
|
||||
async def _return_handshake1_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash
|
||||
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
return _mock_response(200, server_seed + seed_auth_hash)
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
|
||||
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=transport_class(config=config))
|
||||
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
with expectation:
|
||||
(
|
||||
local_seed,
|
||||
@@ -233,31 +254,51 @@ async def test_handshake1(mocker, device_credentials, expectation):
|
||||
await protocol.close()
|
||||
|
||||
|
||||
async def test_handshake(mocker):
|
||||
@pytest.mark.parametrize(
|
||||
"transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2",
|
||||
[
|
||||
pytest.param(
|
||||
KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP"
|
||||
),
|
||||
pytest.param(
|
||||
KlapTransportV2,
|
||||
lambda c, s, a: c + s + a,
|
||||
lambda c, s, a: s + c + a,
|
||||
id="KLAPV2",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_handshake(
|
||||
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
|
||||
):
|
||||
async def _return_handshake_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal response_status, client_seed, server_seed, device_auth_hash
|
||||
nonlocal client_seed, server_seed, device_auth_hash
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
client_seed_auth_hash = _sha256(data + device_auth_hash)
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
|
||||
return _mock_response(200, server_seed + client_seed_auth_hash)
|
||||
return _mock_response(200, server_seed + seed_auth_hash)
|
||||
elif url == "http://127.0.0.1/app/handshake2":
|
||||
seed_auth_hash = _sha256(
|
||||
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
|
||||
)
|
||||
assert data == seed_auth_hash
|
||||
return _mock_response(response_status, b"")
|
||||
|
||||
client_seed = None
|
||||
server_seed = secrets.token_bytes(16)
|
||||
client_credentials = Credentials("foo", "bar")
|
||||
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
|
||||
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
|
||||
|
||||
mocker.patch.object(
|
||||
httpx.AsyncClient, "post", side_effect=_return_handshake_response
|
||||
)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=transport_class(config=config))
|
||||
protocol._transport.http_client = httpx.AsyncClient()
|
||||
|
||||
response_status = 200
|
||||
@@ -273,7 +314,7 @@ async def test_handshake(mocker):
|
||||
|
||||
async def test_query(mocker):
|
||||
async def _return_response(url, params=None, data=None, *_, **__):
|
||||
nonlocal client_seed, server_seed, device_auth_hash, protocol, seq
|
||||
nonlocal client_seed, server_seed, device_auth_hash, seq
|
||||
|
||||
if url == "http://127.0.0.1/app/handshake1":
|
||||
client_seed = data
|
||||
@@ -303,10 +344,8 @@ async def test_query(mocker):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
||||
for _ in range(10):
|
||||
resp = await protocol.query({})
|
||||
@@ -350,10 +389,8 @@ async def test_authentication_failures(mocker, response_status, expectation):
|
||||
|
||||
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
|
||||
|
||||
protocol = IotProtocol(
|
||||
"127.0.0.1",
|
||||
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
|
||||
protocol = IotProtocol(transport=KlapTransport(config=config))
|
||||
|
||||
with expectation:
|
||||
await protocol.query({})
|
||||
|
@@ -9,6 +9,7 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import SmartDeviceException
|
||||
from ..protocol import (
|
||||
BaseTransport,
|
||||
@@ -31,10 +32,11 @@ async def test_protocol_retries(mocker, retry_count):
|
||||
return reader, writer
|
||||
|
||||
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=retry_count)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=retry_count
|
||||
)
|
||||
|
||||
assert conn.call_count == retry_count + 1
|
||||
|
||||
@@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker):
|
||||
"asyncio.open_connection",
|
||||
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker):
|
||||
"asyncio.open_connection",
|
||||
side_effect=ConnectionRefusedError,
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 1
|
||||
|
||||
@@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker):
|
||||
"asyncio.open_connection",
|
||||
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
).query({}, retry_count=5)
|
||||
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
|
||||
{}, retry_count=5
|
||||
)
|
||||
|
||||
assert conn.call_count == 6
|
||||
|
||||
@@ -107,9 +112,8 @@ async def test_protocol_reconnect(mocker, retry_count):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({}, retry_count=retry_count)
|
||||
assert response == {"great": "success"}
|
||||
@@ -137,9 +141,8 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1")
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1")
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@@ -173,9 +176,8 @@ async def test_protocol_custom_port(mocker, custom_port):
|
||||
mocker.patch.object(reader, "readexactly", _mock_read)
|
||||
return reader, writer
|
||||
|
||||
protocol = TPLinkSmartHomeProtocol(
|
||||
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
|
||||
)
|
||||
config = DeviceConfig("127.0.0.1", port_override=custom_port)
|
||||
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
|
||||
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
|
||||
response = await protocol.query({})
|
||||
assert response == {"great": "success"}
|
||||
@@ -271,18 +273,14 @@ def _get_subclasses(of_class):
|
||||
def test_protocol_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 3
|
||||
assert len(params) == 2
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[2].name == "transport"
|
||||
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
params[1].name == "transport"
|
||||
and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
||||
|
||||
@@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj):
|
||||
def test_transport_init_signature(class_name_obj):
|
||||
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
|
||||
|
||||
assert len(params) == 5
|
||||
assert len(params) == 2
|
||||
assert (
|
||||
params[0].name == "self"
|
||||
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert (
|
||||
params[1].name == "host"
|
||||
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
assert (
|
||||
params[3].name == "credentials"
|
||||
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
assert (
|
||||
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
|
||||
)
|
||||
|
@@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
import kasa
|
||||
from kasa import Credentials, SmartDevice, SmartDeviceException
|
||||
from kasa.smartdevice import DeviceType
|
||||
from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException
|
||||
|
||||
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
|
||||
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
|
||||
@@ -215,7 +214,8 @@ def test_device_class_ctors(device_class):
|
||||
host = "127.0.0.2"
|
||||
port = 1234
|
||||
credentials = Credentials("foo", "bar")
|
||||
dev = device_class(host, port=port, credentials=credentials)
|
||||
config = DeviceConfig(host, port_override=port, credentials=credentials)
|
||||
dev = device_class(host, config=config)
|
||||
assert dev.host == host
|
||||
assert dev.port == port
|
||||
assert dev.credentials == credentials
|
||||
@@ -231,29 +231,27 @@ async def test_modules_preserved(dev: SmartDevice):
|
||||
|
||||
async def test_create_smart_device_with_timeout():
|
||||
"""Make sure timeout is passed to the protocol."""
|
||||
dev = SmartDevice(host="127.0.0.1", timeout=100)
|
||||
host = "127.0.0.1"
|
||||
dev = SmartDevice(host, config=DeviceConfig(host, timeout=100))
|
||||
assert dev.protocol._transport._timeout == 100
|
||||
|
||||
|
||||
async def test_create_thin_wrapper():
|
||||
"""Make sure thin wrapper is created with the correct device type."""
|
||||
mock = Mock()
|
||||
config = DeviceConfig(
|
||||
host="test_host",
|
||||
port_override=1234,
|
||||
timeout=100,
|
||||
credentials=Credentials("username", "password"),
|
||||
)
|
||||
with patch("kasa.device_factory.connect", return_value=mock) as connect:
|
||||
dev = await SmartDevice.connect(
|
||||
host="test_host",
|
||||
port=1234,
|
||||
timeout=100,
|
||||
credentials=Credentials("username", "password"),
|
||||
device_type=DeviceType.Strip,
|
||||
)
|
||||
dev = await SmartDevice.connect(config=config)
|
||||
assert dev is mock
|
||||
|
||||
connect.assert_called_once_with(
|
||||
host="test_host",
|
||||
port=1234,
|
||||
timeout=100,
|
||||
credentials=Credentials("username", "password"),
|
||||
device_type=DeviceType.Strip,
|
||||
host=None,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -13,6 +13,7 @@ import pytest
|
||||
|
||||
from ..aestransport import AesTransport
|
||||
from ..credentials import Credentials
|
||||
from ..deviceconfig import DeviceConfig
|
||||
from ..exceptions import (
|
||||
SMART_RETRYABLE_ERRORS,
|
||||
SMART_TIMEOUT_ERRORS,
|
||||
@@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code):
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
|
||||
protocol = SmartProtocol(host, transport=AesTransport(host))
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
|
||||
@@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
|
||||
mocker.patch.object(AesTransport, "perform_login")
|
||||
|
||||
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
|
||||
|
||||
protocol = SmartProtocol(host, transport=AesTransport(host))
|
||||
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
|
||||
protocol = SmartProtocol(transport=AesTransport(config=config))
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await protocol.query(DUMMY_QUERY, retry_count=2)
|
||||
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
|
||||
|
Reference in New Issue
Block a user