diff --git a/kasa/cli/wifi.py b/kasa/cli/wifi.py index 924e83f1..0fc7bdd6 100644 --- a/kasa/cli/wifi.py +++ b/kasa/cli/wifi.py @@ -6,6 +6,7 @@ import asyncclick as click from kasa import ( Device, + KasaException, ) from .common import ( @@ -15,8 +16,7 @@ from .common import ( @click.group() -@pass_dev -def wifi(dev) -> None: +def wifi() -> None: """Commands to control wifi settings.""" @@ -35,13 +35,23 @@ async def scan(dev): @wifi.command() @click.argument("ssid") -@click.option("--keytype", prompt=True) +@click.option( + "--keytype", + default="", + help="KeyType (Not needed for SmartCamDevice).", +) @click.option("--password", prompt=True, hide_input=True) @pass_dev async def join(dev: Device, ssid: str, password: str, keytype: str): """Join the given wifi network.""" echo(f"Asking the device to connect to {ssid}..") - res = await dev.wifi_join(ssid, password, keytype=keytype) + try: + res = await dev.wifi_join(ssid, password, keytype=keytype) + except KasaException as e: + if type(e) is KasaException: + echo(str(e)) + return + raise echo( f"Response: {res} - if the device is not able to join the network, " f"it will revert back to its previous state." diff --git a/kasa/device.py b/kasa/device.py index 45763db3..efd74c13 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -138,15 +138,18 @@ class WifiNetwork: """Wifi network container.""" ssid: str - key_type: int + # This is available on both netif and on softaponboarding + key_type: int | None = None # These are available only on softaponboarding cipher_type: int | None = None - bssid: str | None = None channel: int | None = None + # These are available on softaponboarding, SMART, and SMARTCAM devices + bssid: str | None = None rssi: int | None = None - - # For SMART devices + # These are available on both SMART and SMARTCAM devices signal_level: int | None = None + auth: int | None = None + encryption: int | None = None _LOGGER = logging.getLogger(__name__) diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 36aba3e5..90bac705 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -688,6 +688,9 @@ class IotDevice(Device): async def _join(target: str, payload: dict) -> dict: return await self._query_helper(target, "set_stainfo", payload) + if not keytype: + raise KasaException("KeyType is required for this device.") + payload = {"ssid": ssid, "password": password, "key_type": int(keytype)} try: return await _join("netif", payload) diff --git a/kasa/protocols/smartprotocol.py b/kasa/protocols/smartprotocol.py index 5539de77..ad3e7331 100644 --- a/kasa/protocols/smartprotocol.py +++ b/kasa/protocols/smartprotocol.py @@ -92,6 +92,7 @@ REDACTORS: dict[str, Callable[[Any], Any] | None] = { # Queries that are known not to work properly when sent as a # multiRequest. They will not return the `method` key. FORCE_SINGLE_REQUEST = { + "connectAp", "getConnectStatus", "scanApList", } diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 87aa628d..6be2392c 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -769,6 +769,9 @@ class SmartDevice(Device): if not self.credentials: raise AuthenticationError("Device requires authentication.") + if not keytype: + raise KasaException("KeyType is required for this device.") + payload = { "account": { "username": base64.b64encode( diff --git a/kasa/smartcam/smartcamdevice.py b/kasa/smartcam/smartcamdevice.py index 3beda36b..7bc6184f 100644 --- a/kasa/smartcam/smartcamdevice.py +++ b/kasa/smartcam/smartcamdevice.py @@ -2,12 +2,20 @@ from __future__ import annotations +import base64 import logging from typing import Any, cast -from ..device import DeviceInfo +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + +from ..device import DeviceInfo, WifiNetwork from ..device_type import DeviceType +from ..deviceconfig import DeviceConfig +from ..exceptions import AuthenticationError, DeviceError, KasaException from ..module import Module +from ..protocols import SmartProtocol from ..protocols.smartcamprotocol import _ChildCameraProtocolWrapper from ..smart import SmartChildDevice, SmartDevice from ..smart.smartdevice import ComponentsRaw @@ -23,6 +31,24 @@ class SmartCamDevice(SmartDevice): # Modules that are called as part of the init procedure on first update FIRST_UPDATE_MODULES = {DeviceModule, ChildDevice} + STATIC_PUBLIC_KEY_B64 = ( + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC4D6i0oD/Ga5qb//RfSe8MrPVI" + "rMIGecCxkcGWGj9kxxk74qQNq8XUuXoy2PczQ30BpiRHrlkbtBEPeWLpq85tfubT" + "UjhBz1NPNvWrC88uaYVGvzNpgzZOqDC35961uPTuvdUa8vztcUQjEZy16WbmetRj" + "URFIiWJgFCmemyYVbQIDAQAB" + ) + + def __init__( + self, + host: str, + *, + config: DeviceConfig | None = None, + protocol: SmartProtocol | None = None, + ) -> None: + super().__init__(host, config=config, protocol=protocol) + self._public_key: str | None = None + self._networks: list[WifiNetwork] = [] + @staticmethod def _get_device_type_from_sysinfo(sysinfo: dict[str, Any]) -> DeviceType: """Find type to be displayed as a supported device category.""" @@ -288,3 +314,79 @@ class SmartCamDevice(SmartDevice): def rssi(self) -> int | None: """Return the device id.""" return self.modules[SmartCamModule.SmartCamDeviceModule].rssi + + async def wifi_scan(self) -> list[WifiNetwork]: + """Scan for available wifi networks.""" + + def _net_for_scan_info(res: dict) -> WifiNetwork: + return WifiNetwork( + ssid=res["ssid"], + auth=res["auth"], + encryption=res["encryption"], + rssi=res["rssi"], + bssid=res["bssid"], + ) + + _LOGGER.debug("Querying networks") + + resp = await self._query_helper("scanApList", {"onboarding": {"scan": {}}}) + scan_data: dict = resp["scanApList"]["onboarding"]["scan"] + self._public_key = scan_data.get("publicKey", "") + self._networks = [_net_for_scan_info(net) for net in scan_data["ap_list"]] + return self._networks + + async def wifi_join( + self, ssid: str, password: str, keytype: str = "wpa2_psk" + ) -> dict: + """Join the given wifi network. + + This method returns nothing as the device tries to activate the new + settings immediately instead of responding to the request. + + If joining the network fails, the device will return to the previous state + after some delay. + """ + if not self.credentials: + raise AuthenticationError("Device requires authentication.") + + if not self._networks: + await self.wifi_scan() + net = next( + (n for n in self._networks if getattr(n, "ssid", None) == ssid), None + ) + if net is None: + raise DeviceError(f"Network with SSID '{ssid}' not found.") + + public_key_b64 = self._public_key or self.STATIC_PUBLIC_KEY_B64 + key_bytes = base64.b64decode(public_key_b64) + public_key = serialization.load_der_public_key(key_bytes) + if not isinstance(public_key, RSAPublicKey): + raise TypeError("Loaded public key is not an RSA public key") + encrypted = public_key.encrypt(password.encode(), padding.PKCS1v15()) + encrypted_password = base64.b64encode(encrypted).decode() + + payload = { + "onboarding": { + "connect": { + "auth": net.auth, + "bssid": net.bssid, + "encryption": net.encryption, + "password": encrypted_password, + "rssi": net.rssi, + "ssid": net.ssid, + } + } + } + + # The device does not respond to the request but changes the settings + # immediately which causes us to timeout. + # Thus, We limit retries and suppress the raised exception as useless. + try: + return await self.protocol.query({"connectAp": payload}, retry_count=0) + except DeviceError: + raise # Re-raise on device-reported errors + except KasaException: + _LOGGER.debug( + "Received a kasa exception for wifi join, but this is expected" + ) + return {} diff --git a/tests/fakeprotocol_smartcam.py b/tests/fakeprotocol_smartcam.py index 5cd291b3..17160233 100644 --- a/tests/fakeprotocol_smartcam.py +++ b/tests/fakeprotocol_smartcam.py @@ -256,6 +256,52 @@ class FakeSmartCamTransport(BaseTransport): method = request_dict["method"] info = self.info + if method == "connectAp": + if self.verbatim: + return {"error_code": -1} + return {"result": {}, "error_code": 0} + if method == "scanApList": + if method in info: + result = self._get_method_from_info(method, request_dict.get("params")) + if not self.verbatim: + scan = ( + result.get("result", {}).get("onboarding", {}).get("scan", {}) + ) + ap_list = scan.get("ap_list") + if isinstance(ap_list, list) and not any( + ap.get("ssid") == "FOOBAR" for ap in ap_list + ): + ap_list.append( + { + "ssid": "FOOBAR", + "auth": 3, + "encryption": 3, + "rssi": -40, + "bssid": "00:00:00:00:00:00", + } + ) + return result + if self.verbatim: + return {"error_code": -1} + return { + "result": { + "onboarding": { + "scan": { + "publicKey": "", + "ap_list": [ + { + "ssid": "FOOBAR", + "auth": 3, + "encryption": 3, + "rssi": -40, + "bssid": "00:00:00:00:00:00", + } + ], + } + } + }, + "error_code": 0, + } if method == "controlChild": return await self._handle_control_child( request_dict["params"]["childControl"] diff --git a/tests/smartcam/test_smartcamdevice.py b/tests/smartcam/test_smartcamdevice.py index 8675b693..58ab2fd9 100644 --- a/tests/smartcam/test_smartcamdevice.py +++ b/tests/smartcam/test_smartcamdevice.py @@ -2,12 +2,16 @@ from __future__ import annotations +import base64 from datetime import UTC, datetime +from unittest.mock import AsyncMock, PropertyMock, patch import pytest from freezegun.api import FrozenDateTimeFactory from kasa import Device, DeviceType, Module +from kasa.exceptions import AuthenticationError, DeviceError, KasaException +from kasa.smartcam import SmartCamDevice from ..conftest import device_smartcam, hub_smartcam @@ -34,7 +38,7 @@ async def test_state(dev: Device): @device_smartcam -async def test_alias(dev): +async def test_alias(dev: Device): test_alias = "TEST1234" original = dev.alias @@ -49,7 +53,7 @@ async def test_alias(dev): @hub_smartcam -async def test_hub(dev): +async def test_hub(dev: Device): assert dev.children for child in dev.children: assert child.modules @@ -60,6 +64,95 @@ async def test_hub(dev): assert child.device_id +@device_smartcam +async def test_wifi_scan(dev: SmartCamDevice): + fake_scan_data = { + "scanApList": { + "onboarding": { + "scan": { + "publicKey": base64.b64encode(b"fakekey").decode(), + "ap_list": [ + { + "ssid": "TestSSID", + "auth": "WPA2", + "encryption": "AES", + "rssi": -40, + "bssid": "00:11:22:33:44:55", + } + ], + } + } + } + } + with patch.object(dev, "_query_helper", AsyncMock(return_value=fake_scan_data)): + networks = await dev.wifi_scan() + assert len(networks) == 1 + net = networks[0] + assert net.ssid == "TestSSID" + assert net.auth == "WPA2" + assert net.encryption == "AES" + assert net.rssi == -40 + assert net.bssid == "00:11:22:33:44:55" + assert dev._public_key == base64.b64encode(b"fakekey").decode() + + +@device_smartcam +async def test_wifi_join_success_and_errors(dev: SmartCamDevice): + dev._networks = [ + type( + "WifiNetwork", + (), + { + "ssid": "TestSSID", + "auth": "WPA2", + "encryption": "AES", + "rssi": -40, + "bssid": "00:11:22:33:44:55", + }, + )() + ] + with patch.object(type(dev), "credentials", new_callable=PropertyMock) as cred_mock: + cred_mock.return_value = object() + with patch.object(dev.protocol, "query", AsyncMock(return_value={})): + result = await dev.wifi_join("TestSSID", "password123") + assert isinstance(result, dict) + cred_mock.return_value = None + with pytest.raises(AuthenticationError): + await dev.wifi_join("TestSSID", "password123") + cred_mock.return_value = object() + dev._networks = [] + with ( + patch.object(dev, "wifi_scan", AsyncMock(return_value=[])), + pytest.raises(DeviceError), + ): + await dev.wifi_join("TestSSID", "password123") + dev._networks = [ + type( + "WifiNetwork", + (), + { + "ssid": "TestSSID", + "auth": "WPA2", + "encryption": "AES", + "rssi": -40, + "bssid": "00:11:22:33:44:55", + }, + )() + ] + with ( + patch.object( + dev.protocol, "query", AsyncMock(side_effect=DeviceError("fail")) + ), + pytest.raises(DeviceError), + ): + await dev.wifi_join("TestSSID", "password123") + with patch.object( + dev.protocol, "query", AsyncMock(side_effect=KasaException("fail")) + ): + result = await dev.wifi_join("TestSSID", "password123") + assert result == {} + + @device_smartcam async def test_device_time(dev: Device, freezer: FrozenDateTimeFactory): """Test a child device gets the time from it's parent module.""" @@ -69,3 +162,36 @@ async def test_device_time(dev: Device, freezer: FrozenDateTimeFactory): await module.set_time(fallback_time) await dev.update() assert dev.time == fallback_time + + +@device_smartcam +async def test_wifi_join_typeerror_on_non_rsa_key(dev: SmartCamDevice): + dev._networks = [ + type( + "WifiNetwork", + (), + { + "ssid": "TestSSID", + "auth": "WPA2", + "encryption": "AES", + "rssi": -40, + "bssid": "00:11:22:33:44:55", + }, + )() + ] + with patch.object(type(dev), "credentials", new_callable=PropertyMock) as cred_mock: + cred_mock.return_value = object() + with ( + patch( + "cryptography.hazmat.primitives.serialization.load_der_public_key", + return_value=object(), + ), + patch( + "kasa.smartcam.smartcamdevice.RSAPublicKey", + new=type("FakeRSA", (), {}), + ), + pytest.raises( + TypeError, match="Loaded public key is not an RSA public key" + ), + ): + await dev.wifi_join("TestSSID", "password123") diff --git a/tests/test_cli.py b/tests/test_cli.py index 627959e7..c5c06b7b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -49,10 +49,13 @@ from kasa.smart import SmartDevice from kasa.smartcam import SmartCamDevice from .conftest import ( + device_iot, device_smart, + device_smartcam, get_device_for_fixture_protocol, handle_turn_on, new_discovery, + parametrize_combine, turn_on, ) @@ -359,12 +362,47 @@ async def test_wifi_scan(dev, runner): assert re.search(r"Found [\d]+ wifi networks!", res.output) -@device_smart +@parametrize_combine([device_smart, device_iot]) async def test_wifi_join(dev, mocker, runner): update = mocker.patch.object(dev, "update") res = await runner.invoke( wifi, - ["join", "FOOBAR", "--keytype", "wpa_psk", "--password", "foobar"], + ["join", "FOOBAR", "--keytype", "3", "--password", "foobar"], + obj=dev, + ) + + # Make sure that update was not called for wifi + with pytest.raises(AssertionError): + update.assert_called() + + assert res.exit_code == 0 + assert "Asking the device to connect to FOOBAR" in res.output + + +@parametrize_combine([device_smart, device_iot]) +async def test_wifi_join_missing_keytype(dev, mocker, runner): + """Test that missing keytype raises KasaException and CLI echoes the message.""" + update = mocker.patch.object(dev, "update") + res = await runner.invoke( + wifi, + ["join", "FOOBAR", "--password", "foobar"], + obj=dev, + ) + + # Make sure that update was not called for wifi + with pytest.raises(AssertionError): + update.assert_called() + + assert res.exit_code == 0 + assert "KeyType is required for this device." in res.output + + +@device_smartcam +async def test_wifi_join_smartcam(dev, mocker, runner): + update = mocker.patch.object(dev, "update") + res = await runner.invoke( + wifi, + ["join", "FOOBAR", "--password", "foobar"], obj=dev, )