Fix httpclient exceptions on read and improve error info (#655)

This commit is contained in:
Steven B 2024-01-19 20:06:50 +00:00 committed by GitHub
parent 0647adaba0
commit 38159140fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 121 additions and 11 deletions

View File

@ -1,4 +1,5 @@
"""python-kasa exceptions."""
from asyncio import TimeoutError
from enum import IntEnum
from typing import Optional
@ -27,9 +28,15 @@ class RetryableException(SmartDeviceException):
"""Retryable exception for device errors."""
class TimeoutException(SmartDeviceException):
class TimeoutException(SmartDeviceException, TimeoutError):
"""Timeout exception for device errors."""
def __repr__(self):
return SmartDeviceException.__repr__(self)
def __str__(self):
return SmartDeviceException.__str__(self)
class ConnectionException(SmartDeviceException):
"""Connection exception for device errors."""

View File

@ -1,4 +1,5 @@
"""Module for HttpClientSession class."""
import asyncio
from typing import Any, Dict, Optional, Tuple, Union
import aiohttp
@ -58,25 +59,27 @@ class HttpClient:
cookies=cookies_dict,
headers=headers,
)
async with resp:
if resp.status == 200:
response_data = await resp.read()
if json:
response_data = json_loads(response_data.decode())
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}"
f"Unable to connect to the device: {self._config.host}: {ex}", ex
) from ex
except aiohttp.ServerTimeoutError as ex:
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex:
raise TimeoutException(
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
"Unable to query the device, "
+ f"timed out: {self._config.host}: {ex}",
ex,
) from ex
except Exception as ex:
raise SmartDeviceException(
f"Unable to query the device: {self._config.host}: {ex}"
f"Unable to query the device: {self._config.host}: {ex}", ex
) from ex
async with resp:
if resp.status == 200:
response_data = await resp.read()
if json:
response_data = json_loads(response_data.decode())
return resp.status, response_data
def get_cookie(self, cookie_name: str) -> Optional[str]:

View File

@ -0,0 +1,100 @@
import asyncio
import re
import aiohttp
import pytest
from ..deviceconfig import DeviceConfig
from ..exceptions import (
ConnectionException,
SmartDeviceException,
TimeoutException,
)
from ..httpclient import HttpClient
@pytest.mark.parametrize(
"error, error_raises, error_message",
[
(
aiohttp.ServerDisconnectedError(),
ConnectionException,
"Unable to connect to the device: ",
),
(
aiohttp.ClientOSError(),
ConnectionException,
"Unable to connect to the device: ",
),
(
aiohttp.ServerTimeoutError(),
TimeoutException,
"Unable to query the device, timed out: ",
),
(
asyncio.TimeoutError(),
TimeoutException,
"Unable to query the device, timed out: ",
),
(Exception(), SmartDeviceException, "Unable to query the device: "),
(
aiohttp.ServerFingerprintMismatch("exp", "got", "host", 1),
SmartDeviceException,
"Unable to query the device: ",
),
],
ids=(
"ServerDisconnectedError",
"ClientOSError",
"ServerTimeoutError",
"TimeoutError",
"Exception",
"ServerFingerprintMismatch",
),
)
@pytest.mark.parametrize("mock_read", (False, True), ids=("post", "read"))
async def test_httpclient_errors(mocker, error, error_raises, error_message, mock_read):
class _mock_response:
def __init__(self, status, error):
self.status = status
self.error = error
self.call_count = 0
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
pass
async def read(self):
self.call_count += 1
raise self.error
mock_response = _mock_response(200, error)
async def _post(url, *_, **__):
nonlocal mock_response
return mock_response
host = "127.0.0.1"
side_effect = _post if mock_read else error
conn = mocker.patch.object(aiohttp.ClientSession, "post", side_effect=side_effect)
client = HttpClient(DeviceConfig(host))
# Exceptions with parameters print with double quotes, without use single quotes
full_msg = (
"\("
+ "['\"]"
+ re.escape(f"{error_message}{host}: {error}")
+ "['\"]"
+ re.escape(f", {repr(error)})")
)
with pytest.raises(error_raises, match=error_message) as exc_info:
await client.post("http://foobar")
assert re.match(full_msg, str(exc_info.value))
if mock_read:
assert mock_response.call_count == 1
else:
assert conn.call_count == 1