mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Encapsulate http client dependency (#642)
* Encapsulate http client dependency * Store cookie dict as variable * Update post-review
This commit is contained in:
parent
4623434eb4
commit
3b1b0a3c21
@ -30,7 +30,6 @@ from dataclasses import asdict, dataclass
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
logging.getLogger("httpx").propagate = False
|
|
||||||
|
|
||||||
|
|
||||||
class SmartRequest:
|
class SmartRequest:
|
||||||
|
@ -8,9 +8,8 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Dict, Optional, cast
|
||||||
|
|
||||||
import httpx
|
|
||||||
from cryptography.hazmat.primitives import padding, serialization
|
from cryptography.hazmat.primitives import padding, serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
@ -28,6 +27,7 @@ from .exceptions import (
|
|||||||
SmartErrorCode,
|
SmartErrorCode,
|
||||||
TimeoutException,
|
TimeoutException,
|
||||||
)
|
)
|
||||||
|
from .httpclient import HttpClient
|
||||||
from .json import dumps as json_dumps
|
from .json import dumps as json_dumps
|
||||||
from .json import loads as json_loads
|
from .json import loads as json_loads
|
||||||
from .protocol import BaseTransport
|
from .protocol import BaseTransport
|
||||||
@ -75,14 +75,14 @@ class AesTransport(BaseTransport):
|
|||||||
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
base64.b64decode(self._credentials_hash.encode()).decode() # type: ignore[union-attr]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
self._http_client: HttpClient = HttpClient(config)
|
||||||
|
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
|
|
||||||
self._encryption_session: Optional[AesEncyptionSession] = None
|
self._encryption_session: Optional[AesEncyptionSession] = None
|
||||||
self._session_expire_at: Optional[float] = None
|
self._session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
self._session_cookie = None
|
self._session_cookie: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
|
|
||||||
@ -98,14 +98,6 @@ class AesTransport(BaseTransport):
|
|||||||
"""The hashed credentials used by the transport."""
|
"""The hashed credentials used by the transport."""
|
||||||
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
return base64.b64encode(json_dumps(self._login_params).encode()).decode()
|
||||||
|
|
||||||
@property
|
|
||||||
def _http_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._config.http_client:
|
|
||||||
return self._config.http_client
|
|
||||||
if not self._default_http_client:
|
|
||||||
self._default_http_client = httpx.AsyncClient()
|
|
||||||
return self._default_http_client
|
|
||||||
|
|
||||||
def _get_login_params(self):
|
def _get_login_params(self):
|
||||||
"""Get the login parameters based on the login_version."""
|
"""Get the login parameters based on the login_version."""
|
||||||
un, pw = self.hash_credentials(self._login_version == 2)
|
un, pw = self.hash_credentials(self._login_version == 2)
|
||||||
@ -128,28 +120,6 @@ class AesTransport(BaseTransport):
|
|||||||
pw = base64.b64encode(self._credentials.password.encode()).decode()
|
pw = base64.b64encode(self._credentials.password.encode()).decode()
|
||||||
return un, pw
|
return un, pw
|
||||||
|
|
||||||
async def client_post(self, url, params=None, data=None, json=None, headers=None):
|
|
||||||
"""Send an http post request to the device."""
|
|
||||||
response_data = None
|
|
||||||
cookies = None
|
|
||||||
if self._session_cookie:
|
|
||||||
cookies = httpx.Cookies()
|
|
||||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
|
||||||
self._http_client.cookies.clear()
|
|
||||||
resp = await self._http_client.post(
|
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
data=data,
|
|
||||||
json=json,
|
|
||||||
timeout=self._timeout,
|
|
||||||
cookies=cookies,
|
|
||||||
headers=self.COMMON_HEADERS,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
response_data = resp.json()
|
|
||||||
|
|
||||||
return resp.status_code, response_data
|
|
||||||
|
|
||||||
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
def _handle_response_error_code(self, resp_dict: dict, msg: str):
|
||||||
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
|
||||||
if error_code == SmartErrorCode.SUCCESS:
|
if error_code == SmartErrorCode.SUCCESS:
|
||||||
@ -176,7 +146,12 @@ class AesTransport(BaseTransport):
|
|||||||
"method": "securePassthrough",
|
"method": "securePassthrough",
|
||||||
"params": {"request": encrypted_payload.decode()},
|
"params": {"request": encrypted_payload.decode()},
|
||||||
}
|
}
|
||||||
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
|
status_code, resp_dict = await self._http_client.post(
|
||||||
|
url,
|
||||||
|
json=passthrough_request,
|
||||||
|
headers=self.COMMON_HEADERS,
|
||||||
|
cookies_dict=self._session_cookie,
|
||||||
|
)
|
||||||
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
|
||||||
|
|
||||||
if status_code != 200:
|
if status_code != 200:
|
||||||
@ -185,6 +160,7 @@ class AesTransport(BaseTransport):
|
|||||||
+ f"status code {status_code} to passthrough"
|
+ f"status code {status_code} to passthrough"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
resp_dict = cast(Dict, resp_dict)
|
||||||
self._handle_response_error_code(
|
self._handle_response_error_code(
|
||||||
resp_dict, "Error sending secure_passthrough message"
|
resp_dict, "Error sending secure_passthrough message"
|
||||||
)
|
)
|
||||||
@ -233,7 +209,12 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
_LOGGER.debug(f"Request {request_body}")
|
_LOGGER.debug(f"Request {request_body}")
|
||||||
|
|
||||||
status_code, resp_dict = await self.client_post(url, json=request_body)
|
status_code, resp_dict = await self._http_client.post(
|
||||||
|
url,
|
||||||
|
json=request_body,
|
||||||
|
headers=self.COMMON_HEADERS,
|
||||||
|
cookies_dict=self._session_cookie,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
_LOGGER.debug(f"Device responded with: {resp_dict}")
|
||||||
|
|
||||||
@ -247,13 +228,16 @@ class AesTransport(BaseTransport):
|
|||||||
|
|
||||||
handshake_key = resp_dict["result"]["key"]
|
handshake_key = resp_dict["result"]["key"]
|
||||||
|
|
||||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
if (
|
||||||
self.SESSION_COOKIE_NAME
|
cookie := self._http_client.get_cookie( # type: ignore
|
||||||
)
|
self.SESSION_COOKIE_NAME
|
||||||
if not self._session_cookie:
|
)
|
||||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
) or (
|
||||||
|
cookie := self._http_client.get_cookie( # type: ignore
|
||||||
"SESSIONID"
|
"SESSIONID"
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||||
|
|
||||||
self._session_expire_at = time.time() + 86400
|
self._session_expire_at = time.time() + 86400
|
||||||
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
self._encryption_session = AesEncyptionSession.create_from_keypair(
|
||||||
@ -281,13 +265,10 @@ class AesTransport(BaseTransport):
|
|||||||
return await self.send_secure_passthrough(request)
|
return await self.send_secure_passthrough(request)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the protocol."""
|
"""Close the transport."""
|
||||||
client = self._default_http_client
|
|
||||||
self._default_http_client = None
|
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
self._login_token = None
|
self._login_token = None
|
||||||
if client:
|
await self._http_client.close()
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
class AesEncyptionSession:
|
class AesEncyptionSession:
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .exceptions import SmartDeviceException
|
from .exceptions import SmartDeviceException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -150,7 +151,7 @@ class DeviceConfig:
|
|||||||
|
|
||||||
# compare=False will be excluded from the serialization and object comparison.
|
# compare=False will be excluded from the serialization and object comparison.
|
||||||
#: Set a custom http_client for the device to use.
|
#: Set a custom http_client for the device to use.
|
||||||
http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False)
|
http_client: Optional["AsyncClient"] = field(default=None, compare=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.connection_type is None:
|
if self.connection_type is None:
|
||||||
|
@ -31,6 +31,10 @@ class TimeoutException(SmartDeviceException):
|
|||||||
"""Timeout exception for device errors."""
|
"""Timeout exception for device errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionException(SmartDeviceException):
|
||||||
|
"""Connection exception for device errors."""
|
||||||
|
|
||||||
|
|
||||||
class SmartErrorCode(IntEnum):
|
class SmartErrorCode(IntEnum):
|
||||||
"""Enum for SMART Error Codes."""
|
"""Enum for SMART Error Codes."""
|
||||||
|
|
||||||
|
89
kasa/httpclient.py
Normal file
89
kasa/httpclient.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
"""Module for HttpClientSession class."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .deviceconfig import DeviceConfig
|
||||||
|
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
|
||||||
|
|
||||||
|
logging.getLogger("httpx").propagate = False
|
||||||
|
|
||||||
|
InnerHttpType = Type[httpx.AsyncClient]
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClient:
|
||||||
|
"""HttpClient Class."""
|
||||||
|
|
||||||
|
def __init__(self, config: DeviceConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._client: httpx.AsyncClient = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> httpx.AsyncClient:
|
||||||
|
"""Return the underlying http client."""
|
||||||
|
if self._config.http_client and issubclass(
|
||||||
|
self._config.http_client.__class__, httpx.AsyncClient
|
||||||
|
):
|
||||||
|
return self._config.http_client
|
||||||
|
|
||||||
|
if not self._client:
|
||||||
|
self._client = httpx.AsyncClient()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def post(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
data: Optional[bytes] = None,
|
||||||
|
json: Optional[Dict] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
cookies_dict: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
|
||||||
|
"""Send an http post request to the device."""
|
||||||
|
response_data = None
|
||||||
|
cookies = None
|
||||||
|
if cookies_dict:
|
||||||
|
cookies = httpx.Cookies()
|
||||||
|
for name, value in cookies_dict.items():
|
||||||
|
cookies.set(name, value)
|
||||||
|
self.client.cookies.clear()
|
||||||
|
try:
|
||||||
|
resp = await self.client.post(
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
data=data,
|
||||||
|
json=json,
|
||||||
|
timeout=self._config.timeout,
|
||||||
|
cookies=cookies,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
except httpx.ConnectError as ex:
|
||||||
|
raise ConnectionException(
|
||||||
|
f"Unable to connect to the device: {self._config.host}: {ex}"
|
||||||
|
) from ex
|
||||||
|
except httpx.TimeoutException as ex:
|
||||||
|
raise TimeoutException(
|
||||||
|
"Unable to query the device, " + f"timed out: {self._config.host}: {ex}"
|
||||||
|
) from ex
|
||||||
|
except Exception as ex:
|
||||||
|
raise SmartDeviceException(
|
||||||
|
f"Unable to query the device: {self._config.host}: {ex}"
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
if resp.status_code == 200:
|
||||||
|
response_data = resp.json() if json else resp.content
|
||||||
|
|
||||||
|
return resp.status_code, response_data
|
||||||
|
|
||||||
|
def get_cookie(self, cookie_name: str) -> str:
|
||||||
|
"""Return the cookie with cookie_name."""
|
||||||
|
return self._client.cookies.get(cookie_name)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the protocol."""
|
||||||
|
client = self._client
|
||||||
|
self._client = None
|
||||||
|
if client:
|
||||||
|
await client.aclose()
|
@ -3,9 +3,13 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import httpx
|
from .exceptions import (
|
||||||
|
AuthenticationException,
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
ConnectionException,
|
||||||
|
RetryableException,
|
||||||
|
SmartDeviceException,
|
||||||
|
TimeoutException,
|
||||||
|
)
|
||||||
from .json import dumps as json_dumps
|
from .json import dumps as json_dumps
|
||||||
from .protocol import BaseTransport, TPLinkProtocol
|
from .protocol import BaseTransport, TPLinkProtocol
|
||||||
|
|
||||||
@ -15,6 +19,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
class IotProtocol(TPLinkProtocol):
|
class IotProtocol(TPLinkProtocol):
|
||||||
"""Class for the legacy TPLink IOT KASA Protocol."""
|
"""Class for the legacy TPLink IOT KASA Protocol."""
|
||||||
|
|
||||||
|
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -38,40 +44,39 @@ class IotProtocol(TPLinkProtocol):
|
|||||||
for retry in range(retry_count + 1):
|
for retry in range(retry_count + 1):
|
||||||
try:
|
try:
|
||||||
return await self._execute_query(request, retry)
|
return await self._execute_query(request, retry)
|
||||||
except httpx.ConnectError as sdex:
|
except ConnectionException as sdex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise sdex
|
||||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
|
||||||
) from sdex
|
|
||||||
continue
|
continue
|
||||||
except TimeoutError as tex:
|
|
||||||
await self.close()
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device, timed out: {self._host}: {tex}"
|
|
||||||
) from tex
|
|
||||||
except AuthenticationException as auex:
|
except AuthenticationException as auex:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Unable to authenticate with %s, not retrying", self._host
|
"Unable to authenticate with %s, not retrying", self._host
|
||||||
)
|
)
|
||||||
raise auex
|
raise auex
|
||||||
|
except RetryableException as ex:
|
||||||
|
if retry >= retry_count:
|
||||||
|
await self.close()
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
|
raise ex
|
||||||
|
continue
|
||||||
|
except TimeoutException as ex:
|
||||||
|
if retry >= retry_count:
|
||||||
|
await self.close()
|
||||||
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
|
raise ex
|
||||||
|
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||||
|
continue
|
||||||
except SmartDeviceException as ex:
|
except SmartDeviceException as ex:
|
||||||
|
await self.close()
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Unable to connect to the device: %s, not retrying: %s",
|
"Unable to query the device: %s, not retrying: %s",
|
||||||
self._host,
|
self._host,
|
||||||
ex,
|
ex,
|
||||||
)
|
)
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
|
||||||
if retry >= retry_count:
|
|
||||||
await self.close()
|
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self._host}: {ex}"
|
|
||||||
) from ex
|
|
||||||
continue
|
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
@ -48,20 +48,19 @@ import logging
|
|||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, cast
|
||||||
|
|
||||||
import httpx
|
|
||||||
from cryptography.hazmat.primitives import hashes, padding
|
from cryptography.hazmat.primitives import hashes, padding
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
from .credentials import Credentials
|
from .credentials import Credentials
|
||||||
from .deviceconfig import DeviceConfig
|
from .deviceconfig import DeviceConfig
|
||||||
from .exceptions import AuthenticationException, SmartDeviceException
|
from .exceptions import AuthenticationException, SmartDeviceException
|
||||||
|
from .httpclient import HttpClient
|
||||||
from .json import loads as json_loads
|
from .json import loads as json_loads
|
||||||
from .protocol import BaseTransport, md5
|
from .protocol import BaseTransport, md5
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
logging.getLogger("httpx").propagate = False
|
|
||||||
|
|
||||||
|
|
||||||
def _sha256(payload: bytes) -> bytes:
|
def _sha256(payload: bytes) -> bytes:
|
||||||
@ -98,7 +97,7 @@ class KlapTransport(BaseTransport):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
|
||||||
self._default_http_client: Optional[httpx.AsyncClient] = None
|
self._http_client = HttpClient(config)
|
||||||
self._local_seed: Optional[bytes] = None
|
self._local_seed: Optional[bytes] = None
|
||||||
if (
|
if (
|
||||||
not self._credentials or self._credentials.username is None
|
not self._credentials or self._credentials.username is None
|
||||||
@ -118,7 +117,7 @@ class KlapTransport(BaseTransport):
|
|||||||
self._encryption_session: Optional[KlapEncryptionSession] = None
|
self._encryption_session: Optional[KlapEncryptionSession] = None
|
||||||
self._session_expire_at: Optional[float] = None
|
self._session_expire_at: Optional[float] = None
|
||||||
|
|
||||||
self._session_cookie = None
|
self._session_cookie: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
_LOGGER.debug("Created KLAP transport for %s", self._host)
|
||||||
|
|
||||||
@ -132,34 +131,6 @@ class KlapTransport(BaseTransport):
|
|||||||
"""The hashed credentials used by the transport."""
|
"""The hashed credentials used by the transport."""
|
||||||
return base64.b64encode(self._local_auth_hash).decode()
|
return base64.b64encode(self._local_auth_hash).decode()
|
||||||
|
|
||||||
@property
|
|
||||||
def _http_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._config.http_client:
|
|
||||||
return self._config.http_client
|
|
||||||
if not self._default_http_client:
|
|
||||||
self._default_http_client = httpx.AsyncClient()
|
|
||||||
return self._default_http_client
|
|
||||||
|
|
||||||
async def client_post(self, url, params=None, data=None):
|
|
||||||
"""Send an http post request to the device."""
|
|
||||||
response_data = None
|
|
||||||
cookies = None
|
|
||||||
if self._session_cookie:
|
|
||||||
cookies = httpx.Cookies()
|
|
||||||
cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie)
|
|
||||||
self._http_client.cookies.clear()
|
|
||||||
resp = await self._http_client.post(
|
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
data=data,
|
|
||||||
timeout=self._timeout,
|
|
||||||
cookies=cookies,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
response_data = resp.content
|
|
||||||
|
|
||||||
return resp.status_code, response_data
|
|
||||||
|
|
||||||
async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]:
|
async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]:
|
||||||
"""Perform handshake1."""
|
"""Perform handshake1."""
|
||||||
local_seed: bytes = secrets.token_bytes(16)
|
local_seed: bytes = secrets.token_bytes(16)
|
||||||
@ -172,7 +143,7 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
url = f"http://{self._host}/app/handshake1"
|
url = f"http://{self._host}/app/handshake1"
|
||||||
|
|
||||||
response_status, response_data = await self.client_post(url, data=payload)
|
response_status, response_data = await self._http_client.post(url, data=payload)
|
||||||
|
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
@ -189,6 +160,7 @@ class KlapTransport(BaseTransport):
|
|||||||
f"Device {self._host} responded with {response_status} to handshake1"
|
f"Device {self._host} responded with {response_status} to handshake1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_data = cast(bytes, response_data)
|
||||||
remote_seed: bytes = response_data[0:16]
|
remote_seed: bytes = response_data[0:16]
|
||||||
server_hash = response_data[16:]
|
server_hash = response_data[16:]
|
||||||
|
|
||||||
@ -268,7 +240,11 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash)
|
||||||
|
|
||||||
response_status, response_data = await self.client_post(url, data=payload)
|
response_status, _ = await self._http_client.post(
|
||||||
|
url,
|
||||||
|
data=payload,
|
||||||
|
cookies_dict=self._session_cookie,
|
||||||
|
)
|
||||||
|
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
@ -298,9 +274,10 @@ class KlapTransport(BaseTransport):
|
|||||||
self._session_cookie = None
|
self._session_cookie = None
|
||||||
|
|
||||||
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
local_seed, remote_seed, auth_hash = await self.perform_handshake1()
|
||||||
self._session_cookie = self._http_client.cookies.get( # type: ignore
|
if cookie := self._http_client.get_cookie( # type: ignore
|
||||||
self.SESSION_COOKIE_NAME
|
self.SESSION_COOKIE_NAME
|
||||||
)
|
):
|
||||||
|
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
||||||
# The device returns a TIMEOUT cookie on handshake1 which
|
# The device returns a TIMEOUT cookie on handshake1 which
|
||||||
# it doesn't like to get back so we store the one we want
|
# it doesn't like to get back so we store the one we want
|
||||||
|
|
||||||
@ -330,10 +307,11 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
url = f"http://{self._host}/app/request"
|
url = f"http://{self._host}/app/request"
|
||||||
|
|
||||||
response_status, response_data = await self.client_post(
|
response_status, response_data = await self._http_client.post(
|
||||||
url,
|
url,
|
||||||
params={"seq": seq},
|
params={"seq": seq},
|
||||||
data=payload,
|
data=payload,
|
||||||
|
cookies_dict=self._session_cookie,
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
@ -374,11 +352,8 @@ class KlapTransport(BaseTransport):
|
|||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the transport."""
|
"""Close the transport."""
|
||||||
client = self._default_http_client
|
|
||||||
self._default_http_client = None
|
|
||||||
self._handshake_done = False
|
self._handshake_done = False
|
||||||
if client:
|
await self._http_client.close()
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_auth_hash(creds: Credentials):
|
def generate_auth_hash(creds: Credentials):
|
||||||
|
@ -12,13 +12,12 @@ import uuid
|
|||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
SMART_AUTHENTICATION_ERRORS,
|
SMART_AUTHENTICATION_ERRORS,
|
||||||
SMART_RETRYABLE_ERRORS,
|
SMART_RETRYABLE_ERRORS,
|
||||||
SMART_TIMEOUT_ERRORS,
|
SMART_TIMEOUT_ERRORS,
|
||||||
AuthenticationException,
|
AuthenticationException,
|
||||||
|
ConnectionException,
|
||||||
RetryableException,
|
RetryableException,
|
||||||
SmartDeviceException,
|
SmartDeviceException,
|
||||||
SmartErrorCode,
|
SmartErrorCode,
|
||||||
@ -28,13 +27,12 @@ from .json import dumps as json_dumps
|
|||||||
from .protocol import BaseTransport, TPLinkProtocol, md5
|
from .protocol import BaseTransport, TPLinkProtocol, md5
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
logging.getLogger("httpx").propagate = False
|
|
||||||
|
|
||||||
|
|
||||||
class SmartProtocol(TPLinkProtocol):
|
class SmartProtocol(TPLinkProtocol):
|
||||||
"""Class for the new TPLink SMART protocol."""
|
"""Class for the new TPLink SMART protocol."""
|
||||||
|
|
||||||
SLEEP_SECONDS_AFTER_TIMEOUT = 1
|
BACKOFF_SECONDS_AFTER_TIMEOUT = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -67,22 +65,11 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
for retry in range(retry_count + 1):
|
for retry in range(retry_count + 1):
|
||||||
try:
|
try:
|
||||||
return await self._execute_query(request, retry)
|
return await self._execute_query(request, retry)
|
||||||
except httpx.ConnectError as sdex:
|
except ConnectionException as sdex:
|
||||||
if retry >= retry_count:
|
if retry >= retry_count:
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise SmartDeviceException(
|
raise sdex
|
||||||
f"Unable to connect to the device: {self._host}: {sdex}"
|
|
||||||
) from sdex
|
|
||||||
continue
|
|
||||||
except TimeoutError as tex:
|
|
||||||
if retry >= retry_count:
|
|
||||||
await self.close()
|
|
||||||
raise SmartDeviceException(
|
|
||||||
"Unable to connect to the device, "
|
|
||||||
+ f"timed out: {self._host}: {tex}"
|
|
||||||
) from tex
|
|
||||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
|
||||||
continue
|
continue
|
||||||
except AuthenticationException as auex:
|
except AuthenticationException as auex:
|
||||||
await self.close()
|
await self.close()
|
||||||
@ -101,24 +88,16 @@ class SmartProtocol(TPLinkProtocol):
|
|||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
||||||
raise ex
|
raise ex
|
||||||
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
|
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
|
||||||
continue
|
continue
|
||||||
except SmartDeviceException as ex:
|
except SmartDeviceException as ex:
|
||||||
# Transport would have raised RetryableException if retry makes sense.
|
|
||||||
await self.close()
|
await self.close()
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
|
||||||
raise ex
|
|
||||||
except Exception as ex:
|
|
||||||
if retry >= retry_count:
|
|
||||||
await self.close()
|
|
||||||
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
|
|
||||||
raise SmartDeviceException(
|
|
||||||
f"Unable to connect to the device: {self._host}: {ex}"
|
|
||||||
) from ex
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Unable to query the device %s, retrying: %s", self._host, ex
|
"Unable to query the device: %s, not retrying: %s",
|
||||||
|
self._host,
|
||||||
|
ex,
|
||||||
)
|
)
|
||||||
continue
|
raise ex
|
||||||
|
|
||||||
# make mypy happy, this should never be reached..
|
# make mypy happy, this should never be reached..
|
||||||
raise SmartDeviceException("Query reached somehow to unreachable")
|
raise SmartDeviceException("Query reached somehow to unreachable")
|
||||||
|
@ -145,7 +145,7 @@ async def test_connect_http_client(all_fixture_data, mocker):
|
|||||||
)
|
)
|
||||||
dev = await connect(config=config)
|
dev = await connect(config=config)
|
||||||
if ctype.encryption_type != EncryptType.Xor:
|
if ctype.encryption_type != EncryptType.Xor:
|
||||||
assert dev.protocol._transport._http_client != http_client
|
assert dev.protocol._transport._http_client.client != http_client
|
||||||
|
|
||||||
config = DeviceConfig(
|
config = DeviceConfig(
|
||||||
host=host,
|
host=host,
|
||||||
@ -155,4 +155,4 @@ async def test_connect_http_client(all_fixture_data, mocker):
|
|||||||
)
|
)
|
||||||
dev = await connect(config=config)
|
dev = await connect(config=config)
|
||||||
if ctype.encryption_type != EncryptType.Xor:
|
if ctype.encryption_type != EncryptType.Xor:
|
||||||
assert dev.protocol._transport._http_client == http_client
|
assert dev.protocol._transport._http_client.client == http_client
|
||||||
|
@ -321,9 +321,9 @@ async def test_discover_single_http_client(discovery_mock, mocker):
|
|||||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||||
|
|
||||||
if discovery_mock.default_port == 80:
|
if discovery_mock.default_port == 80:
|
||||||
assert x.protocol._transport._http_client != http_client
|
assert x.protocol._transport._http_client.client != http_client
|
||||||
x.config.http_client = http_client
|
x.config.http_client = http_client
|
||||||
assert x.protocol._transport._http_client == http_client
|
assert x.protocol._transport._http_client.client == http_client
|
||||||
|
|
||||||
|
|
||||||
async def test_discover_http_client(discovery_mock, mocker):
|
async def test_discover_http_client(discovery_mock, mocker):
|
||||||
@ -338,6 +338,6 @@ async def test_discover_http_client(discovery_mock, mocker):
|
|||||||
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
assert x.config.uses_http == (discovery_mock.default_port == 80)
|
||||||
|
|
||||||
if discovery_mock.default_port == 80:
|
if discovery_mock.default_port == 80:
|
||||||
assert x.protocol._transport._http_client != http_client
|
assert x.protocol._transport._http_client.client != http_client
|
||||||
x.config.http_client = http_client
|
x.config.http_client = http_client
|
||||||
assert x.protocol._transport._http_client == http_client
|
assert x.protocol._transport._http_client.client == http_client
|
||||||
|
@ -13,7 +13,12 @@ import pytest
|
|||||||
from ..aestransport import AesTransport
|
from ..aestransport import AesTransport
|
||||||
from ..credentials import Credentials
|
from ..credentials import Credentials
|
||||||
from ..deviceconfig import DeviceConfig
|
from ..deviceconfig import DeviceConfig
|
||||||
from ..exceptions import AuthenticationException, SmartDeviceException
|
from ..exceptions import (
|
||||||
|
AuthenticationException,
|
||||||
|
ConnectionException,
|
||||||
|
SmartDeviceException,
|
||||||
|
)
|
||||||
|
from ..httpclient import HttpClient
|
||||||
from ..iotprotocol import IotProtocol
|
from ..iotprotocol import IotProtocol
|
||||||
from ..klaptransport import (
|
from ..klaptransport import (
|
||||||
KlapEncryptionSession,
|
KlapEncryptionSession,
|
||||||
@ -35,8 +40,8 @@ class _mock_response:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"error, retry_expectation",
|
"error, retry_expectation",
|
||||||
[
|
[
|
||||||
(Exception("dummy exception"), True),
|
(Exception("dummy exception"), False),
|
||||||
(SmartDeviceException("dummy exception"), False),
|
(httpx.TimeoutException("dummy exception"), True),
|
||||||
(httpx.ConnectError("dummy exception"), True),
|
(httpx.ConnectError("dummy exception"), True),
|
||||||
],
|
],
|
||||||
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
|
ids=("Exception", "SmartDeviceException", "httpx.ConnectError"),
|
||||||
@ -89,7 +94,7 @@ async def test_protocol_retry_recoverable_error(
|
|||||||
conn = mocker.patch.object(
|
conn = mocker.patch.object(
|
||||||
httpx.AsyncClient,
|
httpx.AsyncClient,
|
||||||
"post",
|
"post",
|
||||||
side_effect=httpx.CloseError("foo"),
|
side_effect=httpx.ConnectError("foo"),
|
||||||
)
|
)
|
||||||
config = DeviceConfig(host)
|
config = DeviceConfig(host)
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
@ -112,7 +117,7 @@ async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport
|
|||||||
nonlocal remaining
|
nonlocal remaining
|
||||||
remaining -= 1
|
remaining -= 1
|
||||||
if remaining:
|
if remaining:
|
||||||
raise Exception("Simulated post failure")
|
raise ConnectionException("Simulated connection failure")
|
||||||
|
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
@ -155,7 +160,7 @@ async def test_protocol_logging(mocker, caplog, log_level):
|
|||||||
protocol._transport._handshake_done = True
|
protocol._transport._handshake_done = True
|
||||||
protocol._transport._session_expire_at = time.time() + 86400
|
protocol._transport._session_expire_at = time.time() + 86400
|
||||||
protocol._transport._encryption_session = encryption_session
|
protocol._transport._encryption_session = encryption_session
|
||||||
mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted)
|
mocker.patch.object(HttpClient, "post", side_effect=_return_encrypted)
|
||||||
|
|
||||||
response = await protocol.query({})
|
response = await protocol.query({})
|
||||||
assert response == {"great": "success"}
|
assert response == {"great": "success"}
|
||||||
|
Loading…
Reference in New Issue
Block a user