"""Module for testing device factory. As this module tests the factory with discovery data and expects update to be called on devices it uses the discovery_mock handles all the patching of the query methods without actually replacing the device protocol class with one of the testing fake protocols. """ import logging from typing import cast import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( BaseProtocol, Credentials, Discover, IotProtocol, KasaException, SmartCamProtocol, SmartProtocol, ) from kasa.device_factory import ( Device, IotDevice, SmartCamDevice, SmartDevice, connect, get_device_class_from_family, get_protocol, ) from kasa.deviceconfig import ( DeviceConfig, DeviceConnectionParameters, DeviceEncryptionType, DeviceFamily, ) from kasa.discover import DiscoveryResult from kasa.transports import ( AesTransport, BaseTransport, KlapTransport, KlapTransportV2, LinkieTransportV2, SslAesTransport, SslTransport, XorTransport, ) from .conftest import DISCOVERY_MOCK_IP # Device Factory tests are not relevant for real devices which run against # a single device that has already been created via the factory. pytestmark = [pytest.mark.requires_dummy] def _get_connection_type_device_class(discovery_info): if "result" in discovery_info: device_class = Discover._get_device_class(discovery_info) dr = DiscoveryResult.from_dict(discovery_info["result"]) connection_type = DeviceConnectionParameters.from_values( dr.device_type, dr.mgt_encrypt_schm.encrypt_type, dr.mgt_encrypt_schm.lv, dr.mgt_encrypt_schm.is_support_https, ) else: connection_type = DeviceConnectionParameters.from_values( DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value ) device_class = Discover._get_device_class(discovery_info) return connection_type, device_class async def test_connect( discovery_mock, mocker, ): """Test that if the protocol is passed in it gets set correctly.""" host = DISCOVERY_MOCK_IP ctype, device_class = _get_connection_type_device_class( discovery_mock.discovery_data ) config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) protocol_class = get_protocol(config).__class__ close_mock = mocker.patch.object(protocol_class, "close") # mocker.patch.object(SmartDevice, "update") # mocker.patch.object(Device, "update") dev = await connect( config=config, ) assert isinstance(dev, device_class) assert isinstance(dev.protocol, protocol_class) assert dev.config == config assert close_mock.call_count == 0 await dev.disconnect() assert close_mock.call_count == 1 @pytest.mark.parametrize("custom_port", [123, None]) async def test_connect_custom_port(discovery_mock, mocker, custom_port): """Make sure that connect returns an initialized SmartDevice instance.""" host = DISCOVERY_MOCK_IP discovery_data = discovery_mock.discovery_data ctype, _ = _get_connection_type_device_class(discovery_data) config = DeviceConfig( host=host, port_override=custom_port, connection_type=ctype, credentials=Credentials("dummy_user", "dummy_password"), ) default_port = ( DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port if "result" in discovery_data else 9999 ) ctype, _ = _get_connection_type_device_class(discovery_data) dev = await connect(config=config) assert issubclass(dev.__class__, Device) assert dev.port == custom_port or dev.port == default_port @pytest.mark.xdist_group(name="caplog") async def test_connect_logs_connect_time( discovery_mock, caplog: pytest.LogCaptureFixture, ): """Test that the connect time is logged when debug logging is enabled.""" discovery_data = discovery_mock.discovery_data ctype, _ = _get_connection_type_device_class(discovery_data) host = DISCOVERY_MOCK_IP config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) logging.getLogger("kasa").setLevel(logging.DEBUG) await connect( config=config, ) assert "seconds to update" in caplog.text async def test_connect_query_fails(discovery_mock, mocker): """Make sure that connect fails when query fails.""" host = DISCOVERY_MOCK_IP discovery_data = discovery_mock.discovery_data mocker.patch("kasa.IotProtocol.query", side_effect=KasaException) mocker.patch("kasa.SmartProtocol.query", side_effect=KasaException) ctype, _ = _get_connection_type_device_class(discovery_data) config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) protocol_class = get_protocol(config).__class__ close_mock = mocker.patch.object(protocol_class, "close") assert close_mock.call_count == 0 with pytest.raises(KasaException): await connect(config=config) assert close_mock.call_count == 1 async def test_connect_http_client(discovery_mock, mocker): """Make sure that discover_single returns an initialized SmartDevice instance.""" host = DISCOVERY_MOCK_IP discovery_data = discovery_mock.discovery_data ctype, _ = _get_connection_type_device_class(discovery_data) http_client = aiohttp.ClientSession() config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) dev = await connect(config=config) if ctype.encryption_type != DeviceEncryptionType.Xor: assert dev.protocol._transport._http_client.client != http_client await dev.disconnect() config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype, http_client=http_client, ) dev = await connect(config=config) if ctype.encryption_type != DeviceEncryptionType.Xor: assert dev.protocol._transport._http_client.client == http_client await dev.disconnect() await http_client.close() async def test_device_types(dev: Device): await dev.update() if isinstance(dev, SmartCamDevice): res = SmartCamDevice._get_device_type_from_sysinfo(dev.sys_info) elif isinstance(dev, SmartDevice): assert dev._discovery_info device_type = cast(str, dev._discovery_info["device_type"]) res = SmartDevice._get_device_type_from_components( list(dev._components.keys()), device_type ) else: res = IotDevice._get_device_type_from_sys_info(dev._last_update) assert dev.device_type == res @pytest.mark.xdist_group(name="caplog") async def test_device_class_from_unknown_family(caplog): """Verify that unknown SMART devices yield a warning and fallback to SmartDevice.""" dummy_name = "SMART.foo" with caplog.at_level(logging.DEBUG): assert get_device_class_from_family(dummy_name, https=False) == SmartDevice assert f"Unknown SMART device with {dummy_name}" in caplog.text # Aliases to make the test params more readable CP = DeviceConnectionParameters DF = DeviceFamily ET = DeviceEncryptionType @pytest.mark.parametrize( ("conn_params", "expected_protocol", "expected_transport"), [ pytest.param( CP(DF.SmartIpCamera, ET.Aes, https=True), SmartCamProtocol, SslAesTransport, id="smartcam", ), pytest.param( CP(DF.SmartTapoHub, ET.Aes, https=True), SmartCamProtocol, SslAesTransport, id="smartcam-hub", ), pytest.param( CP(DF.IotIpCamera, ET.Aes, https=True), IotProtocol, LinkieTransportV2, id="kasacam", ), pytest.param( CP(DF.SmartTapoRobovac, ET.Aes, https=True), SmartProtocol, SslTransport, id="robovac", ), pytest.param( CP(DF.IotSmartPlugSwitch, ET.Klap, https=False), IotProtocol, KlapTransport, id="iot-klap", ), pytest.param( CP(DF.IotSmartPlugSwitch, ET.Xor, https=False), IotProtocol, XorTransport, id="iot-xor", ), pytest.param( CP(DF.SmartTapoPlug, ET.Aes, https=False), SmartProtocol, AesTransport, id="smart-aes", ), pytest.param( CP(DF.SmartTapoPlug, ET.Klap, https=False), SmartProtocol, KlapTransportV2, id="smart-klap", ), ], ) async def test_get_protocol( conn_params: DeviceConnectionParameters, expected_protocol: type[BaseProtocol], expected_transport: type[BaseTransport], ): """Test get_protocol returns the right protocol.""" config = DeviceConfig("127.0.0.1", connection_type=conn_params) protocol = get_protocol(config) assert isinstance(protocol, expected_protocol) assert isinstance(protocol._transport, expected_transport)