Merge remote-tracking branch 'upstream/master' into feat/dev_descriptors

This commit is contained in:
Teemu Rytilahti
2024-02-15 15:17:48 +01:00
11 changed files with 403 additions and 53 deletions

View File

@@ -3,7 +3,7 @@
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import asyncio
import base64
import hashlib
import logging
@@ -39,6 +39,7 @@ _LOGGER = logging.getLogger(__name__)
ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1
def _sha1(payload: bytes) -> str:
@@ -184,8 +185,24 @@ class AesTransport(BaseTransport):
assert self._encryption_session is not None
raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value]
try:
response = self._encryption_session.decrypt(raw_response.encode())
ret_val = json_loads(response)
except Exception as ex:
try:
ret_val = json_loads(raw_response)
_LOGGER.debug(
"Received unencrypted response over secure passthrough from %s",
self._host,
)
except Exception:
raise SmartDeviceException(
f"Unable to decrypt response from {self._host}, "
+ f"error: {ex}, response: {raw_response}",
ex,
) from ex
return ret_val # type: ignore[return-value]
async def perform_login(self):
"""Login to the device."""
@@ -199,6 +216,7 @@ class AesTransport(BaseTransport):
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await asyncio.sleep(BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(

View File

@@ -5,6 +5,7 @@ import json
import logging
import re
import sys
from contextlib import asynccontextmanager
from functools import singledispatch, wraps
from pprint import pformat as pf
from typing import Any, Dict, cast
@@ -217,7 +218,7 @@ def json_formatter_cb(result, **kwargs):
@click.option(
"--discovery-timeout",
envvar="KASA_DISCOVERY_TIMEOUT",
default=3,
default=5,
required=False,
show_default=True,
help="Timeout for discovery.",
@@ -349,11 +350,16 @@ async def cli(
)
dev = await Device.connect(config=config)
else:
echo("No --type or --device-family and --encrypt-type defined, discovering..")
echo(
"No --type or --device-family and --encrypt-type defined, "
+ f"discovering for {discovery_timeout} seconds.."
)
dev = await Discover.discover_single(
host,
port=port,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
)
# Skip update on specific commands, or if device factory,
@@ -361,7 +367,14 @@ async def cli(
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
await dev.update()
ctx.obj = dev
@asynccontextmanager
async def async_wrapped_device(device: Device):
try:
yield device
finally:
await device.disconnect()
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
if ctx.invoked_subcommand is None:
return await ctx.invoke(state)

View File

@@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
if host:
config = DeviceConfig(host=host)
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
try:
return await _connect(config, protocol)
except:
await protocol.close()
raise
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled:
start_time = time.perf_counter()
@@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
)
start_time = time.perf_counter()
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
device_class: Optional[Type[Device]]
device: Optional[Device] = None

View File

@@ -70,7 +70,7 @@ class SmartDevice(Device):
resp = await self.protocol.query("component_nego")
self._components_raw = resp["component_nego"]
self._components = {
comp["id"]: comp["ver_code"]
comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"]
}
await self._initialize_modules()
@@ -87,9 +87,14 @@ class SmartDevice(Device):
"get_current_power": None,
}
if self._components["device"] >= 2:
extra_reqs = {
**extra_reqs,
"get_device_usage": None,
}
req = {
"get_device_info": None,
"get_device_usage": None,
"get_device_time": None,
**extra_reqs,
}
@@ -97,8 +102,9 @@ class SmartDevice(Device):
resp = await self.protocol.query(req)
self._info = resp["get_device_info"]
self._usage = resp["get_device_usage"]
self._time = resp["get_device_time"]
# Device usage is not available on older firmware versions
self._usage = resp.get("get_device_usage", {})
# Emeter is not always available, but we set them still for now.
self._energy = resp.get("get_energy_usage", {})
self._emeter = resp.get("get_current_power", {})

View File

@@ -82,6 +82,7 @@ class SmartProtocol(BaseProtocol):
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except TimeoutException as ex:
await self._transport.reset()

View File

@@ -0,0 +1,173 @@
{
"component_nego": {
"component_list": [
{
"id": "device",
"ver_code": 1
},
{
"id": "firmware",
"ver_code": 1
},
{
"id": "quick_setup",
"ver_code": 1
},
{
"id": "time",
"ver_code": 1
},
{
"id": "wireless",
"ver_code": 1
},
{
"id": "schedule",
"ver_code": 1
},
{
"id": "countdown",
"ver_code": 1
},
{
"id": "antitheft",
"ver_code": 1
},
{
"id": "account",
"ver_code": 1
},
{
"id": "synchronize",
"ver_code": 1
},
{
"id": "sunrise_sunset",
"ver_code": 1
},
{
"id": "led",
"ver_code": 1
},
{
"id": "cloud_connect",
"ver_code": 1
}
]
},
"discovery_result": {
"device_id": "00000000000000000000000000000000",
"device_model": "P100",
"device_type": "SMART.TAPOPLUG",
"factory_default": false,
"ip": "127.0.0.123",
"mac": "1C-3B-F3-00-00-00",
"mgt_encrypt_schm": {
"encrypt_type": "AES",
"http_port": 80,
"is_support_https": false
},
"owner": "00000000000000000000000000000000"
},
"get_antitheft_rules": {
"antitheft_rule_max_count": 1,
"enable": false,
"rule_list": []
},
"get_connect_cloud_state": {
"status": -1001
},
"get_countdown_rules": {
"countdown_rule_max_count": 1,
"enable": false,
"rule_list": []
},
"get_device_info": {
"avatar": "plug",
"device_id": "0000000000000000000000000000000000000000",
"device_on": true,
"fw_id": "00000000000000000000000000000000",
"fw_ver": "1.1.3 Build 20191017 Rel. 57937",
"has_set_location_info": true,
"hw_id": "00000000000000000000000000000000",
"hw_ver": "1.0.0",
"ip": "127.0.0.123",
"latitude": 0,
"location": "hallway",
"longitude": 0,
"mac": "1C-3B-F3-00-00-00",
"model": "P100",
"nickname": "I01BU0tFRF9OQU1FIw==",
"oem_id": "00000000000000000000000000000000",
"on_time": 6868,
"overheated": false,
"signal_level": 2,
"specs": "US",
"ssid": "I01BU0tFRF9TU0lEIw==",
"time_usage_past30": 114,
"time_usage_past7": 114,
"time_usage_today": 114,
"type": "SMART.TAPOPLUG"
},
"get_device_time": {
"region": "Europe/London",
"time_diff": 0,
"timestamp": 1707905077
},
"get_fw_download_state": {
"download_progress": 0,
"reboot_time": 10,
"status": 0,
"upgrade_time": 0
},
"get_latest_fw": {
"fw_size": 786432,
"fw_ver": "1.3.7 Build 20230711 Rel.61904",
"hw_id": "00000000000000000000000000000000",
"need_to_upgrade": true,
"oem_id": "00000000000000000000000000000000",
"release_date": "2023-07-26",
"release_note": "Modifications and Bug fixes:\nEnhanced device security.",
"type": 3
},
"get_led_info": {
"led_rule": "always",
"led_status": true
},
"get_next_event": {
"action": -1,
"e_time": 0,
"id": "0",
"s_time": 0,
"type": 0
},
"get_schedule_rules": {
"enable": false,
"rule_list": [],
"schedule_rule_max_count": 20,
"start_index": 0,
"sum": 0
},
"get_wireless_scan_info": {
"ap_list": [],
"start_index": 0,
"sum": 0,
"wep_supported": false
},
"qs_component_nego": {
"component_list": [
{
"id": "quick_setup",
"ver_code": 1
},
{
"id": "sunrise_sunset",
"ver_code": 1
}
],
"extra_info": {
"device_model": "P100",
"device_type": "SMART.TAPOPLUG"
}
}
}

View File

@@ -1,5 +1,6 @@
import base64
import json
import logging
import random
import string
import time
@@ -180,6 +181,67 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
assert "result" in res
async def test_unencrypted_response(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, 0, do_not_encrypt_response=True)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
res = await transport.send(json_dumps(request))
assert "result" in res
assert (
"Received unencrypted response over secure passthrough from 127.0.0.1"
in caplog.text
)
async def test_unencrypted_response_invalid_json(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(
host, 200, 0, 0, do_not_encrypt_response=True, send_response=b"Foobar"
)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)
request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
msg = f"Unable to decrypt response from {host}, error: Incorrect padding, response: Foobar"
with pytest.raises(SmartDeviceException, match=msg):
await transport.send(json_dumps(request))
ERRORS = [e for e in SmartErrorCode if e != 0]
@@ -233,15 +295,28 @@ class MockAesDevice:
pass
async def read(self):
return json_dumps(self._json).encode()
if isinstance(self._json, dict):
return json_dumps(self._json).encode()
return self._json
encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
def __init__(
self,
host,
status_code=200,
error_code=0,
inner_error_code=0,
*,
do_not_encrypt_response=False,
send_response=None,
):
self.host = host
self.status_code = status_code
self.error_code = error_code
self._inner_error_code = inner_error_code
self.do_not_encrypt_response = do_not_encrypt_response
self.send_response = send_response
self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
@@ -289,13 +364,15 @@ class MockAesDevice:
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response:
response_data = await decrypted_response.read()
decrypted_response_dict = json_loads(response_data.decode())
encrypted_response = self.encryption_session.encrypt(
json_dumps(decrypted_response_dict).encode()
decrypted_response_data = await decrypted_response.read()
encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
response = (
decrypted_response_data
if self.do_not_encrypt_response
else encrypted_response
)
result = {
"result": {"response": encrypted_response.decode()},
"result": {"response": response.decode()},
"error_code": self.error_code,
}
return self._mock_response(self.status_code, result)
@@ -310,5 +387,6 @@ class MockAesDevice:
async def _return_send_response(self, url: URL, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
response = self.send_response if self.send_response else result
self.inner_call_count += 1
return self._mock_response(self.status_code, result)
return self._mock_response(self.status_code, response)

View File

@@ -7,6 +7,7 @@ from asyncclick.testing import CliRunner
from kasa import (
AuthenticationException,
Credentials,
Device,
EmeterStatus,
SmartDeviceException,
@@ -351,7 +352,9 @@ async def test_credentials(discovery_mock, mocker):
async def test_without_device_type(dev, mocker):
"""Test connecting without the device type."""
runner = CliRunner()
mocker.patch("kasa.discover.Discover.discover_single", return_value=dev)
discovery_mock = mocker.patch(
"kasa.discover.Discover.discover_single", return_value=dev
)
# These will mock the features to avoid accessing non-existing
mocker.patch("kasa.device.Device.features", return_value={})
mocker.patch("kasa.iot.iotdevice.IotDevice.features", return_value={})
@@ -365,9 +368,18 @@ async def test_without_device_type(dev, mocker):
"foo",
"--password",
"bar",
"--discovery-timeout",
"7",
],
)
assert res.exit_code == 0
discovery_mock.assert_called_once_with(
"127.0.0.1",
port=None,
credentials=Credentials("foo", "bar"),
timeout=5,
discovery_timeout=7,
)
@pytest.mark.parametrize("auth_param", ["--username", "--password"])

View File

@@ -53,7 +53,7 @@ async def test_connect(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
dev = await connect(
config=config,
)
@@ -61,8 +61,9 @@ async def test_connect(
assert isinstance(dev.protocol, protocol_class)
assert dev.config == config
assert close_mock.call_count == 0
await dev.disconnect()
assert close_mock.call_count == 1
@pytest.mark.parametrize("custom_port", [123, None])
@@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
assert close_mock.call_count == 0
with pytest.raises(SmartDeviceException):
await connect(config=config)
assert close_mock.call_count == 1
async def test_connect_http_client(all_fixture_data, mocker):