mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-08-09 20:24:02 +00:00
Add DeviceConfig to allow specifying configuration parameters (#569)
* Add DeviceConfig handling * Update post review * Further update post latest review * Update following latest review * Update docstrings and docs
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||
|
||||
from kasa import (
|
||||
@@ -15,122 +16,138 @@ from kasa import (
|
||||
SmartLightStrip,
|
||||
SmartPlug,
|
||||
)
|
||||
from kasa.device_factory import (
|
||||
DEVICE_TYPE_TO_CLASS,
|
||||
connect,
|
||||
get_protocol_from_connection_name,
|
||||
from kasa.device_factory import connect, get_protocol
|
||||
from kasa.deviceconfig import (
|
||||
ConnectionType,
|
||||
DeviceConfig,
|
||||
DeviceFamilyType,
|
||||
EncryptType,
|
||||
)
|
||||
from kasa.discover import DiscoveryResult
|
||||
from kasa.iotprotocol import IotProtocol
|
||||
from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect(discovery_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
def _get_connection_type_device_class(the_fixture_data):
|
||||
if "discovery_result" in the_fixture_data:
|
||||
discovery_info = {"result": the_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
connection_type = ConnectionType.from_values(
|
||||
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
connection_type = ConnectionType.from_values(
|
||||
DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value
|
||||
)
|
||||
device_class = Discover._get_device_class(the_fixture_data)
|
||||
|
||||
return connection_type, device_class
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
@pytest.mark.parametrize(
|
||||
("device_type", "klass"),
|
||||
(
|
||||
(DeviceType.Plug, SmartPlug),
|
||||
(DeviceType.Bulb, SmartBulb),
|
||||
(DeviceType.Dimmer, SmartDimmer),
|
||||
(DeviceType.LightStrip, SmartLightStrip),
|
||||
(DeviceType.Unknown, SmartDevice),
|
||||
),
|
||||
)
|
||||
async def test_connect_passed_device_type(
|
||||
discovery_data: dict,
|
||||
mocker,
|
||||
device_type: DeviceType,
|
||||
klass: Type[SmartDevice],
|
||||
custom_port,
|
||||
):
|
||||
"""Make sure that connect with a passed device type."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
dev = await connect(host, port=custom_port)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
dev = await connect(host, port=custom_port, device_type=device_type)
|
||||
assert isinstance(dev, klass)
|
||||
assert dev.port == custom_port or dev.port == 9999
|
||||
|
||||
|
||||
async def test_connect_query_fails(discovery_data: dict, mocker):
|
||||
"""Make sure that connect fails when query fails."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
|
||||
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
|
||||
|
||||
async def test_connect_logs_connect_time(
|
||||
discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
host = "127.0.0.1"
|
||||
if "result" in discovery_data:
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(host)
|
||||
else:
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(host)
|
||||
assert "seconds to connect" in caplog.text
|
||||
|
||||
|
||||
async def test_connect_pass_protocol(
|
||||
async def test_connect(
|
||||
all_fixture_data: dict,
|
||||
mocker,
|
||||
):
|
||||
"""Test that if the protocol is passed in it's gets set correctly."""
|
||||
if "discovery_result" in all_fixture_data:
|
||||
discovery_info = {"result": all_fixture_data["discovery_result"]}
|
||||
device_class = Discover._get_device_class(discovery_info)
|
||||
else:
|
||||
device_class = Discover._get_device_class(all_fixture_data)
|
||||
|
||||
device_type = list(DEVICE_TYPE_TO_CLASS.keys())[
|
||||
list(DEVICE_TYPE_TO_CLASS.values()).index(device_class)
|
||||
]
|
||||
"""Test that if the protocol is passed in it gets set correctly."""
|
||||
host = "127.0.0.1"
|
||||
if "discovery_result" in all_fixture_data:
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
ctype, device_class = _get_connection_type_device_class(all_fixture_data)
|
||||
|
||||
dr = DiscoveryResult(**discovery_info["result"])
|
||||
connection_name = (
|
||||
dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type
|
||||
)
|
||||
protocol_class = get_protocol_from_connection_name(
|
||||
connection_name, host
|
||||
).__class__
|
||||
else:
|
||||
mocker.patch(
|
||||
"kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data
|
||||
)
|
||||
protocol_class = TPLinkSmartHomeProtocol
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
protocol_class = get_protocol(config).__class__
|
||||
|
||||
dev = await connect(
|
||||
host,
|
||||
device_type=device_type,
|
||||
protocol_class=protocol_class,
|
||||
credentials=Credentials("", ""),
|
||||
config=config,
|
||||
)
|
||||
assert isinstance(dev, device_class)
|
||||
assert isinstance(dev.protocol, protocol_class)
|
||||
|
||||
assert dev.config == config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
|
||||
"""Make sure that connect returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype)
|
||||
default_port = 80 if "discovery_result" in all_fixture_data else 9999
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
dev = await connect(config=config)
|
||||
assert issubclass(dev.__class__, SmartDevice)
|
||||
assert dev.port == custom_port or dev.port == default_port
|
||||
|
||||
|
||||
async def test_connect_logs_connect_time(
|
||||
all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker
|
||||
):
|
||||
"""Test that the connect time is logged when debug logging is enabled."""
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
host = "127.0.0.1"
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
logging.getLogger("kasa").setLevel(logging.DEBUG)
|
||||
await connect(
|
||||
config=config,
|
||||
)
|
||||
assert "seconds to update" in caplog.text
|
||||
|
||||
|
||||
async def test_connect_query_fails(all_fixture_data: dict, mocker):
|
||||
"""Make sure that connect fails when query fails."""
|
||||
host = "127.0.0.1"
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
|
||||
mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException)
|
||||
mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException)
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(config=config)
|
||||
|
||||
|
||||
async def test_connect_http_client(all_fixture_data, mocker):
|
||||
"""Make sure that discover_single returns an initialized SmartDevice instance."""
|
||||
host = "127.0.0.1"
|
||||
|
||||
ctype, _ = _get_connection_type_device_class(all_fixture_data)
|
||||
|
||||
mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data)
|
||||
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data)
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client != http_client
|
||||
|
||||
config = DeviceConfig(
|
||||
host=host,
|
||||
credentials=Credentials("foor", "bar"),
|
||||
connection_type=ctype,
|
||||
http_client=http_client,
|
||||
)
|
||||
dev = await connect(config=config)
|
||||
if ctype.encryption_type != EncryptType.Xor:
|
||||
assert dev.protocol._transport._http_client == http_client
|
||||
|
Reference in New Issue
Block a user