python-kasa/tests/test_device_factory.py
Teemu R. 1be87674bf
Initial support for vacuums (clean module) (#944)
Adds support for clean module:
- Show current vacuum state
- Start cleaning (all rooms)
- Return to dock
- Pausing & unpausing
- Controlling the fan speed

---------

Co-authored-by: Steven B <51370195+sdb9696@users.noreply.github.com>
2025-01-14 15:35:09 +01:00

295 lines
9.1 KiB
Python

"""Module for testing device factory.
As this module tests the factory with discovery data and expects update to be
called on devices it uses the discovery_mock handles all the patching of the
query methods without actually replacing the device protocol class with one of
the testing fake protocols.
"""
import logging
from typing import cast
import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import (
BaseProtocol,
Credentials,
Discover,
IotProtocol,
KasaException,
SmartCamProtocol,
SmartProtocol,
)
from kasa.device_factory import (
Device,
IotDevice,
SmartCamDevice,
SmartDevice,
connect,
get_device_class_from_family,
get_protocol,
)
from kasa.deviceconfig import (
DeviceConfig,
DeviceConnectionParameters,
DeviceEncryptionType,
DeviceFamily,
)
from kasa.discover import DiscoveryResult
from kasa.transports import (
AesTransport,
BaseTransport,
KlapTransport,
KlapTransportV2,
LinkieTransportV2,
SslAesTransport,
SslTransport,
XorTransport,
)
from .conftest import DISCOVERY_MOCK_IP
# Device Factory tests are not relevant for real devices which run against
# a single device that has already been created via the factory.
pytestmark = [pytest.mark.requires_dummy]
def _get_connection_type_device_class(discovery_info):
if "result" in discovery_info:
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult.from_dict(discovery_info["result"])
connection_type = DeviceConnectionParameters.from_values(
dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
dr.mgt_encrypt_schm.lv,
dr.mgt_encrypt_schm.is_support_https,
)
else:
connection_type = DeviceConnectionParameters.from_values(
DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value
)
device_class = Discover._get_device_class(discovery_info)
return connection_type, device_class
async def test_connect(
discovery_mock,
mocker,
):
"""Test that if the protocol is passed in it gets set correctly."""
host = DISCOVERY_MOCK_IP
ctype, device_class = _get_connection_type_device_class(
discovery_mock.discovery_data
)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
# mocker.patch.object(SmartDevice, "update")
# mocker.patch.object(Device, "update")
dev = await connect(
config=config,
)
assert isinstance(dev, device_class)
assert isinstance(dev.protocol, protocol_class)
assert dev.config == config
assert close_mock.call_count == 0
await dev.disconnect()
assert close_mock.call_count == 1
@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(discovery_mock, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance."""
host = DISCOVERY_MOCK_IP
discovery_data = discovery_mock.discovery_data
ctype, _ = _get_connection_type_device_class(discovery_data)
config = DeviceConfig(
host=host,
port_override=custom_port,
connection_type=ctype,
credentials=Credentials("dummy_user", "dummy_password"),
)
default_port = (
DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port
if "result" in discovery_data
else 9999
)
ctype, _ = _get_connection_type_device_class(discovery_data)
dev = await connect(config=config)
assert issubclass(dev.__class__, Device)
assert dev.port == custom_port or dev.port == default_port
@pytest.mark.xdist_group(name="caplog")
async def test_connect_logs_connect_time(
discovery_mock,
caplog: pytest.LogCaptureFixture,
):
"""Test that the connect time is logged when debug logging is enabled."""
discovery_data = discovery_mock.discovery_data
ctype, _ = _get_connection_type_device_class(discovery_data)
host = DISCOVERY_MOCK_IP
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
logging.getLogger("kasa").setLevel(logging.DEBUG)
await connect(
config=config,
)
assert "seconds to update" in caplog.text
async def test_connect_query_fails(discovery_mock, mocker):
"""Make sure that connect fails when query fails."""
host = DISCOVERY_MOCK_IP
discovery_data = discovery_mock.discovery_data
mocker.patch("kasa.IotProtocol.query", side_effect=KasaException)
mocker.patch("kasa.SmartProtocol.query", side_effect=KasaException)
ctype, _ = _get_connection_type_device_class(discovery_data)
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
assert close_mock.call_count == 0
with pytest.raises(KasaException):
await connect(config=config)
assert close_mock.call_count == 1
async def test_connect_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = DISCOVERY_MOCK_IP
discovery_data = discovery_mock.discovery_data
ctype, _ = _get_connection_type_device_class(discovery_data)
http_client = aiohttp.ClientSession()
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
dev = await connect(config=config)
if ctype.encryption_type != DeviceEncryptionType.Xor:
assert dev.protocol._transport._http_client.client != http_client
await dev.disconnect()
config = DeviceConfig(
host=host,
credentials=Credentials("foor", "bar"),
connection_type=ctype,
http_client=http_client,
)
dev = await connect(config=config)
if ctype.encryption_type != DeviceEncryptionType.Xor:
assert dev.protocol._transport._http_client.client == http_client
await dev.disconnect()
await http_client.close()
async def test_device_types(dev: Device):
await dev.update()
if isinstance(dev, SmartCamDevice):
res = SmartCamDevice._get_device_type_from_sysinfo(dev.sys_info)
elif isinstance(dev, SmartDevice):
assert dev._discovery_info
device_type = cast(str, dev._discovery_info["device_type"])
res = SmartDevice._get_device_type_from_components(
list(dev._components.keys()), device_type
)
else:
res = IotDevice._get_device_type_from_sys_info(dev._last_update)
assert dev.device_type == res
@pytest.mark.xdist_group(name="caplog")
async def test_device_class_from_unknown_family(caplog):
"""Verify that unknown SMART devices yield a warning and fallback to SmartDevice."""
dummy_name = "SMART.foo"
with caplog.at_level(logging.DEBUG):
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
assert f"Unknown SMART device with {dummy_name}" in caplog.text
# Aliases to make the test params more readable
CP = DeviceConnectionParameters
DF = DeviceFamily
ET = DeviceEncryptionType
@pytest.mark.parametrize(
("conn_params", "expected_protocol", "expected_transport"),
[
pytest.param(
CP(DF.SmartIpCamera, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam",
),
pytest.param(
CP(DF.SmartTapoHub, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam-hub",
),
pytest.param(
CP(DF.IotIpCamera, ET.Aes, https=True),
IotProtocol,
LinkieTransportV2,
id="kasacam",
),
pytest.param(
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
SmartProtocol,
SslTransport,
id="robovac",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
IotProtocol,
KlapTransport,
id="iot-klap",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
IotProtocol,
XorTransport,
id="iot-xor",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Aes, https=False),
SmartProtocol,
AesTransport,
id="smart-aes",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Klap, https=False),
SmartProtocol,
KlapTransportV2,
id="smart-klap",
),
],
)
async def test_get_protocol(
conn_params: DeviceConnectionParameters,
expected_protocol: type[BaseProtocol],
expected_transport: type[BaseTransport],
):
"""Test get_protocol returns the right protocol."""
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
protocol = get_protocol(config)
assert isinstance(protocol, expected_protocol)
assert isinstance(protocol._transport, expected_transport)