Add DeviceConfig to allow specifying configuration parameters (#569)

* Add DeviceConfig handling

* Update post review

* Further update post latest review

* Update following latest review

* Update docstrings and docs
This commit is contained in:
sdb9696
2023-12-29 19:17:15 +00:00
committed by GitHub
parent ec3ea39a37
commit f6fd898faf
33 changed files with 1032 additions and 589 deletions

View File

@@ -388,7 +388,6 @@ async def get_device_for_file(file, protocol):
d = device_for_file(model, protocol)(host="127.0.0.123")
if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo)
d.credentials = Credentials("", "")
else:
d.protocol = FakeTransportProtocol(sysinfo)
await _update_and_close(d)
@@ -426,28 +425,53 @@ def discovery_mock(all_fixture_data, mocker):
class _DiscoveryMock:
ip: str
default_port: int
discovery_port: int
discovery_data: dict
query_data: dict
device_type: str
encrypt_type: str
port_override: Optional[int] = None
if "discovery_result" in all_fixture_data:
discovery_data = {"result": all_fixture_data["discovery_result"]}
device_type = all_fixture_data["discovery_result"]["device_type"]
encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][
"encrypt_type"
]
datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data)
dm = _DiscoveryMock(
"127.0.0.123",
80,
20002,
discovery_data,
all_fixture_data,
device_type,
encrypt_type,
)
else:
sys_info = all_fixture_data["system"]["get_sysinfo"]
discovery_data = {"system": {"get_sysinfo": sys_info}}
device_type = sys_info.get("mic_type") or sys_info.get("type")
encrypt_type = "XOR"
datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:]
dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data)
dm = _DiscoveryMock(
"127.0.0.123",
9999,
9999,
discovery_data,
all_fixture_data,
device_type,
encrypt_type,
)
def mock_discover(self):
port = (
dm.port_override
if dm.port_override and dm.default_port != 20002
else dm.default_port
if dm.port_override and dm.discovery_port != 20002
else dm.discovery_port
)
self.datagram_received(
datagram,

View File

@@ -15,7 +15,9 @@ from voluptuous import (
Schema,
)
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport
from ..smartprotocol import SmartProtocol
_LOGGER = logging.getLogger(__name__)
@@ -290,7 +292,9 @@ TIME_MODULE = {
class FakeSmartProtocol(SmartProtocol):
def __init__(self, info):
super().__init__("127.0.0.123", transport=FakeSmartTransport(info))
super().__init__(
transport=FakeSmartTransport(info),
)
async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
@@ -301,10 +305,15 @@ class FakeSmartProtocol(SmartProtocol):
class FakeSmartTransport(BaseTransport):
def __init__(self, info):
super().__init__(
"127.0.0.123",
config=DeviceConfig("127.0.0.123", credentials=Credentials()),
)
self.info = info
@property
def default_port(self):
"""Default port for the transport."""
return 80
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
@@ -344,6 +353,11 @@ class FakeSmartTransport(BaseTransport):
class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
super().__init__(
transport=_XorTransport(
config=DeviceConfig("127.0.0.123"),
)
)
self.discovery_data = info
self.writer = None
self.reader = None

View File

@@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padd
from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
@@ -58,7 +59,9 @@ async def test_handshake(
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
assert transport._encryption_session is None
assert transport._handshake_done is False
@@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
@@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token
un, pw = transport.hash_credentials(True)
request = {
"method": "get_device_info",
"params": None,
@@ -119,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code):
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)
transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
transport = AesTransport(config=config)
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session

View File

@@ -4,10 +4,26 @@ import asyncclick as click
import pytest
from asyncclick.testing import CliRunner
from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException
from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle
from kasa.device_factory import DEVICE_TYPE_TO_CLASS
from kasa.discover import Discover
from kasa import (
AuthenticationException,
Credentials,
SmartDevice,
TPLinkSmartHomeProtocol,
UnsupportedDeviceException,
)
from kasa.cli import (
TYPE_TO_CLASS,
alias,
brightness,
cli,
emeter,
raw_command,
state,
sysinfo,
toggle,
)
from kasa.discover import Discover, DiscoveryResult
from kasa.smartprotocol import SmartProtocol
from .conftest import device_iot, handle_turn_on, new_discovery, turn_on
@@ -145,9 +161,11 @@ async def test_credentials(discovery_mock, mocker):
)
mocker.patch("kasa.cli.state", new=_state)
for subclass in DEVICE_TYPE_TO_CLASS.values():
mocker.patch.object(subclass, "update")
mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data)
mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data)
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
runner = CliRunner()
res = await runner.invoke(
cli,
@@ -158,6 +176,10 @@ async def test_credentials(discovery_mock, mocker):
"foo",
"--password",
"bar",
"--device-family",
dr.device_type,
"--encrypt-type",
dr.mgt_encrypt_schm.encrypt_type,
],
)
assert res.exit_code == 0
@@ -166,7 +188,7 @@ async def test_credentials(discovery_mock, mocker):
@device_iot
async def test_without_device_type(discovery_data: dict, dev, mocker):
async def test_without_device_type(dev, mocker):
"""Test connecting without the device type."""
runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
@@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker):
assert res.exit_code != 0
assert isinstance(res.exception, AuthenticationException)
@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS))
async def test_type_param(device_type, mocker):
"""Test for handling only one of username or password supplied."""
runner = CliRunner()
result_device = FileNotFoundError
pass_dev = click.make_pass_decorator(SmartDevice)
@pass_dev
async def _state(dev: SmartDevice):
nonlocal result_device
result_device = dev
mocker.patch("kasa.cli.state", new=_state)
expected_type = TYPE_TO_CLASS[device_type]
mocker.patch.object(expected_type, "update")
res = await runner.invoke(
cli,
["--type", device_type, "--host", "127.0.0.1"],
)
assert res.exit_code == 0
assert isinstance(result_device, expected_type)

View File

@@ -2,6 +2,7 @@
import logging
from typing import Type
import httpx
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
@@ -15,122 +16,138 @@ from kasa import (
SmartLightStrip,
SmartPlug,
)
from kasa.device_factory import (
DEVICE_TYPE_TO_CLASS,
connect,
get_protocol_from_connection_name,
from kasa.device_factory import connect, get_protocol
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import DiscoveryResult
from kasa.iotprotocol import IotProtocol
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect(discovery_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
def _get_connection_type_device_class(the_fixture_data):
if "discovery_result" in the_fixture_data:
discovery_info = {"result": the_fixture_data["discovery_result"]}
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult(**discovery_info["result"])
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
connection_type = ConnectionType.from_values(
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999
connection_type = ConnectionType.from_values(
DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value
)
device_class = Discover._get_device_class(the_fixture_data)
return connection_type, device_class
@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: Type[SmartDevice],
custom_port,
):
"""Make sure that connect with a passed device type."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
dev = await connect(host, port=custom_port)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
dev = await connect(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999
async def test_connect_query_fails(discovery_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
with pytest.raises(SmartDeviceException):
await connect(host)
async def test_connect_logs_connect_time(
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
host = "127.0.0.1"
if "result" in discovery_data:
with pytest.raises(SmartDeviceException):
await connect(host)
else:
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(host)
assert "seconds to connect" in caplog.text
async def test_connect_pass_protocol(
async def test_connect(
all_fixture_data: dict,
mocker,
):
"""Test that if the protocol is passed in it's gets set correctly."""
if "discovery_result" in all_fixture_data:
discovery_info = {"result": all_fixture_data["discovery_result"]}
device_class = Discover._get_device_class(discovery_info)
else:
device_class = Discover._get_device_class(all_fixture_data)
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
]
"""Test that if the protocol is passed in it gets set correctly."""
host = "127.0.0.1"
if "discovery_result" in all_fixture_data:
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
ctype, device_class = _get_connection_type_device_class(all_fixture_data)
dr = DiscoveryResult(**discovery_info["result"])
connection_name = (
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
)
protocol_class = get_protocol_from_connection_name(
connection_name, host
).__class__
else:
mocker.patch(
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
)
protocol_class = TPLinkSmartHomeProtocol
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
dev = await connect(
host,
device_type=device_type,
protocol_class=protocol_class,
credentials=Credentials("", ""),
config=config,
)
assert isinstance(dev, device_class)
assert isinstance(dev.protocol, protocol_class)
assert dev.config == config
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype)
default_port = 80 if "discovery_result" in all_fixture_data else 9999
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
dev = await connect(config=config)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == default_port
async def test_connect_logs_connect_time(
all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker
):
"""Test that the connect time is logged when debug logging is enabled."""
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
host = "127.0.0.1"
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(
config=config,
)
assert "seconds to update" in caplog.text
async def test_connect_query_fails(all_fixture_data: dict, mocker):
"""Make sure that connect fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException)
mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException)
ctype, _ = _get_connection_type_device_class(all_fixture_data)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
with pytest.raises(SmartDeviceException):
await connect(config=config)
async def test_connect_http_client(all_fixture_data, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
ctype, _ = _get_connection_type_device_class(all_fixture_data)
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
http_client = httpx.AsyncClient()
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client != http_client
config = DeviceConfig(
host=host,
credentials=Credentials("foor", "bar"),
connection_type=ctype,
http_client=http_client,
)
dev = await connect(config=config)
if ctype.encryption_type != EncryptType.Xor:
assert dev.protocol._transport._http_client == http_client

View File

@@ -0,0 +1,21 @@
from json import dumps as json_dumps
from json import loads as json_loads
import httpx
from kasa.credentials import Credentials
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
def test_serialization():
config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient())
config_dict = config.to_dict()
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2

View File

@@ -1,21 +1,29 @@
# type: ignore
import logging
import re
import socket
import httpx
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
Credentials,
DeviceType,
Discover,
SmartDevice,
SmartDeviceException,
SmartStrip,
protocol,
)
from kasa.deviceconfig import (
ConnectionType,
DeviceConfig,
DeviceFamilyType,
EncryptType,
)
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException
from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip
from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip
UNSUPPORTED = {
"result": {
@@ -89,13 +97,26 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
host = "127.0.0.1"
discovery_mock.ip = host
discovery_mock.port_override = custom_port
update_mock = mocker.patch.object(SmartStrip, "update")
x = await Discover.discover_single(host, port=custom_port)
device_class = Discover._get_device_class(discovery_mock.discovery_data)
update_mock = mocker.patch.object(device_class, "update")
x = await Discover.discover_single(
host, port=custom_port, credentials=Credentials()
)
assert issubclass(x.__class__, SmartDevice)
assert x._discovery_info is not None
assert x.port == custom_port or x.port == discovery_mock.default_port
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
assert update_mock.call_count == 0
ct = ConnectionType.from_values(
discovery_mock.device_type, discovery_mock.encrypt_type
)
uses_http = discovery_mock.default_port == 80
config = DeviceConfig(
host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http
)
assert x.config == config
async def test_discover_single_hostname(discovery_mock, mocker):
@@ -104,47 +125,39 @@ async def test_discover_single_hostname(discovery_mock, mocker):
ip = "127.0.0.1"
discovery_mock.ip = ip
update_mock = mocker.patch.object(SmartStrip, "update")
device_class = Discover._get_device_class(discovery_mock.discovery_data)
update_mock = mocker.patch.object(device_class, "update")
x = await Discover.discover_single(host)
x = await Discover.discover_single(host, credentials=Credentials())
assert issubclass(x.__class__, SmartDevice)
assert x._discovery_info is not None
assert x.host == host
assert (update_mock.call_count > 0) == isinstance(x, SmartStrip)
assert update_mock.call_count == 0
mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host)
x = await Discover.discover_single(host, credentials=Credentials())
async def test_discover_single_unsupported(mocker):
async def test_discover_single_unsupported(unsupported_device_info, mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
# Test with a valid unsupported response
discovery_data = UNSUPPORTED
with pytest.raises(
UnsupportedDeviceException,
match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
):
await Discover.discover_single(host)
# Test with no response
discovery_data = None
async def test_discover_single_no_response(mocker):
"""Make sure that discover_single handles no response correctly."""
host = "127.0.0.1"
mocker.patch.object(_DiscoverProtocol, "do_discover")
with pytest.raises(
SmartDeviceException, match=f"Timed out getting discovery response for {host}"
):
await Discover.discover_single(host, timeout=0.001)
await Discover.discover_single(host, discovery_timeout=0)
INVALIDS = [
@@ -241,52 +254,82 @@ AUTHENTICATION_DATA_KLAP = {
}
async def test_discover_single_authentication(mocker):
@new_discovery
async def test_discover_single_authentication(discovery_mock, mocker):
"""Make sure that discover_single handles authenticating devices correctly."""
host = "127.0.0.1"
def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))
mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
discovery_mock.ip = host
device_class = Discover._get_device_class(discovery_mock.discovery_data)
mocker.patch.object(
SmartDevice,
device_class,
"update",
side_effect=AuthenticationException("Failed to authenticate"),
)
# Test with a valid unsupported response
discovery_data = AUTHENTICATION_DATA_KLAP
with pytest.raises(
AuthenticationException,
match="Failed to authenticate",
):
device = await Discover.discover_single(host)
device = await Discover.discover_single(
host, credentials=Credentials("foo", "bar")
)
await device.update()
mocker.patch.object(SmartDevice, "update")
device = await Discover.discover_single(host)
mocker.patch.object(device_class, "update")
device = await Discover.discover_single(host, credentials=Credentials("foo", "bar"))
await device.update()
assert device.device_type == DeviceType.Plug
assert isinstance(device, device_class)
async def test_device_update_from_new_discovery_info():
@new_discovery
async def test_device_update_from_new_discovery_info(discovery_data):
device = SmartDevice("127.0.0.7")
discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"])
discover_info = DiscoveryResult(**discovery_data["result"])
discover_dump = discover_info.get_dict()
discover_dump["alias"] = "foobar"
discover_dump["model"] = discover_dump["device_model"]
device.update_from_discover_info(discover_dump)
assert device.alias == discover_dump["alias"]
assert device.alias == "foobar"
assert device.mac == discover_dump["mac"].replace("-", ":")
assert device.model == discover_dump["model"]
assert device.model == discover_dump["device_model"]
with pytest.raises(
SmartDeviceException,
match=re.escape("You need to await update() to access the data"),
):
assert device.supported_modules
async def test_discover_single_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
x: SmartDevice = await Discover.discover_single(host)
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client
async def test_discover_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
discovery_mock.ip = host
http_client = httpx.AsyncClient()
devices = await Discover.discover(discovery_timeout=0)
x: SmartDevice = devices[host]
assert x.config.uses_http == (discovery_mock.default_port == 80)
if discovery_mock.default_port == 80:
assert x.protocol._transport._http_client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client == http_client

View File

@@ -12,9 +12,15 @@ import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import AuthenticationException, SmartDeviceException
from ..iotprotocol import IotProtocol
from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256
from ..klaptransport import (
KlapEncryptionSession,
KlapTransport,
KlapTransportV2,
_sha256,
)
from ..smartprotocol import SmartProtocol
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
@@ -31,8 +37,9 @@ class _mock_response:
[
(Exception("dummy exception"), True),
(SmartDeviceException("dummy exception"), False),
(httpx.ConnectError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException"),
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
@@ -42,8 +49,10 @@ async def test_protocol_retries(
):
host = "127.0.0.1"
conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
@@ -60,10 +69,11 @@ async def test_protocol_no_retry_on_connection_error(
conn = mocker.patch.object(
httpx.AsyncClient,
"post",
side_effect=httpx.ConnectError("foo"),
side_effect=AuthenticationException("foo"),
)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5
)
@@ -81,8 +91,9 @@ async def test_protocol_retry_recoverable_error(
"post",
side_effect=httpx.CloseError("foo"),
)
config = DeviceConfig(host)
with pytest.raises(SmartDeviceException):
await protocol_class(host, transport=transport_class(host)).query(
await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=5
)
@@ -115,7 +126,8 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
side_effect=_fail_one_less_than_retry_count,
)
response = await protocol_class(host, transport=transport_class(host)).query(
config = DeviceConfig(host)
response = await protocol_class(transport=transport_class(config=config)).query(
DUMMY_QUERY, retry_count=retry_count
)
assert "result" in response or "foobar" in response
@@ -136,7 +148,9 @@ async def test_protocol_logging(mocker, caplog, log_level):
seed = secrets.token_bytes(16)
auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar"))
encryption_session = KlapEncryptionSession(seed, seed, auth_hash)
protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1"))
config = DeviceConfig("127.0.0.1")
protocol = IotProtocol(transport=KlapTransport(config=config))
protocol._transport._handshake_done = True
protocol._transport._session_expire_at = time.time() + 86400
@@ -181,7 +195,7 @@ def test_encrypt_unicode():
"device_credentials, expectation",
[
(Credentials("foo", "bar"), does_not_raise()),
(Credentials("", ""), does_not_raise()),
(Credentials(), does_not_raise()),
(
Credentials(
KlapTransport.KASA_SETUP_EMAIL,
@@ -196,30 +210,37 @@ def test_encrypt_unicode():
],
ids=("client", "blank", "kasa_setup", "shouldfail"),
)
async def test_handshake1(mocker, device_credentials, expectation):
@pytest.mark.parametrize(
"transport_class, seed_auth_hash_calc",
[
pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"),
pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"),
],
)
async def test_handshake1(
mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc
):
async def _return_handshake1_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
return _mock_response(200, server_seed + client_seed_auth_hash)
seed_auth_hash = _sha256(
seed_auth_hash_calc(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + seed_auth_hash)
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(device_credentials)
device_auth_hash = transport_class.generate_auth_hash(device_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake1_response
)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
protocol._transport.http_client = httpx.AsyncClient()
with expectation:
(
local_seed,
@@ -233,31 +254,51 @@ async def test_handshake1(mocker, device_credentials, expectation):
await protocol.close()
async def test_handshake(mocker):
@pytest.mark.parametrize(
"transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2",
[
pytest.param(
KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP"
),
pytest.param(
KlapTransportV2,
lambda c, s, a: c + s + a,
lambda c, s, a: s + c + a,
id="KLAPV2",
),
],
)
async def test_handshake(
mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2
):
async def _return_handshake_response(url, params=None, data=None, *_, **__):
nonlocal response_status, client_seed, server_seed, device_auth_hash
nonlocal client_seed, server_seed, device_auth_hash
if url == "http://127.0.0.1/app/handshake1":
client_seed = data
client_seed_auth_hash = _sha256(data + device_auth_hash)
seed_auth_hash = _sha256(
seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
)
return _mock_response(200, server_seed + client_seed_auth_hash)
return _mock_response(200, server_seed + seed_auth_hash)
elif url == "http://127.0.0.1/app/handshake2":
seed_auth_hash = _sha256(
seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
)
assert data == seed_auth_hash
return _mock_response(response_status, b"")
client_seed = None
server_seed = secrets.token_bytes(16)
client_credentials = Credentials("foo", "bar")
device_auth_hash = KlapTransport.generate_auth_hash(client_credentials)
device_auth_hash = transport_class.generate_auth_hash(client_credentials)
mocker.patch.object(
httpx.AsyncClient, "post", side_effect=_return_handshake_response
)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=transport_class(config=config))
protocol._transport.http_client = httpx.AsyncClient()
response_status = 200
@@ -273,7 +314,7 @@ async def test_handshake(mocker):
async def test_query(mocker):
async def _return_response(url, params=None, data=None, *_, **__):
nonlocal client_seed, server_seed, device_auth_hash, protocol, seq
nonlocal client_seed, server_seed, device_auth_hash, seq
if url == "http://127.0.0.1/app/handshake1":
client_seed = data
@@ -303,10 +344,8 @@ async def test_query(mocker):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
for _ in range(10):
resp = await protocol.query({})
@@ -350,10 +389,8 @@ async def test_authentication_failures(mocker, response_status, expectation):
mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response)
protocol = IotProtocol(
"127.0.0.1",
transport=KlapTransport("127.0.0.1", credentials=client_credentials),
)
config = DeviceConfig("127.0.0.1", credentials=client_credentials)
protocol = IotProtocol(transport=KlapTransport(config=config))
with expectation:
await protocol.query({})

View File

@@ -9,6 +9,7 @@ import sys
import pytest
from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..protocol import (
BaseTransport,
@@ -31,10 +32,11 @@ async def test_protocol_retries(mocker, retry_count):
return reader, writer
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=retry_count)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=retry_count
)
assert conn.call_count == retry_count + 1
@@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker):
"asyncio.open_connection",
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=5
)
assert conn.call_count == 1
@@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker):
"asyncio.open_connection",
side_effect=ConnectionRefusedError,
)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=5
)
assert conn.call_count == 1
@@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker):
"asyncio.open_connection",
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
)
config = DeviceConfig("127.0.0.1")
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
).query({}, retry_count=5)
await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query(
{}, retry_count=5
)
assert conn.call_count == 6
@@ -107,9 +112,8 @@ async def test_protocol_reconnect(mocker, retry_count):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({}, retry_count=retry_count)
assert response == {"great": "success"}
@@ -137,9 +141,8 @@ async def test_protocol_logging(mocker, caplog, log_level):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1")
)
config = DeviceConfig("127.0.0.1")
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
@@ -173,9 +176,8 @@ async def test_protocol_custom_port(mocker, custom_port):
mocker.patch.object(reader, "readexactly", _mock_read)
return reader, writer
protocol = TPLinkSmartHomeProtocol(
"127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port)
)
config = DeviceConfig("127.0.0.1", port_override=custom_port)
protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config))
mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
response = await protocol.query({})
assert response == {"great": "success"}
@@ -271,18 +273,14 @@ def _get_subclasses(of_class):
def test_protocol_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 3
assert len(params) == 2
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[2].name == "transport"
and params[2].kind == inspect.Parameter.KEYWORD_ONLY
params[1].name == "transport"
and params[1].kind == inspect.Parameter.KEYWORD_ONLY
)
@@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj):
def test_transport_init_signature(class_name_obj):
params = list(inspect.signature(class_name_obj[1].__init__).parameters.values())
assert len(params) == 5
assert len(params) == 2
assert (
params[0].name == "self"
and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert (
params[1].name == "host"
and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY
assert (
params[3].name == "credentials"
and params[3].kind == inspect.Parameter.KEYWORD_ONLY
)
assert (
params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY
params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY
)

View File

@@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException
from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on
from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol
@@ -215,7 +214,8 @@ def test_device_class_ctors(device_class):
host = "127.0.0.2"
port = 1234
credentials = Credentials("foo", "bar")
dev = device_class(host, port=port, credentials=credentials)
config = DeviceConfig(host, port_override=port, credentials=credentials)
dev = device_class(host, config=config)
assert dev.host == host
assert dev.port == port
assert dev.credentials == credentials
@@ -231,29 +231,27 @@ async def test_modules_preserved(dev: SmartDevice):
async def test_create_smart_device_with_timeout():
"""Make sure timeout is passed to the protocol."""
dev = SmartDevice(host="127.0.0.1", timeout=100)
host = "127.0.0.1"
dev = SmartDevice(host, config=DeviceConfig(host, timeout=100))
assert dev.protocol._transport._timeout == 100
async def test_create_thin_wrapper():
"""Make sure thin wrapper is created with the correct device type."""
mock = Mock()
config = DeviceConfig(
host="test_host",
port_override=1234,
timeout=100,
credentials=Credentials("username", "password"),
)
with patch("kasa.device_factory.connect", return_value=mock) as connect:
dev = await SmartDevice.connect(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
)
dev = await SmartDevice.connect(config=config)
assert dev is mock
connect.assert_called_once_with(
host="test_host",
port=1234,
timeout=100,
credentials=Credentials("username", "password"),
device_type=DeviceType.Strip,
host=None,
config=config,
)

View File

@@ -13,6 +13,7 @@ import pytest
from ..aestransport import AesTransport
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
@@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code):
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
@@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code):
mocker.patch.object(AesTransport, "perform_login")
send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response)
protocol = SmartProtocol(host, transport=AesTransport(host))
config = DeviceConfig(host, credentials=Credentials("foo", "bar"))
protocol = SmartProtocol(transport=AesTransport(config=config))
with pytest.raises(SmartDeviceException):
await protocol.query(DUMMY_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):