Initial support for vacuums (clean module) (#944)

Adds support for clean module:
- Show current vacuum state
- Start cleaning (all rooms)
- Return to dock
- Pausing & unpausing
- Controlling the fan speed

---------

Co-authored-by: Steven B <51370195+sdb9696@users.noreply.github.com>
This commit is contained in:
Teemu R. 2025-01-14 15:35:09 +01:00 committed by GitHub
parent be34dbd387
commit 1be87674bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 799 additions and 13 deletions

View File

@ -204,6 +204,7 @@ The following devices have been tested and confirmed as working. If your device
- **Cameras**: C100, C210, C225, C325WB, C520WS, C720, D230, TC65, TC70
- **Hubs**: H100, H200
- **Hub-Connected Devices[^3]**: S200B, S200D, T100, T110, T300, T310, T315
- **Vacuums**: RV20 Max Plus
<!--SUPPORTED_END-->
[^1]: Model requires authentication

View File

@ -324,6 +324,11 @@ All Tapo devices require authentication.<br>Hub-Connected Devices may work acros
- Hardware: 1.0 (EU) / Firmware: 1.7.0
- Hardware: 1.0 (US) / Firmware: 1.8.0
### Vacuums
- **RV20 Max Plus**
- Hardware: 1.0 (EU) / Firmware: 1.0.7
<!--SUPPORTED_END-->
[^1]: Model requires authentication

View File

@ -39,6 +39,7 @@ DEVICE_TYPE_TO_PRODUCT_GROUP = {
DeviceType.Hub: "Hubs",
DeviceType.Sensor: "Hub-Connected Devices",
DeviceType.Thermostat: "Hub-Connected Devices",
DeviceType.Vacuum: "Vacuums",
}

View File

@ -118,6 +118,16 @@ class SmartRequest:
enable: bool
id: str | None = None
@dataclass
class GetCleanAttrParams(SmartRequestParams):
"""CleanAttr params.
Decides which cleaning settings are requested
"""
#: type can be global or pose
type: str = "global"
@staticmethod
def get_raw_request(
method: str, params: SmartRequestParams | None = None
@ -429,6 +439,8 @@ COMPONENT_REQUESTS = {
"clean": [
SmartRequest.get_raw_request("getCleanRecords"),
SmartRequest.get_raw_request("getVacStatus"),
SmartRequest.get_raw_request("getCleanStatus"),
SmartRequest("getCleanAttr", SmartRequest.GetCleanAttrParams()),
],
"battery": [SmartRequest.get_raw_request("getBatteryInfo")],
"consumables": [SmartRequest.get_raw_request("getConsumablesInfo")],

View File

@ -159,7 +159,7 @@ def get_device_class_from_family(
"SMART.KASAHUB": SmartDevice,
"SMART.KASASWITCH": SmartDevice,
"SMART.IPCAMERA.HTTPS": SmartCamDevice,
"SMART.TAPOROBOVAC": SmartDevice,
"SMART.TAPOROBOVAC.HTTPS": SmartDevice,
"IOT.SMARTPLUGSWITCH": IotPlug,
"IOT.SMARTBULB": IotBulb,
"IOT.IPCAMERA": IotCamera,
@ -173,6 +173,9 @@ def get_device_class_from_family(
_LOGGER.debug("Unknown SMART device with %s, using SmartDevice", device_type)
cls = SmartDevice
if cls is not None:
_LOGGER.debug("Using %s for %s", cls.__name__, device_type)
return cls
@ -188,6 +191,7 @@ def get_protocol(config: DeviceConfig, *, strict: bool = False) -> BaseProtocol
"""
ctype = config.connection_type
protocol_name = ctype.device_family.value.split(".")[0]
_LOGGER.debug("Finding protocol for %s", ctype.device_family)
if ctype.device_family is DeviceFamily.SmartIpCamera:
if strict and ctype.encryption_type is not DeviceEncryptionType.Aes:

View File

@ -676,9 +676,14 @@ class Discover:
for key, val in candidates.items():
try:
prot, config = val
_LOGGER.debug("Trying to connect with %s", prot.__class__.__name__)
dev = await _connect(config, prot)
except Exception:
_LOGGER.debug("Unable to connect with %s", prot)
except Exception as ex:
_LOGGER.debug(
"Unable to connect with %s: %s",
prot.__class__.__name__,
ex,
)
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, False)
@ -686,6 +691,7 @@ class Discover:
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, True)
_LOGGER.debug("Found working protocol %s", prot.__class__.__name__)
return dev
finally:
await prot.close()

View File

@ -127,6 +127,8 @@ class SmartErrorCode(IntEnum):
DST_ERROR = -2301
DST_SAVE_ERROR = -2302
VACUUM_BATTERY_LOW = -3001
SYSTEM_ERROR = -40101
INVALID_ARGUMENTS = -40209

View File

@ -161,6 +161,9 @@ class Module(ABC):
Camera: Final[ModuleName[smartcam.Camera]] = ModuleName("Camera")
LensMask: Final[ModuleName[smartcam.LensMask]] = ModuleName("LensMask")
# Vacuum modules
Clean: Final[ModuleName[smart.Clean]] = ModuleName("Clean")
def __init__(self, device: Device, module: str) -> None:
self._device = device
self._module = module

View File

@ -7,6 +7,7 @@ from .batterysensor import BatterySensor
from .brightness import Brightness
from .childdevice import ChildDevice
from .childprotection import ChildProtection
from .clean import Clean
from .cloud import Cloud
from .color import Color
from .colortemperature import ColorTemperature
@ -66,6 +67,7 @@ __all__ = [
"TriggerLogs",
"FrostProtection",
"Thermostat",
"Clean",
"SmartLightEffect",
"OverheatProtection",
"HomeKit",

267
kasa/smart/modules/clean.py Normal file
View File

@ -0,0 +1,267 @@
"""Implementation of vacuum clean module."""
from __future__ import annotations
import logging
from enum import IntEnum
from typing import Annotated
from ...feature import Feature
from ...module import FeatureAttribute
from ..smartmodule import SmartModule
_LOGGER = logging.getLogger(__name__)
class Status(IntEnum):
"""Status of vacuum."""
Idle = 0
Cleaning = 1
Mapping = 2
GoingHome = 4
Charging = 5
Charged = 6
Paused = 7
Undocked = 8
Error = 100
UnknownInternal = -1000
class ErrorCode(IntEnum):
"""Error codes for vacuum."""
Ok = 0
SideBrushStuck = 2
MainBrushStuck = 3
WheelBlocked = 4
DustBinRemoved = 14
UnableToMove = 15
LidarBlocked = 16
UnableToFindDock = 21
BatteryLow = 22
UnknownInternal = -1000
class FanSpeed(IntEnum):
"""Fan speed level."""
Quiet = 1
Standard = 2
Turbo = 3
Max = 4
class Clean(SmartModule):
"""Implementation of vacuum clean module."""
REQUIRED_COMPONENT = "clean"
_error_code = ErrorCode.Ok
def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
self._device,
id="vacuum_return_home",
name="Return home",
container=self,
attribute_setter="return_home",
category=Feature.Category.Primary,
type=Feature.Action,
)
)
self._add_feature(
Feature(
self._device,
id="vacuum_start",
name="Start cleaning",
container=self,
attribute_setter="start",
category=Feature.Category.Primary,
type=Feature.Action,
)
)
self._add_feature(
Feature(
self._device,
id="vacuum_pause",
name="Pause",
container=self,
attribute_setter="pause",
category=Feature.Category.Primary,
type=Feature.Action,
)
)
self._add_feature(
Feature(
self._device,
id="vacuum_status",
name="Vacuum status",
container=self,
attribute_getter="status",
category=Feature.Category.Primary,
type=Feature.Type.Sensor,
)
)
self._add_feature(
Feature(
self._device,
id="vacuum_error",
name="Error",
container=self,
attribute_getter="error",
category=Feature.Category.Info,
type=Feature.Type.Sensor,
)
)
self._add_feature(
Feature(
self._device,
id="battery_level",
name="Battery level",
container=self,
attribute_getter="battery",
icon="mdi:battery",
unit_getter=lambda: "%",
category=Feature.Category.Info,
type=Feature.Type.Sensor,
)
)
self._add_feature(
Feature(
self._device,
id="vacuum_fan_speed",
name="Fan speed",
container=self,
attribute_getter="fan_speed_preset",
attribute_setter="set_fan_speed_preset",
icon="mdi:fan",
choices_getter=lambda: list(FanSpeed.__members__),
category=Feature.Category.Primary,
type=Feature.Type.Choice,
)
)
async def _post_update_hook(self) -> None:
"""Set error code after update."""
errors = self._vac_status.get("err_status")
if errors is None or not errors:
self._error_code = ErrorCode.Ok
return
if len(errors) > 1:
_LOGGER.warning(
"Multiple error codes, using the first one only: %s", errors
)
error = errors.pop(0)
try:
self._error_code = ErrorCode(error)
except ValueError:
_LOGGER.warning(
"Unknown error code, please create an issue describing the error: %s",
error,
)
self._error_code = ErrorCode.UnknownInternal
def query(self) -> dict:
"""Query to execute during the update cycle."""
return {
"getVacStatus": None,
"getBatteryInfo": None,
"getCleanStatus": None,
"getCleanAttr": {"type": "global"},
}
async def start(self) -> dict:
"""Start cleaning."""
# If we are paused, do not restart cleaning
if self.status is Status.Paused:
return await self.resume()
return await self.call(
"setSwitchClean",
{
"clean_mode": 0,
"clean_on": True,
"clean_order": True,
"force_clean": False,
},
)
async def pause(self) -> dict:
"""Pause cleaning."""
if self.status is Status.GoingHome:
return await self.set_return_home(False)
return await self.set_pause(True)
async def resume(self) -> dict:
"""Resume cleaning."""
return await self.set_pause(False)
async def set_pause(self, enabled: bool) -> dict:
"""Pause or resume cleaning."""
return await self.call("setRobotPause", {"pause": enabled})
async def return_home(self) -> dict:
"""Return home."""
return await self.set_return_home(True)
async def set_return_home(self, enabled: bool) -> dict:
"""Return home / pause returning."""
return await self.call("setSwitchCharge", {"switch_charge": enabled})
@property
def error(self) -> ErrorCode:
"""Return error."""
return self._error_code
@property
def fan_speed_preset(self) -> Annotated[str, FeatureAttribute()]:
"""Return fan speed preset."""
return FanSpeed(self._settings["suction"]).name
async def set_fan_speed_preset(
self, speed: str
) -> Annotated[dict, FeatureAttribute]:
"""Set fan speed preset."""
name_to_value = {x.name: x.value for x in FanSpeed}
if speed not in name_to_value:
raise ValueError("Invalid fan speed %s, available %s", speed, name_to_value)
return await self.call(
"setCleanAttr", {"suction": name_to_value[speed], "type": "global"}
)
@property
def battery(self) -> int:
"""Return battery level."""
return self.data["getBatteryInfo"]["battery_percentage"]
@property
def _vac_status(self) -> dict:
"""Return vac status container."""
return self.data["getVacStatus"]
@property
def _settings(self) -> dict:
"""Return cleaning settings."""
return self.data["getCleanAttr"]
@property
def status(self) -> Status:
"""Return current status."""
if self._error_code is not ErrorCode.Ok:
return Status.Error
status_code = self._vac_status["status"]
try:
return Status(status_code)
except ValueError:
_LOGGER.warning("Got unknown status code: %s (%s)", status_code, self.data)
return Status.UnknownInternal

View File

@ -134,6 +134,8 @@ SENSORS_SMART = {
}
THERMOSTATS_SMART = {"KE100"}
VACUUMS_SMART = {"RV20"}
WITH_EMETER_IOT = {"HS110", "HS300", "KP115", "KP125", *BULBS_IOT}
WITH_EMETER_SMART = {"P110", "P110M", "P115", "KP125M", "EP25", "P304M"}
WITH_EMETER = {*WITH_EMETER_IOT, *WITH_EMETER_SMART}
@ -151,6 +153,7 @@ ALL_DEVICES_SMART = (
.union(SENSORS_SMART)
.union(SWITCHES_SMART)
.union(THERMOSTATS_SMART)
.union(VACUUMS_SMART)
)
ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART)
@ -342,6 +345,7 @@ hub_smartcam = parametrize(
device_type_filter=[DeviceType.Hub],
protocol_filter={"SMARTCAM"},
)
vacuum = parametrize("vacuums", device_type_filter=[DeviceType.Vacuum])
def check_categories():
@ -360,6 +364,7 @@ def check_categories():
+ thermostats_smart.args[1]
+ camera_smartcam.args[1]
+ hub_smartcam.args[1]
+ vacuum.args[1]
)
diffs: set[FixtureInfo] = set(FIXTURE_DATA) - set(categorized_fixtures)
if diffs:

View File

@ -383,8 +383,8 @@ class FakeSmartTransport(BaseTransport):
result = copy.deepcopy(info[child_method])
retval = {"result": result, "error_code": 0}
return retval
elif child_method[:4] == "set_":
target_method = f"get_{child_method[4:]}"
elif child_method[:3] == "set":
target_method = f"get{child_method[3:]}"
if target_method not in child_device_calls:
raise RuntimeError(
f"No {target_method} in child info, calling set before get not supported."
@ -549,7 +549,7 @@ class FakeSmartTransport(BaseTransport):
return await self._handle_control_child(request_dict["params"])
params = request_dict.get("params", {})
if method in {"component_nego", "qs_component_nego"} or method[:4] == "get_":
if method in {"component_nego", "qs_component_nego"} or method[:3] == "get":
if method in info:
result = copy.deepcopy(info[method])
if result and "start_index" in result and "sum" in result:
@ -637,9 +637,14 @@ class FakeSmartTransport(BaseTransport):
return self._set_on_off_gradually_info(info, params)
elif method == "set_child_protection":
return self._update_sysinfo_key(info, "child_protection", params["enable"])
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
elif method[:3] == "set":
target_method = f"get{method[3:]}"
# Some vacuum commands do not have a getter
if method in ["setRobotPause", "setSwitchClean", "setSwitchCharge"]:
return {"error_code": 0}
info[target_method].update(params)
return {"error_code": 0}
async def close(self) -> None:

View File

@ -0,0 +1,310 @@
{
"component_nego": {
"component_list": [
{
"id": "device",
"ver_code": 1
},
{
"id": "iot_cloud",
"ver_code": 1
},
{
"id": "time",
"ver_code": 1
},
{
"id": "firmware",
"ver_code": 1
},
{
"id": "quick_setup",
"ver_code": 1
},
{
"id": "clean",
"ver_code": 3
},
{
"id": "battery",
"ver_code": 1
},
{
"id": "consumables",
"ver_code": 2
},
{
"id": "direction_control",
"ver_code": 1
},
{
"id": "button_and_led",
"ver_code": 1
},
{
"id": "speaker",
"ver_code": 3
},
{
"id": "schedule",
"ver_code": 3
},
{
"id": "wireless",
"ver_code": 1
},
{
"id": "map",
"ver_code": 2
},
{
"id": "auto_change_map",
"ver_code": -1
},
{
"id": "ble_whole_setup",
"ver_code": 1
},
{
"id": "dust_bucket",
"ver_code": 1
},
{
"id": "inherit",
"ver_code": 1
},
{
"id": "mop",
"ver_code": 1
},
{
"id": "do_not_disturb",
"ver_code": 1
},
{
"id": "device_local_time",
"ver_code": 1
},
{
"id": "charge_pose_clean",
"ver_code": 1
},
{
"id": "continue_breakpoint_sweep",
"ver_code": 1
},
{
"id": "goto_point",
"ver_code": 1
},
{
"id": "furniture",
"ver_code": 1
},
{
"id": "map_cloud_backup",
"ver_code": 1
},
{
"id": "dev_log",
"ver_code": 1
},
{
"id": "map_lock",
"ver_code": 1
},
{
"id": "carpet_area",
"ver_code": 1
},
{
"id": "clean_angle",
"ver_code": 1
},
{
"id": "clean_percent",
"ver_code": 1
},
{
"id": "no_pose_config",
"ver_code": 1
}
]
},
"discovery_result": {
"error_code": 0,
"result": {
"device_id": "00000000000000000000000000000000",
"device_model": "RV20 Max Plus(EU)",
"device_type": "SMART.TAPOROBOVAC",
"factory_default": false,
"ip": "127.0.0.123",
"is_support_iot_cloud": true,
"mac": "B0-19-21-00-00-00",
"mgt_encrypt_schm": {
"encrypt_type": "AES",
"http_port": 4433,
"is_support_https": true
},
"obd_src": "tplink",
"owner": "00000000000000000000000000000000"
}
},
"getAutoChangeMap": {
"auto_change_map": false
},
"getAutoDustCollection": {
"auto_dust_collection": 1
},
"getBatteryInfo": {
"battery_percentage": 75
},
"getCleanAttr": {"suction": 2, "cistern": 2, "clean_number": 1},
"getCleanStatus": {"getCleanStatus": {"clean_status": 0, "is_working": false, "is_mapping": false, "is_relocating": false}},
"getCleanRecords": {
"lastest_day_record": [
0,
0,
0,
0
],
"record_list": [],
"record_list_num": 0,
"total_area": 0,
"total_number": 0,
"total_time": 0
},
"getConsumablesInfo": {
"charge_contact_time": 0,
"edge_brush_time": 0,
"filter_time": 0,
"main_brush_lid_time": 0,
"rag_time": 0,
"roll_brush_time": 0,
"sensor_time": 0
},
"getCurrentVoiceLanguage": {
"name": "2",
"version": 1
},
"getDoNotDisturb": {
"do_not_disturb": true,
"e_min": 480,
"s_min": 1320
},
"getMapInfo": {
"auto_change_map": false,
"current_map_id": 0,
"map_list": [],
"map_num": 0,
"version": "LDS"
},
"getMopState": {
"mop_state": false
},
"getVacStatus": {
"err_status": [
0
],
"errorCode_id": [
0
],
"prompt": [],
"promptCode_id": [],
"status": 5
},
"get_device_info": {
"auto_pack_ver": "0.0.1.1771",
"avatar": "",
"board_sn": "000000000000",
"custom_sn": "000000000000",
"device_id": "0000000000000000000000000000000000000000",
"fw_id": "00000000000000000000000000000000",
"fw_ver": "1.0.7 Build 240828 Rel.205951",
"has_set_location_info": true,
"hw_id": "00000000000000000000000000000000",
"hw_ver": "1.0",
"ip": "127.0.0.123",
"lang": "",
"latitude": 0,
"linux_ver": "V21.198.1708420747",
"location": "",
"longitude": 0,
"mac": "B0-19-21-00-00-00",
"mcu_ver": "1.1.2563.5",
"model": "RV20 Max Plus",
"nickname": "I01BU0tFRF9OQU1FIw==",
"oem_id": "00000000000000000000000000000000",
"overheated": false,
"region": "Europe/Berlin",
"rssi": -59,
"signal_level": 2,
"specs": "",
"ssid": "I01BU0tFRF9TU0lEIw==",
"sub_ver": "0.0.1.1771-1.1.34",
"time_diff": 60,
"total_ver": "1.1.34",
"type": "SMART.TAPOROBOVAC"
},
"get_device_time": {
"region": "Europe/Berlin",
"time_diff": 60,
"timestamp": 1736598518
},
"get_fw_download_state": {
"auto_upgrade": false,
"download_progress": 0,
"reboot_time": 5,
"status": 0,
"upgrade_time": 5
},
"get_inherit_info": null,
"get_next_event": {},
"get_schedule_rules": {
"enable": false,
"rule_list": [],
"schedule_rule_max_count": 32,
"start_index": 0,
"sum": 0
},
"get_wireless_scan_info": {
"ap_list": [
{
"key_type": "wpa2_psk",
"signal_level": 2,
"ssid": "I01BU0tFRF9TU0lEIw=="
}
],
"start_index": 0,
"sum": 1,
"wep_supported": true
},
"qs_component_nego": {
"component_list": [
{
"id": "quick_setup",
"ver_code": 1
},
{
"id": "iot_cloud",
"ver_code": 1
},
{
"id": "firmware",
"ver_code": 1
},
{
"id": "ble_whole_setup",
"ver_code": 1
},
{
"id": "inherit",
"ver_code": 1
}
],
"extra_info": {
"device_model": "RV20 Max Plus",
"device_type": "SMART.TAPOROBOVAC"
}
}
}

View File

@ -0,0 +1,146 @@
from __future__ import annotations
import logging
import pytest
from pytest_mock import MockerFixture
from kasa import Module
from kasa.smart import SmartDevice
from kasa.smart.modules.clean import ErrorCode, Status
from ...device_fixtures import get_parent_and_child_modules, parametrize
clean = parametrize("clean module", component_filter="clean", protocol_filter={"SMART"})
@clean
@pytest.mark.parametrize(
("feature", "prop_name", "type"),
[
("vacuum_status", "status", Status),
("vacuum_error", "error", ErrorCode),
("vacuum_fan_speed", "fan_speed_preset", str),
("battery_level", "battery", int),
],
)
async def test_features(dev: SmartDevice, feature: str, prop_name: str, type: type):
"""Test that features are registered and work as expected."""
clean = next(get_parent_and_child_modules(dev, Module.Clean))
assert clean is not None
prop = getattr(clean, prop_name)
assert isinstance(prop, type)
feat = clean._device.features[feature]
assert feat.value == prop
assert isinstance(feat.value, type)
@pytest.mark.parametrize(
("feature", "value", "method", "params"),
[
pytest.param(
"vacuum_start",
1,
"setSwitchClean",
{
"clean_mode": 0,
"clean_on": True,
"clean_order": True,
"force_clean": False,
},
id="vacuum_start",
),
pytest.param(
"vacuum_pause", 1, "setRobotPause", {"pause": True}, id="vacuum_pause"
),
pytest.param(
"vacuum_return_home",
1,
"setSwitchCharge",
{"switch_charge": True},
id="vacuum_return_home",
),
pytest.param(
"vacuum_fan_speed",
"Quiet",
"setCleanAttr",
{"suction": 1, "type": "global"},
id="vacuum_fan_speed",
),
],
)
@clean
async def test_actions(
dev: SmartDevice,
mocker: MockerFixture,
feature: str,
value: str | int,
method: str,
params: dict,
):
"""Test the clean actions."""
clean = next(get_parent_and_child_modules(dev, Module.Clean))
call = mocker.spy(clean, "call")
await dev.features[feature].set_value(value)
call.assert_called_with(method, params)
@pytest.mark.parametrize(
("err_status", "error"),
[
pytest.param([], ErrorCode.Ok, id="empty error"),
pytest.param([0], ErrorCode.Ok, id="no error"),
pytest.param([3], ErrorCode.MainBrushStuck, id="known error"),
pytest.param([123], ErrorCode.UnknownInternal, id="unknown error"),
pytest.param([3, 4], ErrorCode.MainBrushStuck, id="multi-error"),
],
)
@clean
async def test_post_update_hook(dev: SmartDevice, err_status: list, error: ErrorCode):
"""Test that post update hook sets error states correctly."""
clean = next(get_parent_and_child_modules(dev, Module.Clean))
clean.data["getVacStatus"]["err_status"] = err_status
await clean._post_update_hook()
assert clean._error_code is error
if error is not ErrorCode.Ok:
assert clean.status is Status.Error
@clean
async def test_resume(dev: SmartDevice, mocker: MockerFixture):
"""Test that start calls resume if the state is paused."""
clean = next(get_parent_and_child_modules(dev, Module.Clean))
call = mocker.spy(clean, "call")
resume = mocker.spy(clean, "resume")
mocker.patch.object(
type(clean),
"status",
new_callable=mocker.PropertyMock,
return_value=Status.Paused,
)
await clean.start()
call.assert_called_with("setRobotPause", {"pause": False})
resume.assert_awaited()
@clean
async def test_unknown_status(
dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture
):
"""Test that unknown status is logged."""
clean = next(get_parent_and_child_modules(dev, Module.Clean))
caplog.set_level(logging.DEBUG)
clean.data["getVacStatus"]["status"] = 123
assert clean.status is Status.UnknownInternal
assert "Got unknown status code: 123" in caplog.text

View File

@ -117,7 +117,11 @@ async def test_connect_custom_port(discovery_mock, mocker, custom_port):
connection_type=ctype,
credentials=Credentials("dummy_user", "dummy_password"),
)
default_port = 80 if "result" in discovery_data else 9999
default_port = (
DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port
if "result" in discovery_data
else 9999
)
ctype, _ = _get_connection_type_device_class(discovery_data)

View File

@ -134,7 +134,14 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
discovery_mock.ip = host
discovery_mock.port_override = custom_port
device_class = Discover._get_device_class(discovery_mock.discovery_data)
disco_data = discovery_mock.discovery_data
device_class = Discover._get_device_class(disco_data)
http_port = (
DiscoveryResult.from_dict(disco_data["result"]).mgt_encrypt_schm.http_port
if "result" in disco_data
else None
)
# discovery_mock patches protocol query methods so use spy here.
update_mock = mocker.spy(device_class, "update")
@ -143,7 +150,11 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
)
assert issubclass(x.__class__, Device)
assert x._discovery_info is not None
assert x.port == custom_port or x.port == discovery_mock.default_port
assert (
x.port == custom_port
or x.port == discovery_mock.default_port
or x.port == http_port
)
# Make sure discovery does not call update()
assert update_mock.call_count == 0
if discovery_mock.default_port == 80:
@ -153,6 +164,7 @@ async def test_discover_single(discovery_mock, custom_port, mocker):
discovery_mock.device_type,
discovery_mock.encrypt_type,
discovery_mock.login_version,
discovery_mock.https,
)
config = DeviceConfig(
host=host,
@ -681,7 +693,7 @@ async def test_discover_try_connect_all(discovery_mock, mocker):
and self._transport.__class__ is transport_class
):
return discovery_mock.query_data
raise KasaException()
raise KasaException("Unable to execute query")
async def _update(self, *args, **kwargs):
if (
@ -689,7 +701,8 @@ async def test_discover_try_connect_all(discovery_mock, mocker):
and self.protocol._transport.__class__ is transport_class
):
return
raise KasaException()
raise KasaException("Unable to execute update")
mocker.patch("kasa.IotProtocol.query", new=_query)
mocker.patch("kasa.SmartProtocol.query", new=_query)