Update light transition module to work with child devices (#1017)

Fixes module to work with child devices, i.e. ks240
Interrogates the data to see whether maximums are available.
Fixes a bug whereby setting a duration while the feature is not
enabled does not actually enable it.
This commit is contained in:
Steven B 2024-06-27 18:52:54 +01:00 committed by GitHub
parent cf24a94526
commit 2a62849987
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 304 additions and 77 deletions

View File

@ -107,6 +107,8 @@ class Feature:
Number = Type.Number Number = Type.Number
Choice = Type.Choice Choice = Type.Choice
DEFAULT_MAX = 2**16 # Arbitrary max
class Category(Enum): class Category(Enum):
"""Category hint to allow feature grouping.""" """Category hint to allow feature grouping."""
@ -155,7 +157,7 @@ class Feature:
#: Minimum value #: Minimum value
minimum_value: int = 0 minimum_value: int = 0
#: Maximum value #: Maximum value
maximum_value: int = 2**16 # Arbitrary max maximum_value: int = DEFAULT_MAX
#: Attribute containing the name of the range getter property. #: Attribute containing the name of the range getter property.
#: If set, this property will be used to set *minimum_value* and *maximum_value*. #: If set, this property will be used to set *minimum_value* and *maximum_value*.
range_getter: str | None = None range_getter: str | None = None

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, TypedDict
from ...exceptions import KasaException from ...exceptions import KasaException
from ...feature import Feature from ...feature import Feature
@ -12,6 +12,12 @@ if TYPE_CHECKING:
from ..smartdevice import SmartDevice from ..smartdevice import SmartDevice
class _State(TypedDict):
duration: int
enable: bool
max_duration: int
class LightTransition(SmartModule): class LightTransition(SmartModule):
"""Implementation of gradual on/off.""" """Implementation of gradual on/off."""
@ -19,14 +25,30 @@ class LightTransition(SmartModule):
QUERY_GETTER_NAME = "get_on_off_gradually_info" QUERY_GETTER_NAME = "get_on_off_gradually_info"
MAXIMUM_DURATION = 60 MAXIMUM_DURATION = 60
# Key in sysinfo that indicates state can be retrieved from there.
# Usually only for child lights, i.e, ks240.
SYS_INFO_STATE_KEYS = (
"gradually_on_mode",
"gradually_off_mode",
"fade_on_time",
"fade_off_time",
)
_on_state: _State
_off_state: _State
_enabled: bool
def __init__(self, device: SmartDevice, module: str): def __init__(self, device: SmartDevice, module: str):
super().__init__(device, module) super().__init__(device, module)
self._create_features() self._state_in_sysinfo = all(
key in device.sys_info for key in self.SYS_INFO_STATE_KEYS
)
self._supports_on_and_off: bool = self.supported_version > 1
def _create_features(self): def _initialize_features(self):
"""Create features based on the available version.""" """Initialize features."""
icon = "mdi:transition" icon = "mdi:transition"
if self.supported_version == 1: if not self._supports_on_and_off:
self._add_feature( self._add_feature(
Feature( Feature(
device=self._device, device=self._device,
@ -34,16 +56,12 @@ class LightTransition(SmartModule):
id="smooth_transitions", id="smooth_transitions",
name="Smooth transitions", name="Smooth transitions",
icon=icon, icon=icon,
attribute_getter="enabled_v1", attribute_getter="enabled",
attribute_setter="set_enabled_v1", attribute_setter="set_enabled",
type=Feature.Type.Switch, type=Feature.Type.Switch,
) )
) )
elif self.supported_version >= 2: else:
# v2 adds separate on & off states
# v3 adds max_duration
# TODO: note, hardcoding the maximums for now as the features get
# initialized before the first update.
self._add_feature( self._add_feature(
Feature( Feature(
self._device, self._device,
@ -54,9 +72,9 @@ class LightTransition(SmartModule):
attribute_setter="set_turn_on_transition", attribute_setter="set_turn_on_transition",
icon=icon, icon=icon,
type=Feature.Type.Number, type=Feature.Type.Number,
maximum_value=self.MAXIMUM_DURATION, maximum_value=self._turn_on_transition_max,
) )
) # self._turn_on_transition_max )
self._add_feature( self._add_feature(
Feature( Feature(
self._device, self._device,
@ -67,38 +85,74 @@ class LightTransition(SmartModule):
attribute_setter="set_turn_off_transition", attribute_setter="set_turn_off_transition",
icon=icon, icon=icon,
type=Feature.Type.Number, type=Feature.Type.Number,
maximum_value=self.MAXIMUM_DURATION, maximum_value=self._turn_off_transition_max,
) )
) # self._turn_off_transition_max )
@property def _post_update_hook(self) -> None:
def _turn_on(self): """Update the states."""
"""Internal getter for turn on settings.""" # Assumes any device with state in sysinfo supports on and off and
if "on_state" not in self.data: # has maximum values for both.
# v2 adds separate on & off states
# v3 adds max_duration except for ks240 which is v2 but supports it
if not self._supports_on_and_off:
self._enabled = self.data["enable"]
return
if self._state_in_sysinfo:
on_max = self._device.sys_info.get(
"max_fade_on_time", self.MAXIMUM_DURATION
)
off_max = self._device.sys_info.get(
"max_fade_off_time", self.MAXIMUM_DURATION
)
on_enabled = bool(self._device.sys_info["gradually_on_mode"])
off_enabled = bool(self._device.sys_info["gradually_off_mode"])
on_duration = self._device.sys_info["fade_on_time"]
off_duration = self._device.sys_info["fade_off_time"]
elif (on_state := self.data.get("on_state")) and (
off_state := self.data.get("off_state")
):
on_max = on_state.get("max_duration", self.MAXIMUM_DURATION)
off_max = off_state.get("max_duration", self.MAXIMUM_DURATION)
on_enabled = on_state["enable"]
off_enabled = off_state["enable"]
on_duration = on_state["duration"]
off_duration = off_state["duration"]
else:
raise KasaException( raise KasaException(
f"Unsupported for {self.REQUIRED_COMPONENT} v{self.supported_version}" f"Unsupported for {self.REQUIRED_COMPONENT} v{self.supported_version}"
) )
return self.data["on_state"] self._enabled = on_enabled or off_enabled
self._on_state = {
"duration": on_duration,
"enable": on_enabled,
"max_duration": on_max,
}
self._off_state = {
"duration": off_duration,
"enable": off_enabled,
"max_duration": off_max,
}
@property async def set_enabled(self, enable: bool):
def _turn_off(self):
"""Internal getter for turn off settings."""
if "off_state" not in self.data:
raise KasaException(
f"Unsupported for {self.REQUIRED_COMPONENT} v{self.supported_version}"
)
return self.data["off_state"]
async def set_enabled_v1(self, enable: bool):
"""Enable gradual on/off.""" """Enable gradual on/off."""
return await self.call("set_on_off_gradually_info", {"enable": enable}) if not self._supports_on_and_off:
return await self.call("set_on_off_gradually_info", {"enable": enable})
else:
on = await self.call(
"set_on_off_gradually_info", {"on_state": {"enable": enable}}
)
off = await self.call(
"set_on_off_gradually_info", {"off_state": {"enable": enable}}
)
return {**on, **off}
@property @property
def enabled_v1(self) -> bool: def enabled(self) -> bool:
"""Return True if gradual on/off is enabled.""" """Return True if gradual on/off is enabled."""
return bool(self.data["enable"]) return self._enabled
@property @property
def turn_on_transition(self) -> int: def turn_on_transition(self) -> int:
@ -106,15 +160,13 @@ class LightTransition(SmartModule):
Available only from v2. Available only from v2.
""" """
if "fade_on_time" in self._device.sys_info: return self._on_state["duration"] if self._on_state["enable"] else 0
return self._device.sys_info["fade_on_time"]
return self._turn_on["duration"]
@property @property
def _turn_on_transition_max(self) -> int: def _turn_on_transition_max(self) -> int:
"""Maximum turn on duration.""" """Maximum turn on duration."""
# v3 added max_duration, we default to 60 when it's not available # v3 added max_duration, we default to 60 when it's not available
return self._turn_on.get("max_duration", 60) return self._on_state["max_duration"]
async def set_turn_on_transition(self, seconds: int): async def set_turn_on_transition(self, seconds: int):
"""Set turn on transition in seconds. """Set turn on transition in seconds.
@ -129,12 +181,12 @@ class LightTransition(SmartModule):
if seconds <= 0: if seconds <= 0:
return await self.call( return await self.call(
"set_on_off_gradually_info", "set_on_off_gradually_info",
{"on_state": {**self._turn_on, "enable": False}}, {"on_state": {"enable": False}},
) )
return await self.call( return await self.call(
"set_on_off_gradually_info", "set_on_off_gradually_info",
{"on_state": {**self._turn_on, "duration": seconds}}, {"on_state": {"enable": True, "duration": seconds}},
) )
@property @property
@ -143,15 +195,13 @@ class LightTransition(SmartModule):
Available only from v2. Available only from v2.
""" """
if "fade_off_time" in self._device.sys_info: return self._off_state["duration"] if self._off_state["enable"] else 0
return self._device.sys_info["fade_off_time"]
return self._turn_off["duration"]
@property @property
def _turn_off_transition_max(self) -> int: def _turn_off_transition_max(self) -> int:
"""Maximum turn on duration.""" """Maximum turn on duration."""
# v3 added max_duration, we default to 60 when it's not available # v3 added max_duration, we default to 60 when it's not available
return self._turn_off.get("max_duration", 60) return self._off_state["max_duration"]
async def set_turn_off_transition(self, seconds: int): async def set_turn_off_transition(self, seconds: int):
"""Set turn on transition in seconds. """Set turn on transition in seconds.
@ -166,26 +216,24 @@ class LightTransition(SmartModule):
if seconds <= 0: if seconds <= 0:
return await self.call( return await self.call(
"set_on_off_gradually_info", "set_on_off_gradually_info",
{"off_state": {**self._turn_off, "enable": False}}, {"off_state": {"enable": False}},
) )
return await self.call( return await self.call(
"set_on_off_gradually_info", "set_on_off_gradually_info",
{"off_state": {**self._turn_on, "duration": seconds}}, {"off_state": {"enable": True, "duration": seconds}},
) )
def query(self) -> dict: def query(self) -> dict:
"""Query to execute during the update cycle.""" """Query to execute during the update cycle."""
# Some devices have the required info in the device info. # Some devices have the required info in the device info.
if "gradually_on_mode" in self._device.sys_info: if self._state_in_sysinfo:
return {} return {}
else: else:
return {self.QUERY_GETTER_NAME: None} return {self.QUERY_GETTER_NAME: None}
async def _check_supported(self): async def _check_supported(self):
"""Additional check to see if the module is supported by the device.""" """Additional check to see if the module is supported by the device."""
# TODO Temporarily disabled on child light devices until module fixed # For devices that report child components on the parent that are not
# to support updates # actually supported by the parent.
if self._device._parent is not None:
return False
return "brightness" in self._device.sys_info return "brightness" in self._device.sys_info

View File

@ -15,7 +15,13 @@ from kasa.smart import SmartDevice
from .fakeprotocol_iot import FakeIotProtocol from .fakeprotocol_iot import FakeIotProtocol
from .fakeprotocol_smart import FakeSmartProtocol from .fakeprotocol_smart import FakeSmartProtocol
from .fixtureinfo import FIXTURE_DATA, FixtureInfo, filter_fixtures, idgenerator from .fixtureinfo import (
FIXTURE_DATA,
ComponentFilter,
FixtureInfo,
filter_fixtures,
idgenerator,
)
# Tapo bulbs # Tapo bulbs
BULBS_SMART_VARIABLE_TEMP = {"L530E", "L930-5"} BULBS_SMART_VARIABLE_TEMP = {"L530E", "L930-5"}
@ -175,7 +181,7 @@ def parametrize(
*, *,
model_filter=None, model_filter=None,
protocol_filter=None, protocol_filter=None,
component_filter=None, component_filter: str | ComponentFilter | None = None,
data_root_filter=None, data_root_filter=None,
device_type_filter=None, device_type_filter=None,
ids=None, ids=None,

View File

@ -12,6 +12,8 @@ from .fakeprotocol_iot import FakeIotProtocol
from .fakeprotocol_smart import FakeSmartProtocol, FakeSmartTransport from .fakeprotocol_smart import FakeSmartProtocol, FakeSmartTransport
from .fixtureinfo import FixtureInfo, filter_fixtures, idgenerator from .fixtureinfo import FixtureInfo, filter_fixtures, idgenerator
DISCOVERY_MOCK_IP = "127.0.0.123"
def _make_unsupported(device_family, encrypt_type): def _make_unsupported(device_family, encrypt_type):
return { return {
@ -73,7 +75,7 @@ new_discovery = parametrize_discovery(
async def discovery_mock(request, mocker): async def discovery_mock(request, mocker):
"""Mock discovery and patch protocol queries to use Fake protocols.""" """Mock discovery and patch protocol queries to use Fake protocols."""
fixture_info: FixtureInfo = request.param fixture_info: FixtureInfo = request.param
yield patch_discovery({"127.0.0.123": fixture_info}, mocker) yield patch_discovery({DISCOVERY_MOCK_IP: fixture_info}, mocker)
def create_discovery_mock(ip: str, fixture_data: dict): def create_discovery_mock(ip: str, fixture_data: dict):

View File

@ -78,7 +78,6 @@ class FakeSmartTransport(BaseTransport):
}, },
}, },
), ),
"get_on_off_gradually_info": ("on_off_gradually", {"enable": True}),
"get_latest_fw": ( "get_latest_fw": (
"firmware", "firmware",
{ {
@ -164,6 +163,8 @@ class FakeSmartTransport(BaseTransport):
return {"error_code": 0} return {"error_code": 0}
elif child_method == "set_preset_rules": elif child_method == "set_preset_rules":
return self._set_child_preset_rules(info, child_params) return self._set_child_preset_rules(info, child_params)
elif child_method == "set_on_off_gradually_info":
return self._set_on_off_gradually_info(info, child_params)
elif child_method in child_device_calls: elif child_method in child_device_calls:
result = copy.deepcopy(child_device_calls[child_method]) result = copy.deepcopy(child_device_calls[child_method])
return {"result": result, "error_code": 0} return {"result": result, "error_code": 0}
@ -200,6 +201,49 @@ class FakeSmartTransport(BaseTransport):
"Method %s not implemented for children" % child_method "Method %s not implemented for children" % child_method
) )
def _get_on_off_gradually_info(self, info, params):
if self.components["on_off_gradually"] == 1:
info["get_on_off_gradually_info"] = {"enable": True}
else:
info["get_on_off_gradually_info"] = {
"off_state": {"duration": 5, "enable": False, "max_duration": 60},
"on_state": {"duration": 5, "enable": False, "max_duration": 60},
}
return copy.deepcopy(info["get_on_off_gradually_info"])
def _set_on_off_gradually_info(self, info, params):
# Child devices can have the required properties directly in info
if self.components["on_off_gradually"] == 1:
info["get_on_off_gradually_info"] = {"enable": params["enable"]}
elif on_state := params.get("on_state"):
if "fade_on_time" in info and "gradually_on_mode" in info:
info["gradually_on_mode"] = 1 if on_state["enable"] else 0
if "duration" in on_state:
info["fade_on_time"] = on_state["duration"]
else:
info["get_on_off_gradually_info"]["on_state"]["enable"] = on_state[
"enable"
]
if "duration" in on_state:
info["get_on_off_gradually_info"]["on_state"]["duration"] = (
on_state["duration"]
)
elif off_state := params.get("off_state"):
if "fade_off_time" in info and "gradually_off_mode" in info:
info["gradually_off_mode"] = 1 if off_state["enable"] else 0
if "duration" in off_state:
info["fade_off_time"] = off_state["duration"]
else:
info["get_on_off_gradually_info"]["off_state"]["enable"] = off_state[
"enable"
]
if "duration" in off_state:
info["get_on_off_gradually_info"]["off_state"]["duration"] = (
off_state["duration"]
)
return {"error_code": 0}
def _set_dynamic_light_effect(self, info, params): def _set_dynamic_light_effect(self, info, params):
"""Set or remove values as per the device behaviour.""" """Set or remove values as per the device behaviour."""
info["get_device_info"]["dynamic_light_effect_enable"] = params["enable"] info["get_device_info"]["dynamic_light_effect_enable"] = params["enable"]
@ -294,6 +338,13 @@ class FakeSmartTransport(BaseTransport):
info[method] = copy.deepcopy(missing_result[1]) info[method] = copy.deepcopy(missing_result[1])
result = copy.deepcopy(info[method]) result = copy.deepcopy(info[method])
retval = {"result": result, "error_code": 0} retval = {"result": result, "error_code": 0}
elif (
method == "get_on_off_gradually_info"
and "on_off_gradually" in self.components
):
# Need to call a method here to determine which version schema to return
result = self._get_on_off_gradually_info(info, params)
return {"result": result, "error_code": 0}
else: else:
# PARAMS error returned for KS240 when get_device_usage called # PARAMS error returned for KS240 when get_device_usage called
# on parent device. Could be any error code though. # on parent device. Could be any error code though.
@ -324,6 +375,8 @@ class FakeSmartTransport(BaseTransport):
return self._set_preset_rules(info, params) return self._set_preset_rules(info, params)
elif method == "edit_preset_rules": elif method == "edit_preset_rules":
return self._edit_preset_rules(info, params) return self._edit_preset_rules(info, params)
elif method == "set_on_off_gradually_info":
return self._set_on_off_gradually_info(info, params)
elif method[:4] == "set_": elif method[:4] == "set_":
target_method = f"get_{method[4:]}" target_method = f"get_{method[4:]}"
info[target_method].update(params) info[target_method].update(params)

View File

@ -17,6 +17,12 @@ class FixtureInfo(NamedTuple):
data: dict data: dict
class ComponentFilter(NamedTuple):
component_name: str
minimum_version: int = 0
maximum_version: int | None = None
FixtureInfo.__hash__ = lambda self: hash((self.name, self.protocol)) # type: ignore[attr-defined, method-assign] FixtureInfo.__hash__ = lambda self: hash((self.name, self.protocol)) # type: ignore[attr-defined, method-assign]
FixtureInfo.__eq__ = lambda x, y: hash(x) == hash(y) # type: ignore[method-assign] FixtureInfo.__eq__ = lambda x, y: hash(x) == hash(y) # type: ignore[method-assign]
@ -88,7 +94,7 @@ def filter_fixtures(
data_root_filter: str | None = None, data_root_filter: str | None = None,
protocol_filter: set[str] | None = None, protocol_filter: set[str] | None = None,
model_filter: set[str] | None = None, model_filter: set[str] | None = None,
component_filter: str | None = None, component_filter: str | ComponentFilter | None = None,
device_type_filter: list[DeviceType] | None = None, device_type_filter: list[DeviceType] | None = None,
): ):
"""Filter the fixtures based on supplied parameters. """Filter the fixtures based on supplied parameters.
@ -106,14 +112,26 @@ def filter_fixtures(
file_model = file_model_region.split("(")[0] file_model = file_model_region.split("(")[0]
return file_model in model_filter return file_model in model_filter
def _component_match(fixture_data: FixtureInfo, component_filter): def _component_match(
fixture_data: FixtureInfo, component_filter: str | ComponentFilter
):
if (component_nego := fixture_data.data.get("component_nego")) is None: if (component_nego := fixture_data.data.get("component_nego")) is None:
return False return False
components = { components = {
component["id"]: component["ver_code"] component["id"]: component["ver_code"]
for component in component_nego["component_list"] for component in component_nego["component_list"]
} }
return component_filter in components if isinstance(component_filter, str):
return component_filter in components
else:
return (
(ver_code := components.get(component_filter.component_name))
and ver_code >= component_filter.minimum_version
and (
component_filter.maximum_version is None
or ver_code <= component_filter.maximum_version
)
)
def _device_type_match(fixture_data: FixtureInfo, device_type): def _device_type_match(fixture_data: FixtureInfo, device_type):
if (component_nego := fixture_data.data.get("component_nego")) is None: if (component_nego := fixture_data.data.get("component_nego")) is None:

View File

@ -0,0 +1,80 @@
from pytest_mock import MockerFixture
from kasa import Feature, Module
from kasa.smart import SmartDevice
from kasa.tests.device_fixtures import get_parent_and_child_modules, parametrize
from kasa.tests.fixtureinfo import ComponentFilter
light_transition_v1 = parametrize(
"has light transition",
component_filter=ComponentFilter(
component_name="on_off_gradually", maximum_version=1
),
protocol_filter={"SMART"},
)
light_transition_gt_v1 = parametrize(
"has light transition",
component_filter=ComponentFilter(
component_name="on_off_gradually", minimum_version=2
),
protocol_filter={"SMART"},
)
@light_transition_v1
async def test_module_v1(dev: SmartDevice, mocker: MockerFixture):
"""Test light transition module."""
assert isinstance(dev, SmartDevice)
light_transition = next(get_parent_and_child_modules(dev, Module.LightTransition))
assert light_transition
assert "smooth_transitions" in light_transition._module_features
assert "smooth_transition_on" not in light_transition._module_features
assert "smooth_transition_off" not in light_transition._module_features
await light_transition.set_enabled(True)
await dev.update()
assert light_transition.enabled is True
await light_transition.set_enabled(False)
await dev.update()
assert light_transition.enabled is False
@light_transition_gt_v1
async def test_module_gt_v1(dev: SmartDevice, mocker: MockerFixture):
"""Test light transition module."""
assert isinstance(dev, SmartDevice)
light_transition = next(get_parent_and_child_modules(dev, Module.LightTransition))
assert light_transition
assert "smooth_transitions" not in light_transition._module_features
assert "smooth_transition_on" in light_transition._module_features
assert "smooth_transition_off" in light_transition._module_features
await light_transition.set_enabled(True)
await dev.update()
assert light_transition.enabled is True
await light_transition.set_enabled(False)
await dev.update()
assert light_transition.enabled is False
await light_transition.set_turn_on_transition(5)
await dev.update()
assert light_transition.turn_on_transition == 5
# enabled is true if either on or off is enabled
assert light_transition.enabled is True
await light_transition.set_turn_off_transition(10)
await dev.update()
assert light_transition.turn_off_transition == 10
assert light_transition.enabled is True
max_on = light_transition._module_features["smooth_transition_on"].maximum_value
assert max_on < Feature.DEFAULT_MAX
max_off = light_transition._module_features["smooth_transition_off"].maximum_value
assert max_off < Feature.DEFAULT_MAX
await light_transition.set_turn_on_transition(0)
await light_transition.set_turn_off_transition(0)
await dev.update()
assert light_transition.enabled is False

View File

@ -1,16 +1,25 @@
# type: ignore """Module for testing device factory.
As this module tests the factory with discovery data and expects update to be
called on devices it uses the discovery_mock handles all the patching of the
query methods without actually replacing the device protocol class with one of
the testing fake protocols.
"""
import logging import logging
from typing import cast
import aiohttp import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from kasa import ( from kasa import (
Credentials, Credentials,
Device,
Discover, Discover,
KasaException, KasaException,
) )
from kasa.device_factory import ( from kasa.device_factory import (
Device,
SmartDevice,
_get_device_type_from_sys_info, _get_device_type_from_sys_info,
connect, connect,
get_device_class_from_family, get_device_class_from_family,
@ -23,7 +32,8 @@ from kasa.deviceconfig import (
DeviceFamily, DeviceFamily,
) )
from kasa.discover import DiscoveryResult from kasa.discover import DiscoveryResult
from kasa.smart.smartdevice import SmartDevice
from .conftest import DISCOVERY_MOCK_IP
def _get_connection_type_device_class(discovery_info): def _get_connection_type_device_class(discovery_info):
@ -44,18 +54,22 @@ def _get_connection_type_device_class(discovery_info):
async def test_connect( async def test_connect(
discovery_data, discovery_mock,
mocker, mocker,
): ):
"""Test that if the protocol is passed in it gets set correctly.""" """Test that if the protocol is passed in it gets set correctly."""
host = "127.0.0.1" host = DISCOVERY_MOCK_IP
ctype, device_class = _get_connection_type_device_class(discovery_data) ctype, device_class = _get_connection_type_device_class(
discovery_mock.discovery_data
)
config = DeviceConfig( config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
) )
protocol_class = get_protocol(config).__class__ protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close") close_mock = mocker.patch.object(protocol_class, "close")
# mocker.patch.object(SmartDevice, "update")
# mocker.patch.object(Device, "update")
dev = await connect( dev = await connect(
config=config, config=config,
) )
@ -69,10 +83,11 @@ async def test_connect(
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(discovery_data: dict, mocker, custom_port): async def test_connect_custom_port(discovery_mock, mocker, custom_port):
"""Make sure that connect returns an initialized SmartDevice instance.""" """Make sure that connect returns an initialized SmartDevice instance."""
host = "127.0.0.1" host = DISCOVERY_MOCK_IP
discovery_data = discovery_mock.discovery_data
ctype, _ = _get_connection_type_device_class(discovery_data) ctype, _ = _get_connection_type_device_class(discovery_data)
config = DeviceConfig( config = DeviceConfig(
host=host, host=host,
@ -90,13 +105,14 @@ async def test_connect_custom_port(discovery_data: dict, mocker, custom_port):
async def test_connect_logs_connect_time( async def test_connect_logs_connect_time(
discovery_data: dict, discovery_mock,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
): ):
"""Test that the connect time is logged when debug logging is enabled.""" """Test that the connect time is logged when debug logging is enabled."""
discovery_data = discovery_mock.discovery_data
ctype, _ = _get_connection_type_device_class(discovery_data) ctype, _ = _get_connection_type_device_class(discovery_data)
host = "127.0.0.1" host = DISCOVERY_MOCK_IP
config = DeviceConfig( config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
) )
@ -107,9 +123,10 @@ async def test_connect_logs_connect_time(
assert "seconds to update" in caplog.text assert "seconds to update" in caplog.text
async def test_connect_query_fails(discovery_data, mocker): async def test_connect_query_fails(discovery_mock, mocker):
"""Make sure that connect fails when query fails.""" """Make sure that connect fails when query fails."""
host = "127.0.0.1" host = DISCOVERY_MOCK_IP
discovery_data = discovery_mock.discovery_data
mocker.patch("kasa.IotProtocol.query", side_effect=KasaException) mocker.patch("kasa.IotProtocol.query", side_effect=KasaException)
mocker.patch("kasa.SmartProtocol.query", side_effect=KasaException) mocker.patch("kasa.SmartProtocol.query", side_effect=KasaException)
@ -125,10 +142,10 @@ async def test_connect_query_fails(discovery_data, mocker):
assert close_mock.call_count == 1 assert close_mock.call_count == 1
async def test_connect_http_client(discovery_data, mocker): async def test_connect_http_client(discovery_mock, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance.""" """Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1" host = DISCOVERY_MOCK_IP
discovery_data = discovery_mock.discovery_data
ctype, _ = _get_connection_type_device_class(discovery_data) ctype, _ = _get_connection_type_device_class(discovery_data)
http_client = aiohttp.ClientSession() http_client = aiohttp.ClientSession()
@ -157,9 +174,10 @@ async def test_connect_http_client(discovery_data, mocker):
async def test_device_types(dev: Device): async def test_device_types(dev: Device):
await dev.update() await dev.update()
if isinstance(dev, SmartDevice): if isinstance(dev, SmartDevice):
device_type = dev._discovery_info["result"]["device_type"] assert dev._discovery_info
device_type = cast(str, dev._discovery_info["result"]["device_type"])
res = SmartDevice._get_device_type_from_components( res = SmartDevice._get_device_type_from_components(
dev._components.keys(), device_type list(dev._components.keys()), device_type
) )
else: else:
res = _get_device_type_from_sys_info(dev._last_update) res = _get_device_type_from_sys_info(dev._last_update)