Merge remote-tracking branch 'upstream/master' into feat/dev_descriptors

This commit is contained in:
Teemu Rytilahti 2024-02-15 15:17:48 +01:00
commit 5baaa84a1d
11 changed files with 403 additions and 53 deletions

View File

@ -12,6 +12,7 @@ import collections.abc
import json import json
import logging import logging
import re import re
import traceback
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from pathlib import Path from pathlib import Path
from pprint import pprint from pprint import pprint
@ -19,7 +20,7 @@ from typing import Dict, List, Union
import asyncclick as click import asyncclick as click
from devtools.helpers.smartrequests import COMPONENT_REQUESTS, SmartRequest from devtools.helpers.smartrequests import SmartRequest, get_component_requests
from kasa import ( from kasa import (
AuthenticationException, AuthenticationException,
Credentials, Credentials,
@ -35,6 +36,8 @@ from kasa.smart import SmartDevice
Call = namedtuple("Call", "module method") Call = namedtuple("Call", "module method")
SmartCall = namedtuple("SmartCall", "module request should_succeed") SmartCall = namedtuple("SmartCall", "module request should_succeed")
_LOGGER = logging.getLogger(__name__)
def scrub(res): def scrub(res):
"""Remove identifiers from the given dict.""" """Remove identifiers from the given dict."""
@ -228,6 +231,8 @@ async def get_legacy_fixture(device):
else: else:
click.echo(click.style("OK", fg="green")) click.echo(click.style("OK", fg="green"))
successes.append((test_call, info)) successes.append((test_call, info))
finally:
await device.protocol.close()
final_query = defaultdict(defaultdict) final_query = defaultdict(defaultdict)
final = defaultdict(defaultdict) final = defaultdict(defaultdict)
@ -241,7 +246,8 @@ async def get_legacy_fixture(device):
final = await device.protocol.query(final_query) final = await device.protocol.query(final_query)
except Exception as ex: except Exception as ex:
_echo_error(f"Unable to query all successes at once: {ex}", bold=True, fg="red") _echo_error(f"Unable to query all successes at once: {ex}", bold=True, fg="red")
finally:
await device.protocol.close()
if device._discovery_info and not device._discovery_info.get("system"): if device._discovery_info and not device._discovery_info.get("system"):
# Need to recreate a DiscoverResult here because we don't want the aliases # Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device. # in the fixture, we want the actual field names as returned by the device.
@ -316,7 +322,11 @@ async def _make_requests_or_exit(
_echo_error( _echo_error(
f"Unexpected exception querying {name} at once: {ex}", f"Unexpected exception querying {name} at once: {ex}",
) )
if _LOGGER.isEnabledFor(logging.DEBUG):
traceback.print_stack()
exit(1) exit(1)
finally:
await device.protocol.close()
async def get_smart_fixture(device: SmartDevice, batch_size: int): async def get_smart_fixture(device: SmartDevice, batch_size: int):
@ -367,14 +377,15 @@ async def get_smart_fixture(device: SmartDevice, batch_size: int):
for item in component_info_response["component_list"]: for item in component_info_response["component_list"]:
component_id = item["id"] component_id = item["id"]
if requests := COMPONENT_REQUESTS.get(component_id): ver_code = item["ver_code"]
if (requests := get_component_requests(component_id, ver_code)) is not None:
component_test_calls = [ component_test_calls = [
SmartCall(module=component_id, request=request, should_succeed=True) SmartCall(module=component_id, request=request, should_succeed=True)
for request in requests for request in requests
] ]
test_calls.extend(component_test_calls) test_calls.extend(component_test_calls)
should_succeed.extend(component_test_calls) should_succeed.extend(component_test_calls)
elif component_id not in COMPONENT_REQUESTS: else:
click.echo(f"Skipping {component_id}..", nl=False) click.echo(f"Skipping {component_id}..", nl=False)
click.echo(click.style("UNSUPPORTED", fg="yellow")) click.echo(click.style("UNSUPPORTED", fg="yellow"))
@ -396,7 +407,11 @@ async def get_smart_fixture(device: SmartDevice, batch_size: int):
if ( if (
not test_call.should_succeed not test_call.should_succeed
and hasattr(ex, "error_code") and hasattr(ex, "error_code")
and ex.error_code == SmartErrorCode.UNKNOWN_METHOD_ERROR and ex.error_code
in [
SmartErrorCode.UNKNOWN_METHOD_ERROR,
SmartErrorCode.TRANSPORT_NOT_AVAILABLE_ERROR,
]
): ):
click.echo(click.style("FAIL - EXPECTED", fg="green")) click.echo(click.style("FAIL - EXPECTED", fg="green"))
else: else:
@ -410,6 +425,8 @@ async def get_smart_fixture(device: SmartDevice, batch_size: int):
else: else:
click.echo(click.style("OK", fg="green")) click.echo(click.style("OK", fg="green"))
successes.append(test_call) successes.append(test_call)
finally:
await device.protocol.close()
requests = [] requests = []
for succ in successes: for succ in successes:

View File

@ -133,11 +133,14 @@ class SmartRequest:
return SmartRequest("get_device_usage") return SmartRequest("get_device_usage")
@staticmethod @staticmethod
def device_info_list() -> List["SmartRequest"]: def device_info_list(ver_code) -> List["SmartRequest"]:
"""Get device info list.""" """Get device info list."""
if ver_code == 1:
return [SmartRequest.get_device_info()]
return [ return [
SmartRequest.get_device_info(), SmartRequest.get_device_info(),
SmartRequest.get_device_usage(), SmartRequest.get_device_usage(),
SmartRequest.get_auto_update_info(),
] ]
@staticmethod @staticmethod
@ -149,7 +152,6 @@ class SmartRequest:
def firmware_info_list() -> List["SmartRequest"]: def firmware_info_list() -> List["SmartRequest"]:
"""Get info list.""" """Get info list."""
return [ return [
SmartRequest.get_auto_update_info(),
SmartRequest.get_raw_request("get_fw_download_state"), SmartRequest.get_raw_request("get_fw_download_state"),
SmartRequest.get_raw_request("get_latest_fw"), SmartRequest.get_raw_request("get_latest_fw"),
] ]
@ -165,9 +167,13 @@ class SmartRequest:
return SmartRequest("get_device_time") return SmartRequest("get_device_time")
@staticmethod @staticmethod
def get_wireless_scan_info() -> "SmartRequest": def get_wireless_scan_info(
params: Optional[GetRulesParams] = None
) -> "SmartRequest":
"""Get wireless scan info.""" """Get wireless scan info."""
return SmartRequest("get_wireless_scan_info") return SmartRequest(
"get_wireless_scan_info", params or SmartRequest.GetRulesParams()
)
@staticmethod @staticmethod
def get_schedule_rules(params: Optional[GetRulesParams] = None) -> "SmartRequest": def get_schedule_rules(params: Optional[GetRulesParams] = None) -> "SmartRequest":
@ -294,9 +300,13 @@ class SmartRequest:
@staticmethod @staticmethod
def get_component_info_requests(component_nego_response) -> List["SmartRequest"]: def get_component_info_requests(component_nego_response) -> List["SmartRequest"]:
"""Get a list of requests based on the component info response.""" """Get a list of requests based on the component info response."""
request_list = [] request_list: List["SmartRequest"] = []
for component in component_nego_response["component_list"]: for component in component_nego_response["component_list"]:
if requests := COMPONENT_REQUESTS.get(component["id"]): if (
requests := get_component_requests(
component["id"], int(component["ver_code"])
)
) is not None:
request_list.extend(requests) request_list.extend(requests)
return request_list return request_list
@ -314,8 +324,17 @@ class SmartRequest:
return request return request
def get_component_requests(component_id, ver_code):
"""Get the requests supported by the component and version."""
if (cr := COMPONENT_REQUESTS.get(component_id)) is None:
return None
if callable(cr):
return cr(ver_code)
return cr
COMPONENT_REQUESTS = { COMPONENT_REQUESTS = {
"device": SmartRequest.device_info_list(), "device": SmartRequest.device_info_list,
"firmware": SmartRequest.firmware_info_list(), "firmware": SmartRequest.firmware_info_list(),
"quick_setup": [SmartRequest.qs_component_nego()], "quick_setup": [SmartRequest.qs_component_nego()],
"inherit": [SmartRequest.get_raw_request("get_inherit_info")], "inherit": [SmartRequest.get_raw_request("get_inherit_info")],
@ -324,33 +343,33 @@ COMPONENT_REQUESTS = {
"schedule": SmartRequest.schedule_info_list(), "schedule": SmartRequest.schedule_info_list(),
"countdown": [SmartRequest.get_countdown_rules()], "countdown": [SmartRequest.get_countdown_rules()],
"antitheft": [SmartRequest.get_antitheft_rules()], "antitheft": [SmartRequest.get_antitheft_rules()],
"account": None, "account": [],
"synchronize": None, # sync_env "synchronize": [], # sync_env
"sunrise_sunset": None, # for schedules "sunrise_sunset": [], # for schedules
"led": [SmartRequest.get_led_info()], "led": [SmartRequest.get_led_info()],
"cloud_connect": [SmartRequest.get_raw_request("get_connect_cloud_state")], "cloud_connect": [SmartRequest.get_raw_request("get_connect_cloud_state")],
"iot_cloud": None, "iot_cloud": [],
"device_local_time": None, "device_local_time": [],
"default_states": None, # in device_info "default_states": [], # in device_info
"auto_off": [SmartRequest.get_auto_off_config()], "auto_off": [SmartRequest.get_auto_off_config()],
"localSmart": None, "localSmart": [],
"energy_monitoring": SmartRequest.energy_monitoring_list(), "energy_monitoring": SmartRequest.energy_monitoring_list(),
"power_protection": SmartRequest.power_protection_list(), "power_protection": SmartRequest.power_protection_list(),
"current_protection": None, # overcurrent in device_info "current_protection": [], # overcurrent in device_info
"matter": None, "matter": [],
"preset": [SmartRequest.get_preset_rules()], "preset": [SmartRequest.get_preset_rules()],
"brightness": None, # in device_info "brightness": [], # in device_info
"color": None, # in device_info "color": [], # in device_info
"color_temperature": None, # in device_info "color_temperature": [], # in device_info
"auto_light": [SmartRequest.get_auto_light_info()], "auto_light": [SmartRequest.get_auto_light_info()],
"light_effect": [SmartRequest.get_dynamic_light_effect_rules()], "light_effect": [SmartRequest.get_dynamic_light_effect_rules()],
"bulb_quick_control": None, "bulb_quick_control": [],
"on_off_gradually": [SmartRequest.get_raw_request("get_on_off_gradually_info")], "on_off_gradually": [SmartRequest.get_raw_request("get_on_off_gradually_info")],
"light_strip": None, "light_strip": [],
"light_strip_lighting_effect": [ "light_strip_lighting_effect": [
SmartRequest.get_raw_request("get_lighting_effect") SmartRequest.get_raw_request("get_lighting_effect")
], ],
"music_rhythm": None, # music_rhythm_enable in device_info "music_rhythm": [], # music_rhythm_enable in device_info
"segment": [SmartRequest.get_raw_request("get_device_segment")], "segment": [SmartRequest.get_raw_request("get_device_segment")],
"segment_effect": [SmartRequest.get_raw_request("get_segment_effect_rule")], "segment_effect": [SmartRequest.get_raw_request("get_segment_effect_rule")],
} }

View File

@ -3,7 +3,7 @@
Based on the work of https://github.com/petretiandrea/plugp100 Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license. under compatible GNU GPL3 license.
""" """
import asyncio
import base64 import base64
import hashlib import hashlib
import logging import logging
@ -39,6 +39,7 @@ _LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400 ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20 SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
def _sha1(payload: bytes) -> str: def _sha1(payload: bytes) -> str:
@ -184,8 +185,24 @@ class AesTransport(BaseTransport):
assert self._encryption_session is not None assert self._encryption_session is not None
raw_response: str = resp_dict["result"]["response"] raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value] try:
response = self._encryption_session.decrypt(raw_response.encode())
ret_val = json_loads(response)
except Exception as ex:
try:
ret_val = json_loads(raw_response)
_LOGGER.debug(
"Received unencrypted response over secure passthrough from %s",
self._host,
)
except Exception:
raise SmartDeviceException(
f"Unable to decrypt response from {self._host}, "
+ f"error: {ex}, response: {raw_response}",
ex,
) from ex
return ret_val # type: ignore[return-value]
async def perform_login(self): async def perform_login(self):
"""Login to the device.""" """Login to the device."""
@ -199,6 +216,7 @@ class AesTransport(BaseTransport):
self._default_credentials = get_default_credentials( self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"] DEFAULT_CREDENTIALS["TAPO"]
) )
await asyncio.sleep(BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
await self.perform_handshake() await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials)) await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug( _LOGGER.debug(

View File

@ -5,6 +5,7 @@ import json
import logging import logging
import re import re
import sys import sys
from contextlib import asynccontextmanager
from functools import singledispatch, wraps from functools import singledispatch, wraps
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Dict, cast from typing import Any, Dict, cast
@ -217,7 +218,7 @@ def json_formatter_cb(result, **kwargs):
@click.option( @click.option(
"--discovery-timeout", "--discovery-timeout",
envvar="KASA_DISCOVERY_TIMEOUT", envvar="KASA_DISCOVERY_TIMEOUT",
default=3, default=5,
required=False, required=False,
show_default=True, show_default=True,
help="Timeout for discovery.", help="Timeout for discovery.",
@ -349,11 +350,16 @@ async def cli(
) )
dev = await Device.connect(config=config) dev = await Device.connect(config=config)
else: else:
echo("No --type or --device-family and --encrypt-type defined, discovering..") echo(
"No --type or --device-family and --encrypt-type defined, "
+ f"discovering for {discovery_timeout} seconds.."
)
dev = await Discover.discover_single( dev = await Discover.discover_single(
host, host,
port=port, port=port,
credentials=credentials, credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
) )
# Skip update on specific commands, or if device factory, # Skip update on specific commands, or if device factory,
@ -361,7 +367,14 @@ async def cli(
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family: if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
await dev.update() await dev.update()
ctx.obj = dev @asynccontextmanager
async def async_wrapped_device(device: Device):
try:
yield device
finally:
await device.disconnect()
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
return await ctx.invoke(state) return await ctx.invoke(state)

View File

@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
if host: if host:
config = DeviceConfig(host=host) config = DeviceConfig(host=host)
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
try:
return await _connect(config, protocol)
except:
await protocol.close()
raise
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled: if debug_enabled:
start_time = time.perf_counter() start_time = time.perf_counter()
@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
) )
start_time = time.perf_counter() start_time = time.perf_counter()
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
device_class: Optional[Type[Device]] device_class: Optional[Type[Device]]
device: Optional[Device] = None device: Optional[Device] = None

View File

@ -70,7 +70,7 @@ class SmartDevice(Device):
resp = await self.protocol.query("component_nego") resp = await self.protocol.query("component_nego")
self._components_raw = resp["component_nego"] self._components_raw = resp["component_nego"]
self._components = { self._components = {
comp["id"]: comp["ver_code"] comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"] for comp in self._components_raw["component_list"]
} }
await self._initialize_modules() await self._initialize_modules()
@ -87,9 +87,14 @@ class SmartDevice(Device):
"get_current_power": None, "get_current_power": None,
} }
if self._components["device"] >= 2:
extra_reqs = {
**extra_reqs,
"get_device_usage": None,
}
req = { req = {
"get_device_info": None, "get_device_info": None,
"get_device_usage": None,
"get_device_time": None, "get_device_time": None,
**extra_reqs, **extra_reqs,
} }
@ -97,8 +102,9 @@ class SmartDevice(Device):
resp = await self.protocol.query(req) resp = await self.protocol.query(req)
self._info = resp["get_device_info"] self._info = resp["get_device_info"]
self._usage = resp["get_device_usage"]
self._time = resp["get_device_time"] self._time = resp["get_device_time"]
# Device usage is not available on older firmware versions
self._usage = resp.get("get_device_usage", {})
# Emeter is not always available, but we set them still for now. # Emeter is not always available, but we set them still for now.
self._energy = resp.get("get_energy_usage", {}) self._energy = resp.get("get_energy_usage", {})
self._emeter = resp.get("get_current_power", {}) self._emeter = resp.get("get_current_power", {})

View File

@ -82,6 +82,7 @@ class SmartProtocol(BaseProtocol):
if retry >= retry_count: if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry) _LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue continue
except TimeoutException as ex: except TimeoutException as ex:
await self._transport.reset() await self._transport.reset()

View File

@ -0,0 +1,173 @@
{
"component_nego": {
"component_list": [
{
"id": "device",
"ver_code": 1
},
{
"id": "firmware",
"ver_code": 1
},
{
"id": "quick_setup",
"ver_code": 1
},
{
"id": "time",
"ver_code": 1
},
{
"id": "wireless",
"ver_code": 1
},
{
"id": "schedule",
"ver_code": 1
},
{
"id": "countdown",
"ver_code": 1
},
{
"id": "antitheft",
"ver_code": 1
},
{
"id": "account",
"ver_code": 1
},
{
"id": "synchronize",
"ver_code": 1
},
{
"id": "sunrise_sunset",
"ver_code": 1
},
{
"id": "led",
"ver_code": 1
},
{
"id": "cloud_connect",
"ver_code": 1
}
]
},
"discovery_result": {
"device_id": "00000000000000000000000000000000",
"device_model": "P100",
"device_type": "SMART.TAPOPLUG",
"factory_default": false,
"ip": "127.0.0.123",
"mac": "1C-3B-F3-00-00-00",
"mgt_encrypt_schm": {
"encrypt_type": "AES",
"http_port": 80,
"is_support_https": false
},
"owner": "00000000000000000000000000000000"
},
"get_antitheft_rules": {
"antitheft_rule_max_count": 1,
"enable": false,
"rule_list": []
},
"get_connect_cloud_state": {
"status": -1001
},
"get_countdown_rules": {
"countdown_rule_max_count": 1,
"enable": false,
"rule_list": []
},
"get_device_info": {
"avatar": "plug",
"device_id": "0000000000000000000000000000000000000000",
"device_on": true,
"fw_id": "00000000000000000000000000000000",
"fw_ver": "1.1.3 Build 20191017 Rel. 57937",
"has_set_location_info": true,
"hw_id": "00000000000000000000000000000000",
"hw_ver": "1.0.0",
"ip": "127.0.0.123",
"latitude": 0,
"location": "hallway",
"longitude": 0,
"mac": "1C-3B-F3-00-00-00",
"model": "P100",
"nickname": "I01BU0tFRF9OQU1FIw==",
"oem_id": "00000000000000000000000000000000",
"on_time": 6868,
"overheated": false,
"signal_level": 2,
"specs": "US",
"ssid": "I01BU0tFRF9TU0lEIw==",
"time_usage_past30": 114,
"time_usage_past7": 114,
"time_usage_today": 114,
"type": "SMART.TAPOPLUG"
},
"get_device_time": {
"region": "Europe/London",
"time_diff": 0,
"timestamp": 1707905077
},
"get_fw_download_state": {
"download_progress": 0,
"reboot_time": 10,
"status": 0,
"upgrade_time": 0
},
"get_latest_fw": {
"fw_size": 786432,
"fw_ver": "1.3.7 Build 20230711 Rel.61904",
"hw_id": "00000000000000000000000000000000",
"need_to_upgrade": true,
"oem_id": "00000000000000000000000000000000",
"release_date": "2023-07-26",
"release_note": "Modifications and Bug fixes:\nEnhanced device security.",
"type": 3
},
"get_led_info": {
"led_rule": "always",
"led_status": true
},
"get_next_event": {
"action": -1,
"e_time": 0,
"id": "0",
"s_time": 0,
"type": 0
},
"get_schedule_rules": {
"enable": false,
"rule_list": [],
"schedule_rule_max_count": 20,
"start_index": 0,
"sum": 0
},
"get_wireless_scan_info": {
"ap_list": [],
"start_index": 0,
"sum": 0,
"wep_supported": false
},
"qs_component_nego": {
"component_list": [
{
"id": "quick_setup",
"ver_code": 1
},
{
"id": "sunrise_sunset",
"ver_code": 1
}
],
"extra_info": {
"device_model": "P100",
"device_type": "SMART.TAPOPLUG"
}
}
}

View File

@ -1,5 +1,6 @@
import base64 import base64
import json import json
import logging
import random import random
import string import string
import time import time
@ -180,6 +181,67 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
assert "result" in res assert "result" in res
async def test_unencrypted_response(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, 0, do_not_encrypt_response=True)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
res = await transport.send(json_dumps(request))
assert "result" in res
assert (
"Received unencrypted response over secure passthrough from 127.0.0.1"
in caplog.text
)
async def test_unencrypted_response_invalid_json(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(
host, 200, 0, 0, do_not_encrypt_response=True, send_response=b"Foobar"
)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
msg = f"Unable to decrypt response from {host}, error: Incorrect padding, response: Foobar"
with pytest.raises(SmartDeviceException, match=msg):
await transport.send(json_dumps(request))
ERRORS = [e for e in SmartErrorCode if e != 0] ERRORS = [e for e in SmartErrorCode if e != 0]
@ -233,15 +295,28 @@ class MockAesDevice:
pass pass
async def read(self): async def read(self):
return json_dumps(self._json).encode() if isinstance(self._json, dict):
return json_dumps(self._json).encode()
return self._json
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): def __init__(
self,
host,
status_code=200,
error_code=0,
inner_error_code=0,
*,
do_not_encrypt_response=False,
send_response=None,
):
self.host = host self.host = host
self.status_code = status_code self.status_code = status_code
self.error_code = error_code self.error_code = error_code
self._inner_error_code = inner_error_code self._inner_error_code = inner_error_code
self.do_not_encrypt_response = do_not_encrypt_response
self.send_response = send_response
self.http_client = HttpClient(DeviceConfig(self.host)) self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0 self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
@ -289,13 +364,15 @@ class MockAesDevice:
decrypted_request_dict = json_loads(decrypted_request) decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict) decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response: async with decrypted_response:
response_data = await decrypted_response.read() decrypted_response_data = await decrypted_response.read()
decrypted_response_dict = json_loads(response_data.decode()) encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
encrypted_response = self.encryption_session.encrypt( response = (
json_dumps(decrypted_response_dict).encode() decrypted_response_data
if self.do_not_encrypt_response
else encrypted_response
) )
result = { result = {
"result": {"response": encrypted_response.decode()}, "result": {"response": response.decode()},
"error_code": self.error_code, "error_code": self.error_code,
} }
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, result)
@ -310,5 +387,6 @@ class MockAesDevice:
async def _return_send_response(self, url: URL, json: Dict[str, Any]): async def _return_send_response(self, url: URL, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code} result = {"result": {"method": None}, "error_code": self.inner_error_code}
response = self.send_response if self.send_response else result
self.inner_call_count += 1 self.inner_call_count += 1
return self._mock_response(self.status_code, result) return self._mock_response(self.status_code, response)

View File

@ -7,6 +7,7 @@ from asyncclick.testing import CliRunner
from kasa import ( from kasa import (
AuthenticationException, AuthenticationException,
Credentials,
Device, Device,
EmeterStatus, EmeterStatus,
SmartDeviceException, SmartDeviceException,
@ -351,7 +352,9 @@ async def test_credentials(discovery_mock, mocker):
async def test_without_device_type(dev, mocker): async def test_without_device_type(dev, mocker):
"""Test connecting without the device type.""" """Test connecting without the device type."""
runner = CliRunner() runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev) discovery_mock = mocker.patch(
"kasa.discover.Discover.discover_single", return_value=dev
)
# These will mock the features to avoid accessing non-existing # These will mock the features to avoid accessing non-existing
mocker.patch("kasa.device.Device.features", return_value={}) mocker.patch("kasa.device.Device.features", return_value={})
mocker.patch("kasa.iot.iotdevice.IotDevice.features", return_value={}) mocker.patch("kasa.iot.iotdevice.IotDevice.features", return_value={})
@ -365,9 +368,18 @@ async def test_without_device_type(dev, mocker):
"foo", "foo",
"--password", "--password",
"bar", "bar",
"--discovery-timeout",
"7",
], ],
) )
assert res.exit_code == 0 assert res.exit_code == 0
discovery_mock.assert_called_once_with(
"127.0.0.1",
port=None,
credentials=Credentials("foo", "bar"),
timeout=5,
discovery_timeout=7,
)
@pytest.mark.parametrize("auth_param", ["--username", "--password"]) @pytest.mark.parametrize("auth_param", ["--username", "--password"])

View File

@ -53,7 +53,7 @@ async def test_connect(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
) )
protocol_class = get_protocol(config).__class__ protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
dev = await connect( dev = await connect(
config=config, config=config,
) )
@ -61,8 +61,9 @@ async def test_connect(
assert isinstance(dev.protocol, protocol_class) assert isinstance(dev.protocol, protocol_class)
assert dev.config == config assert dev.config == config
assert close_mock.call_count == 0
await dev.disconnect() await dev.disconnect()
assert close_mock.call_count == 1
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
config = DeviceConfig( config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
) )
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
assert close_mock.call_count == 0
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await connect(config=config) await connect(config=config)
assert close_mock.call_count == 1
async def test_connect_http_client(all_fixture_data, mocker): async def test_connect_http_client(all_fixture_data, mocker):