diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 49e4e2c8..c0ef23b6 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -1,4 +1,5 @@ """python-kasa exceptions.""" +from asyncio import TimeoutError from enum import IntEnum from typing import Optional @@ -27,9 +28,15 @@ class RetryableException(SmartDeviceException): """Retryable exception for device errors.""" -class TimeoutException(SmartDeviceException): +class TimeoutException(SmartDeviceException, TimeoutError): """Timeout exception for device errors.""" + def __repr__(self): + return SmartDeviceException.__repr__(self) + + def __str__(self): + return SmartDeviceException.__str__(self) + class ConnectionException(SmartDeviceException): """Connection exception for device errors.""" diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 91c444ae..26b8d6a7 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -1,4 +1,5 @@ """Module for HttpClientSession class.""" +import asyncio from typing import Any, Dict, Optional, Tuple, Union import aiohttp @@ -58,25 +59,27 @@ class HttpClient: cookies=cookies_dict, headers=headers, ) + async with resp: + if resp.status == 200: + response_data = await resp.read() + if json: + response_data = json_loads(response_data.decode()) + except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: raise ConnectionException( - f"Unable to connect to the device: {self._config.host}: {ex}" + f"Unable to connect to the device: {self._config.host}: {ex}", ex ) from ex - except aiohttp.ServerTimeoutError as ex: + except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex: raise TimeoutException( - "Unable to query the device, " + f"timed out: {self._config.host}: {ex}" + "Unable to query the device, " + + f"timed out: {self._config.host}: {ex}", + ex, ) from ex except Exception as ex: raise SmartDeviceException( - f"Unable to query the device: {self._config.host}: {ex}" + f"Unable to query the device: {self._config.host}: {ex}", ex ) from ex - async with resp: - if resp.status == 200: - response_data = await resp.read() - if json: - response_data = json_loads(response_data.decode()) - return resp.status, response_data def get_cookie(self, cookie_name: str) -> Optional[str]: diff --git a/kasa/tests/test_httpclient.py b/kasa/tests/test_httpclient.py new file mode 100644 index 00000000..0a6c2beb --- /dev/null +++ b/kasa/tests/test_httpclient.py @@ -0,0 +1,100 @@ +import asyncio +import re + +import aiohttp +import pytest + +from ..deviceconfig import DeviceConfig +from ..exceptions import ( + ConnectionException, + SmartDeviceException, + TimeoutException, +) +from ..httpclient import HttpClient + + +@pytest.mark.parametrize( + "error, error_raises, error_message", + [ + ( + aiohttp.ServerDisconnectedError(), + ConnectionException, + "Unable to connect to the device: ", + ), + ( + aiohttp.ClientOSError(), + ConnectionException, + "Unable to connect to the device: ", + ), + ( + aiohttp.ServerTimeoutError(), + TimeoutException, + "Unable to query the device, timed out: ", + ), + ( + asyncio.TimeoutError(), + TimeoutException, + "Unable to query the device, timed out: ", + ), + (Exception(), SmartDeviceException, "Unable to query the device: "), + ( + aiohttp.ServerFingerprintMismatch("exp", "got", "host", 1), + SmartDeviceException, + "Unable to query the device: ", + ), + ], + ids=( + "ServerDisconnectedError", + "ClientOSError", + "ServerTimeoutError", + "TimeoutError", + "Exception", + "ServerFingerprintMismatch", + ), +) +@pytest.mark.parametrize("mock_read", (False, True), ids=("post", "read")) +async def test_httpclient_errors(mocker, error, error_raises, error_message, mock_read): + class _mock_response: + def __init__(self, status, error): + self.status = status + self.error = error + self.call_count = 0 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_t, exc_v, exc_tb): + pass + + async def read(self): + self.call_count += 1 + raise self.error + + mock_response = _mock_response(200, error) + + async def _post(url, *_, **__): + nonlocal mock_response + return mock_response + + host = "127.0.0.1" + + side_effect = _post if mock_read else error + + conn = mocker.patch.object(aiohttp.ClientSession, "post", side_effect=side_effect) + client = HttpClient(DeviceConfig(host)) + # Exceptions with parameters print with double quotes, without use single quotes + full_msg = ( + "\(" + + "['\"]" + + re.escape(f"{error_message}{host}: {error}") + + "['\"]" + + re.escape(f", {repr(error)})") + ) + with pytest.raises(error_raises, match=error_message) as exc_info: + await client.post("http://foobar") + + assert re.match(full_msg, str(exc_info.value)) + if mock_read: + assert mock_response.call_count == 1 + else: + assert conn.call_count == 1