mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-01-09 14:27:10 +00:00
Update SslAesTransport for legacy firmware versions
This commit is contained in:
parent
ed0481918c
commit
4a5bc20ee2
@ -38,7 +38,7 @@ from kasa.feature import Feature
|
|||||||
from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState
|
from kasa.interfaces.light import HSV, ColorTempRange, Light, LightState
|
||||||
from kasa.interfaces.thermostat import Thermostat, ThermostatState
|
from kasa.interfaces.thermostat import Thermostat, ThermostatState
|
||||||
from kasa.module import Module
|
from kasa.module import Module
|
||||||
from kasa.protocols import BaseProtocol, IotProtocol, SmartProtocol
|
from kasa.protocols import BaseProtocol, IotProtocol, SmartCamProtocol, SmartProtocol
|
||||||
from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401
|
from kasa.protocols.iotprotocol import _deprecated_TPLinkSmartHomeProtocol # noqa: F401
|
||||||
from kasa.smartcam.modules.camera import StreamResolution
|
from kasa.smartcam.modules.camera import StreamResolution
|
||||||
from kasa.transports import BaseTransport
|
from kasa.transports import BaseTransport
|
||||||
@ -52,6 +52,7 @@ __all__ = [
|
|||||||
"BaseTransport",
|
"BaseTransport",
|
||||||
"IotProtocol",
|
"IotProtocol",
|
||||||
"SmartProtocol",
|
"SmartProtocol",
|
||||||
|
"SmartCamProtocol",
|
||||||
"LightState",
|
"LightState",
|
||||||
"TurnOnBehaviors",
|
"TurnOnBehaviors",
|
||||||
"TurnOnBehavior",
|
"TurnOnBehavior",
|
||||||
|
@ -8,7 +8,7 @@ from typing import Any
|
|||||||
|
|
||||||
from .device import Device
|
from .device import Device
|
||||||
from .device_type import DeviceType
|
from .device_type import DeviceType
|
||||||
from .deviceconfig import DeviceConfig
|
from .deviceconfig import DeviceConfig, DeviceFamily
|
||||||
from .exceptions import KasaException, UnsupportedDeviceError
|
from .exceptions import KasaException, UnsupportedDeviceError
|
||||||
from .iot import (
|
from .iot import (
|
||||||
IotBulb,
|
IotBulb,
|
||||||
@ -180,19 +180,23 @@ def get_protocol(
|
|||||||
config: DeviceConfig,
|
config: DeviceConfig,
|
||||||
) -> BaseProtocol | None:
|
) -> BaseProtocol | None:
|
||||||
"""Return the protocol from the connection name."""
|
"""Return the protocol from the connection name."""
|
||||||
protocol_name = config.connection_type.device_family.value.split(".")[0]
|
|
||||||
ctype = config.connection_type
|
ctype = config.connection_type
|
||||||
|
protocol_name = ctype.device_family.value.split(".")[0]
|
||||||
|
|
||||||
|
if ctype.device_family is DeviceFamily.SmartIpCamera:
|
||||||
|
return SmartCamProtocol(transport=SslAesTransport(config=config))
|
||||||
|
|
||||||
|
if ctype.device_family is DeviceFamily.IotIpCamera:
|
||||||
|
return IotProtocol(transport=LinkieTransportV2(config=config))
|
||||||
|
|
||||||
|
if ctype.device_family is DeviceFamily.SmartTapoRobovac:
|
||||||
|
return SmartProtocol(transport=SslTransport(config=config))
|
||||||
|
|
||||||
protocol_transport_key = (
|
protocol_transport_key = (
|
||||||
protocol_name
|
protocol_name
|
||||||
+ "."
|
+ "."
|
||||||
+ ctype.encryption_type.value
|
+ ctype.encryption_type.value
|
||||||
+ (".HTTPS" if ctype.https else "")
|
+ (".HTTPS" if ctype.https else "")
|
||||||
+ (
|
|
||||||
f".{ctype.login_version}"
|
|
||||||
if ctype.login_version and ctype.login_version > 1
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
|
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
|
||||||
@ -201,12 +205,11 @@ def get_protocol(
|
|||||||
] = {
|
] = {
|
||||||
"IOT.XOR": (IotProtocol, XorTransport),
|
"IOT.XOR": (IotProtocol, XorTransport),
|
||||||
"IOT.KLAP": (IotProtocol, KlapTransport),
|
"IOT.KLAP": (IotProtocol, KlapTransport),
|
||||||
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
|
|
||||||
"SMART.AES": (SmartProtocol, AesTransport),
|
"SMART.AES": (SmartProtocol, AesTransport),
|
||||||
"SMART.AES.2": (SmartProtocol, AesTransport),
|
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
|
||||||
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
|
# Still require a lookup for SslAesTransport as H200 has a type of
|
||||||
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
|
# SMART.TAPOHUB.
|
||||||
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
|
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
|
||||||
}
|
}
|
||||||
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
|
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
|
||||||
return None
|
return None
|
||||||
|
@ -775,12 +775,10 @@ class Discover:
|
|||||||
):
|
):
|
||||||
encrypt_type = encrypt_info.sym_schm
|
encrypt_type = encrypt_info.sym_schm
|
||||||
|
|
||||||
if (
|
if not (login_version := encrypt_schm.lv) and (
|
||||||
not (login_version := encrypt_schm.lv)
|
et := discovery_result.encrypt_type
|
||||||
and (et := discovery_result.encrypt_type)
|
|
||||||
and et == ["3"]
|
|
||||||
):
|
):
|
||||||
login_version = 2
|
login_version = max([int(i) for i in et])
|
||||||
|
|
||||||
if not encrypt_type:
|
if not encrypt_type:
|
||||||
raise UnsupportedDeviceError(
|
raise UnsupportedDeviceError(
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from .iotprotocol import IotProtocol
|
from .iotprotocol import IotProtocol
|
||||||
from .protocol import BaseProtocol
|
from .protocol import BaseProtocol
|
||||||
|
from .smartcamprotocol import SmartCamProtocol
|
||||||
from .smartprotocol import SmartErrorCode, SmartProtocol
|
from .smartprotocol import SmartErrorCode, SmartProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -9,4 +10,5 @@ __all__ = [
|
|||||||
"IotProtocol",
|
"IotProtocol",
|
||||||
"SmartErrorCode",
|
"SmartErrorCode",
|
||||||
"SmartProtocol",
|
"SmartProtocol",
|
||||||
|
"SmartCamProtocol",
|
||||||
]
|
]
|
||||||
|
@ -19,7 +19,7 @@ from ..transports.sslaestransport import (
|
|||||||
SMART_RETRYABLE_ERRORS,
|
SMART_RETRYABLE_ERRORS,
|
||||||
SmartErrorCode,
|
SmartErrorCode,
|
||||||
)
|
)
|
||||||
from . import SmartProtocol
|
from .smartprotocol import SmartProtocol
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from .aestransport import AesEncyptionSession, AesTransport
|
|||||||
from .basetransport import BaseTransport
|
from .basetransport import BaseTransport
|
||||||
from .klaptransport import KlapTransport, KlapTransportV2
|
from .klaptransport import KlapTransport, KlapTransportV2
|
||||||
from .linkietransport import LinkieTransportV2
|
from .linkietransport import LinkieTransportV2
|
||||||
|
from .sslaestransport import SslAesTransport
|
||||||
from .ssltransport import SslTransport
|
from .ssltransport import SslTransport
|
||||||
from .xortransport import XorEncryption, XorTransport
|
from .xortransport import XorEncryption, XorTransport
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ __all__ = [
|
|||||||
"AesTransport",
|
"AesTransport",
|
||||||
"AesEncyptionSession",
|
"AesEncyptionSession",
|
||||||
"SslTransport",
|
"SslTransport",
|
||||||
|
"SslAesTransport",
|
||||||
"BaseTransport",
|
"BaseTransport",
|
||||||
"KlapTransport",
|
"KlapTransport",
|
||||||
"KlapTransportV2",
|
"KlapTransportV2",
|
||||||
|
@ -48,6 +48,10 @@ def _sha256_hash(payload: bytes) -> str:
|
|||||||
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
|
return hashlib.sha256(payload).hexdigest().upper() # noqa: S324
|
||||||
|
|
||||||
|
|
||||||
|
def _sha1_hash(payload: bytes) -> str:
|
||||||
|
return hashlib.sha1(payload).hexdigest().upper() # noqa: S324
|
||||||
|
|
||||||
|
|
||||||
class TransportState(Enum):
|
class TransportState(Enum):
|
||||||
"""Enum for AES state."""
|
"""Enum for AES state."""
|
||||||
|
|
||||||
@ -107,11 +111,10 @@ class SslAesTransport(BaseTransport):
|
|||||||
self._app_url = URL(f"https://{self._host_port}")
|
self._app_url = URL(f"https://{self._host_port}")
|
||||||
self._token_url: URL | None = None
|
self._token_url: URL | None = None
|
||||||
self._ssl_context: ssl.SSLContext | None = None
|
self._ssl_context: ssl.SSLContext | None = None
|
||||||
ref = str(self._token_url) if self._token_url else str(self._app_url)
|
|
||||||
self._headers = {
|
self._headers = {
|
||||||
**self.COMMON_HEADERS,
|
**self.COMMON_HEADERS,
|
||||||
"Host": self._host_port,
|
"Host": self._host,
|
||||||
"Referer": ref,
|
"Referer": f"https://{self._host}",
|
||||||
}
|
}
|
||||||
self._seq: int | None = None
|
self._seq: int | None = None
|
||||||
self._pwd_hash: str | None = None
|
self._pwd_hash: str | None = None
|
||||||
@ -125,6 +128,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
self._password = ch["pwd"]
|
self._password = ch["pwd"]
|
||||||
self._username = ch["un"]
|
self._username = ch["un"]
|
||||||
self._local_nonce: str | None = None
|
self._local_nonce: str | None = None
|
||||||
|
self._send_secure = True
|
||||||
|
|
||||||
_LOGGER.debug("Created AES transport for %s", self._host)
|
_LOGGER.debug("Created AES transport for %s", self._host)
|
||||||
|
|
||||||
@ -194,6 +198,10 @@ class SslAesTransport(BaseTransport):
|
|||||||
else:
|
else:
|
||||||
url = self._app_url
|
url = self._app_url
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Sending secure passthrough from %s",
|
||||||
|
self._host,
|
||||||
|
)
|
||||||
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
|
||||||
passthrough_request = {
|
passthrough_request = {
|
||||||
"method": "securePassthrough",
|
"method": "securePassthrough",
|
||||||
@ -254,6 +262,37 @@ class SslAesTransport(BaseTransport):
|
|||||||
) from ex
|
) from ex
|
||||||
return ret_val # type: ignore[return-value]
|
return ret_val # type: ignore[return-value]
|
||||||
|
|
||||||
|
async def send_unencrypted(self, request: str) -> dict[str, Any]:
|
||||||
|
"""Send encrypted message as passthrough."""
|
||||||
|
if self._state is TransportState.ESTABLISHED and self._token_url:
|
||||||
|
url = self._token_url
|
||||||
|
else:
|
||||||
|
url = self._app_url
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Sending unencrypted from %s",
|
||||||
|
self._host,
|
||||||
|
)
|
||||||
|
|
||||||
|
status_code, resp_dict = await self._http_client.post(
|
||||||
|
url,
|
||||||
|
json=request,
|
||||||
|
headers=self._headers,
|
||||||
|
ssl=await self._get_ssl_context(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if status_code != 200:
|
||||||
|
raise KasaException(
|
||||||
|
f"{self._host} responded with an unexpected "
|
||||||
|
+ f"status code {status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handle_response_error_code(resp_dict, "Error sending message")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
resp_dict = cast(dict[str, Any], resp_dict)
|
||||||
|
return resp_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_confirm_hash(
|
def generate_confirm_hash(
|
||||||
local_nonce: str, server_nonce: str, pwd_hash: str
|
local_nonce: str, server_nonce: str, pwd_hash: str
|
||||||
@ -302,9 +341,53 @@ class SslAesTransport(BaseTransport):
|
|||||||
|
|
||||||
async def perform_handshake(self) -> None:
|
async def perform_handshake(self) -> None:
|
||||||
"""Perform the handshake."""
|
"""Perform the handshake."""
|
||||||
local_nonce, server_nonce, pwd_hash = await self.perform_handshake1()
|
result = await self.perform_handshake1()
|
||||||
|
if result:
|
||||||
|
local_nonce, server_nonce, pwd_hash = result
|
||||||
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
|
await self.perform_handshake2(local_nonce, server_nonce, pwd_hash)
|
||||||
|
|
||||||
|
async def try_perform_login(self) -> bool:
|
||||||
|
"""Perform the md5 login."""
|
||||||
|
_LOGGER.debug("Performing insecure login ...")
|
||||||
|
|
||||||
|
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
|
||||||
|
username = self._username
|
||||||
|
body = {
|
||||||
|
"method": "login",
|
||||||
|
"params": {
|
||||||
|
"hashed": True,
|
||||||
|
"password": pwd_hash,
|
||||||
|
"username": username,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
http_client = self._http_client
|
||||||
|
status_code, resp_dict = await http_client.post(
|
||||||
|
self._app_url,
|
||||||
|
json=body,
|
||||||
|
headers=self._headers,
|
||||||
|
ssl=await self._get_ssl_context(),
|
||||||
|
)
|
||||||
|
if status_code != 200:
|
||||||
|
raise KasaException(
|
||||||
|
f"{self._host} responded with an unexpected "
|
||||||
|
+ f"status code {status_code} to handshake2"
|
||||||
|
)
|
||||||
|
resp_dict = cast(dict, resp_dict)
|
||||||
|
if resp_dict.get("error_code") == 0 and (
|
||||||
|
stok := resp_dict.get("result", {}).get("stok")
|
||||||
|
):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Succesfully logged in to %s with less secure passthrough", self._host
|
||||||
|
)
|
||||||
|
self._send_secure = False
|
||||||
|
self._token_url = URL(f"{str(self._app_url)}/stok={stok}/ds")
|
||||||
|
self._pwd_hash = pwd_hash
|
||||||
|
return True
|
||||||
|
|
||||||
|
_LOGGER.debug("Unable to log in to %s with less secure login", self._host)
|
||||||
|
return False
|
||||||
|
|
||||||
async def perform_handshake2(
|
async def perform_handshake2(
|
||||||
self, local_nonce: str, server_nonce: str, pwd_hash: str
|
self, local_nonce: str, server_nonce: str, pwd_hash: str
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -355,13 +438,42 @@ class SslAesTransport(BaseTransport):
|
|||||||
self._state = TransportState.ESTABLISHED
|
self._state = TransportState.ESTABLISHED
|
||||||
_LOGGER.debug("Handshake2 complete ...")
|
_LOGGER.debug("Handshake2 complete ...")
|
||||||
|
|
||||||
async def perform_handshake1(self) -> tuple[str, str, str]:
|
def _pwd_to_hash(self) -> str:
|
||||||
|
"""Return the password to hash."""
|
||||||
|
if self._credentials and self._credentials != Credentials():
|
||||||
|
return self._credentials.password
|
||||||
|
|
||||||
|
if self._username and self._password:
|
||||||
|
return self._password
|
||||||
|
|
||||||
|
return self._default_credentials.password
|
||||||
|
|
||||||
|
async def perform_handshake1(self) -> tuple[str, str, str] | None:
|
||||||
"""Perform the handshake1."""
|
"""Perform the handshake1."""
|
||||||
resp_dict = None
|
resp_dict = None
|
||||||
if self._username:
|
if self._username:
|
||||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||||
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
|
resp_dict = await self.try_send_handshake1(self._username, local_nonce)
|
||||||
|
|
||||||
|
if (
|
||||||
|
resp_dict
|
||||||
|
and (error_code := self._get_response_error(resp_dict))
|
||||||
|
is SmartErrorCode.SESSION_EXPIRED
|
||||||
|
and (
|
||||||
|
encrypt_type := resp_dict.get("result", {})
|
||||||
|
.get("data", {})
|
||||||
|
.get("encrypt_type")
|
||||||
|
)
|
||||||
|
and (encrypt_type != ["3"])
|
||||||
|
):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Received encrypt_type %s for %s, trying less secure login",
|
||||||
|
encrypt_type,
|
||||||
|
self._host,
|
||||||
|
)
|
||||||
|
if await self.try_perform_login():
|
||||||
|
return None
|
||||||
|
|
||||||
# Try the default username. If it fails raise the original error_code
|
# Try the default username. If it fails raise the original error_code
|
||||||
if (
|
if (
|
||||||
not resp_dict
|
not resp_dict
|
||||||
@ -369,6 +481,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
is not SmartErrorCode.INVALID_NONCE
|
is not SmartErrorCode.INVALID_NONCE
|
||||||
or "nonce" not in resp_dict["result"].get("data", {})
|
or "nonce" not in resp_dict["result"].get("data", {})
|
||||||
):
|
):
|
||||||
|
_LOGGER.debug("Trying default credentials to %s", self._host)
|
||||||
local_nonce = secrets.token_bytes(8).hex().upper()
|
local_nonce = secrets.token_bytes(8).hex().upper()
|
||||||
default_resp_dict = await self.try_send_handshake1(
|
default_resp_dict = await self.try_send_handshake1(
|
||||||
self._default_credentials.username, local_nonce
|
self._default_credentials.username, local_nonce
|
||||||
@ -378,7 +491,7 @@ class SslAesTransport(BaseTransport):
|
|||||||
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
|
) is SmartErrorCode.INVALID_NONCE and "nonce" in default_resp_dict[
|
||||||
"result"
|
"result"
|
||||||
].get("data", {}):
|
].get("data", {}):
|
||||||
_LOGGER.debug("Connected to {self._host} with default username")
|
_LOGGER.debug("Connected to %s with default username", self._host)
|
||||||
self._username = self._default_credentials.username
|
self._username = self._default_credentials.username
|
||||||
error_code = default_error_code
|
error_code = default_error_code
|
||||||
resp_dict = default_resp_dict
|
resp_dict = default_resp_dict
|
||||||
@ -397,12 +510,8 @@ class SslAesTransport(BaseTransport):
|
|||||||
|
|
||||||
server_nonce = resp_dict["result"]["data"]["nonce"]
|
server_nonce = resp_dict["result"]["data"]["nonce"]
|
||||||
device_confirm = resp_dict["result"]["data"]["device_confirm"]
|
device_confirm = resp_dict["result"]["data"]["device_confirm"]
|
||||||
if self._credentials and self._credentials != Credentials():
|
|
||||||
pwd_hash = _sha256_hash(self._credentials.password.encode())
|
pwd_hash = _sha256_hash(self._pwd_to_hash().encode())
|
||||||
elif self._username and self._password:
|
|
||||||
pwd_hash = _sha256_hash(self._password.encode())
|
|
||||||
else:
|
|
||||||
pwd_hash = _sha256_hash(self._default_credentials.password.encode())
|
|
||||||
|
|
||||||
expected_confirm_sha256 = self.generate_confirm_hash(
|
expected_confirm_sha256 = self.generate_confirm_hash(
|
||||||
local_nonce, server_nonce, pwd_hash
|
local_nonce, server_nonce, pwd_hash
|
||||||
@ -414,7 +523,9 @@ class SslAesTransport(BaseTransport):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._credentials
|
assert self._credentials
|
||||||
assert self._credentials.password
|
assert self._credentials.password
|
||||||
pwd_hash = _md5_hash(self._credentials.password.encode())
|
|
||||||
|
pwd_hash = _md5_hash(self._pwd_to_hash().encode())
|
||||||
|
|
||||||
expected_confirm_md5 = self.generate_confirm_hash(
|
expected_confirm_md5 = self.generate_confirm_hash(
|
||||||
local_nonce, server_nonce, pwd_hash
|
local_nonce, server_nonce, pwd_hash
|
||||||
)
|
)
|
||||||
@ -422,8 +533,17 @@ class SslAesTransport(BaseTransport):
|
|||||||
_LOGGER.debug("Credentials match")
|
_LOGGER.debug("Credentials match")
|
||||||
return local_nonce, server_nonce, pwd_hash
|
return local_nonce, server_nonce, pwd_hash
|
||||||
|
|
||||||
|
for val in {"admin", "tpadmin", "slprealtek"}:
|
||||||
|
for func in {_sha256_hash, _md5_hash, _sha1_hash, lambda x: x.decode()}:
|
||||||
|
pwd_hash = func(val.encode())
|
||||||
|
ec = self.generate_confirm_hash(local_nonce, server_nonce, pwd_hash)
|
||||||
|
if device_confirm == ec:
|
||||||
|
_LOGGER.debug("Credentials match with %s %s", val, func.__name__)
|
||||||
|
return local_nonce, server_nonce, pwd_hash
|
||||||
|
|
||||||
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
msg = f"Server response doesn't match our challenge on ip {self._host}"
|
||||||
_LOGGER.debug(msg)
|
_LOGGER.debug(msg)
|
||||||
|
|
||||||
raise AuthenticationError(msg)
|
raise AuthenticationError(msg)
|
||||||
|
|
||||||
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
|
async def try_send_handshake1(self, username: str, local_nonce: str) -> dict:
|
||||||
@ -462,8 +582,11 @@ class SslAesTransport(BaseTransport):
|
|||||||
if self._state is TransportState.HANDSHAKE_REQUIRED:
|
if self._state is TransportState.HANDSHAKE_REQUIRED:
|
||||||
await self.perform_handshake()
|
await self.perform_handshake()
|
||||||
|
|
||||||
|
if self._send_secure:
|
||||||
return await self.send_secure_passthrough(request)
|
return await self.send_secure_passthrough(request)
|
||||||
|
|
||||||
|
return await self.send_unencrypted(request)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the http client and reset internal state."""
|
"""Close the http client and reset internal state."""
|
||||||
await self.reset()
|
await self.reset()
|
||||||
|
@ -13,9 +13,13 @@ import aiohttp
|
|||||||
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
|
||||||
|
|
||||||
from kasa import (
|
from kasa import (
|
||||||
|
BaseProtocol,
|
||||||
Credentials,
|
Credentials,
|
||||||
Discover,
|
Discover,
|
||||||
|
IotProtocol,
|
||||||
KasaException,
|
KasaException,
|
||||||
|
SmartCamProtocol,
|
||||||
|
SmartProtocol,
|
||||||
)
|
)
|
||||||
from kasa.device_factory import (
|
from kasa.device_factory import (
|
||||||
Device,
|
Device,
|
||||||
@ -33,6 +37,16 @@ from kasa.deviceconfig import (
|
|||||||
DeviceFamily,
|
DeviceFamily,
|
||||||
)
|
)
|
||||||
from kasa.discover import DiscoveryResult
|
from kasa.discover import DiscoveryResult
|
||||||
|
from kasa.transports import (
|
||||||
|
AesTransport,
|
||||||
|
BaseTransport,
|
||||||
|
KlapTransport,
|
||||||
|
KlapTransportV2,
|
||||||
|
LinkieTransportV2,
|
||||||
|
SslAesTransport,
|
||||||
|
SslTransport,
|
||||||
|
XorTransport,
|
||||||
|
)
|
||||||
|
|
||||||
from .conftest import DISCOVERY_MOCK_IP
|
from .conftest import DISCOVERY_MOCK_IP
|
||||||
|
|
||||||
@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
|
|||||||
with caplog.at_level(logging.DEBUG):
|
with caplog.at_level(logging.DEBUG):
|
||||||
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
|
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
|
||||||
assert f"Unknown SMART device with {dummy_name}" in caplog.text
|
assert f"Unknown SMART device with {dummy_name}" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases to make the test params more readable
|
||||||
|
CP = DeviceConnectionParameters
|
||||||
|
DF = DeviceFamily
|
||||||
|
ET = DeviceEncryptionType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("conn_params", "expected_protocol", "expected_transport"),
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.SmartIpCamera, ET.Aes, https=True),
|
||||||
|
SmartCamProtocol,
|
||||||
|
SslAesTransport,
|
||||||
|
id="smartcam",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.SmartTapoHub, ET.Aes, https=True),
|
||||||
|
SmartCamProtocol,
|
||||||
|
SslAesTransport,
|
||||||
|
id="smartcam-hub",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.IotIpCamera, ET.Aes, https=True),
|
||||||
|
IotProtocol,
|
||||||
|
LinkieTransportV2,
|
||||||
|
id="kasacam",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
|
||||||
|
SmartProtocol,
|
||||||
|
SslTransport,
|
||||||
|
id="robovac",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
|
||||||
|
IotProtocol,
|
||||||
|
KlapTransport,
|
||||||
|
id="iot-klap",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
|
||||||
|
IotProtocol,
|
||||||
|
XorTransport,
|
||||||
|
id="iot-xor",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.SmartTapoPlug, ET.Aes, https=False),
|
||||||
|
SmartProtocol,
|
||||||
|
AesTransport,
|
||||||
|
id="smart-aes",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
CP(DF.SmartTapoPlug, ET.Klap, https=False),
|
||||||
|
SmartProtocol,
|
||||||
|
KlapTransportV2,
|
||||||
|
id="smart-klap",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_get_protocol(
|
||||||
|
conn_params: DeviceConnectionParameters,
|
||||||
|
expected_protocol: type[BaseProtocol],
|
||||||
|
expected_transport: type[BaseTransport],
|
||||||
|
):
|
||||||
|
"""Test get_protocol returns the right protocol."""
|
||||||
|
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
|
||||||
|
protocol = get_protocol(config)
|
||||||
|
assert isinstance(protocol, expected_protocol)
|
||||||
|
assert isinstance(protocol._transport, expected_transport)
|
||||||
|
Loading…
Reference in New Issue
Block a user