Add support for the protocol used by TAPO devices and some newer KASA devices. (#552)

* Add Tapo protocol support

* Update get_device_instance and test_unsupported following review
This commit is contained in:
sdb9696 2023-11-30 12:10:49 +00:00 committed by GitHub
parent 9de3f69033
commit 63d64ad920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 775 additions and 24 deletions

498
kasa/aesprotocol.py Normal file
View File

@ -0,0 +1,498 @@
"""Implementation of the TP-Link AES Protocol.
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""
import asyncio
import base64
import hashlib
import logging
import time
import uuid
from pprint import pformat as pf
from typing import Dict, Optional, Union
import httpx
from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import TPLinkProtocol
_LOGGER = logging.getLogger(__name__)
logging.getLogger("httpx").propagate = False
def _md5(payload: bytes) -> bytes:
digest = hashes.Hash(hashes.MD5()) # noqa: S303
digest.update(payload)
hash = digest.finalize()
return hash
def _sha1(payload: bytes) -> str:
sha1_algo = hashlib.sha1() # noqa: S324
sha1_algo.update(payload)
return sha1_algo.hexdigest()
class TPLinkAes(TPLinkProtocol):
"""Implementation of the AES encryption protocol.
AES is the name used in device discovery for TP-Link's TAPO encryption
protocol, sometimes used by newer firmware versions on kasa devices.
"""
DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
"Content-Type": "application/json",
"requestByApp": "true",
"Accept": "application/json",
}
def __init__(
self,
host: str,
*,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)
self.credentials = (
credentials
if credentials and credentials.username and credentials.password
else Credentials(username="", password="")
)
self._local_seed: Optional[bytes] = None
self.local_auth_hash = self.generate_auth_hash(self.credentials)
self.local_auth_owner = self.generate_owner_hash(self.credentials).hex()
self.kasa_setup_auth_hash = None
self.blank_auth_hash = None
self.handshake_lock = asyncio.Lock()
self.query_lock = asyncio.Lock()
self.handshake_done = False
self.encryption_session: Optional[AesEncyptionSession] = None
self.session_expire_at: Optional[float] = None
self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT
self.session_cookie = None
self.terminal_uuid = None
self.http_client: Optional[httpx.AsyncClient] = None
self.request_id_generator = SnowflakeId(1, 1)
self.login_token = None
_LOGGER.debug("Created AES object for %s", self.host)
def hash_credentials(self, credentials, try_login_version2):
"""Hash the credentials."""
if try_login_version2:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw
async def client_post(self, url, params=None, data=None, json=None, headers=None):
"""Send an http post request to the device."""
response_data = None
cookies = None
if self.session_cookie:
cookies = httpx.Cookies()
cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie)
self.http_client.cookies.clear()
resp = await self.http_client.post(
url,
params=params,
data=data,
json=json,
timeout=self.timeout,
cookies=cookies,
headers=self.COMMON_HEADERS,
)
if resp.status_code == 200:
response_data = resp.json()
return resp.status_code, response_data
async def send_secure_passthrough(self, request):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
if self.login_token:
url += f"?token={self.login_token}"
raw_request = json_dumps(request)
encrypted_payload = self.encryption_session.encrypt(raw_request.encode())
passthrough_request = {
"method": "securePassthrough",
"params": {"request": encrypted_payload.decode()},
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
if status_code == 200 and resp_dict["error_code"] == 0:
response = self.encryption_session.decrypt(
resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
if resp_dict["error_code"] != 0:
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
if "result" in resp_dict:
return resp_dict["result"]
else:
raise AuthenticationException("Could not complete send")
def get_aes_request(self, method, params=None):
"""Get a request message."""
request = {
"method": method,
"params": params,
"requestID": self.request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self.terminal_uuid,
}
return request
async def perform_login(self, login_v2):
"""Login to the device."""
self.login_token = None
un, pw = self.hash_credentials(self.credentials, login_v2)
params = {"password": pw, "username": un}
request = self.get_aes_request("login_device", params)
try:
result = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self.login_token = result["token"]
async def perform_handshake(self):
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")
_LOGGER.debug("Generating keypair")
self.handshake_done = False
self.session_expire_at = None
self.session_cookie = None
url = f"http://{self.host}/app"
key_pair = KeyPair.create_key_pair()
pub_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ key_pair.get_public_key()
+ "\n-----END PUBLIC KEY-----\n"
)
handshake_params = {"key": pub_key}
_LOGGER.debug(f"Handshake params: {handshake_params}")
request_body = {"method": "handshake", "params": handshake_params}
_LOGGER.debug(f"Request {request_body}")
status_code, resp_dict = await self.client_post(url, json=request_body)
_LOGGER.debug(f"Device responded with: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
_LOGGER.debug("Decoding handshake key...")
handshake_key = resp_dict["result"]["key"]
self.session_cookie = self.http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self.session_cookie:
self.session_cookie = self.http_client.cookies.get( # type: ignore
"SESSIONID"
)
self.session_expire_at = time.time() + 86400
self.encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)
self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode(
"UTF-8"
)
self.handshake_done = True
_LOGGER.debug("Handshake with %s complete", self.host)
else:
raise AuthenticationException("Could not complete handshake")
def handshake_session_expired(self):
"""Return true if session has expired."""
return (
self.session_expire_at is None or self.session_expire_at - time.time() <= 0
)
@staticmethod
def generate_auth_hash(creds: Credentials):
"""Generate an md5 auth hash for the protocol on the supplied credentials."""
un = creds.username or ""
pw = creds.password or ""
return _md5(_md5(un.encode()) + _md5(pw.encode()))
@staticmethod
def generate_owner_hash(creds: Credentials):
"""Return the MD5 hash of the username in this object."""
un = creds.username or ""
return _md5(un.encode())
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self.query_lock:
return await self._query(request, retry_count)
async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
) from ex
continue
# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
_LOGGER.debug(
"%s >> %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(request),
)
if not self.http_client:
self.http_client = httpx.AsyncClient()
if not self.handshake_done or self.handshake_session_expired():
try:
await self.perform_handshake()
await self.perform_login(False)
except AuthenticationException:
await self.perform_handshake()
await self.perform_login(True)
if isinstance(request, dict):
aes_method = next(iter(request))
aes_params = request[aes_method]
else:
aes_method = request
aes_params = None
aes_request = self.get_aes_request(aes_method, aes_params)
response_data = await self.send_secure_passthrough(aes_request)
_LOGGER.debug(
"%s << %s",
self.host,
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)
return response_data
async def close(self) -> None:
"""Close the protocol."""
client = self.http_client
self.http_client = None
if client:
await client.aclose()
class AesEncyptionSession:
"""Class for an AES encryption session."""
@staticmethod
def create_from_keypair(handshake_key: str, keypair):
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
private_key = serialization.load_der_private_key(private_key_data, None, None)
key_and_iv = private_key.decrypt(
handshake_key_bytes, asymmetric_padding.PKCS1v15()
)
if key_and_iv is None:
raise ValueError("Decryption failed!")
return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])
def __init__(self, key, iv):
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
def encrypt(self, data) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)
def decrypt(self, data) -> str:
"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
unpadded_data = unpadder.update(decrypted) + unpadder.finalize()
return unpadded_data.decode()
class KeyPair:
"""Class for generating key pairs."""
@staticmethod
def create_key_pair(key_size: int = 1024):
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return KeyPair(
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
)
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
def get_private_key(self) -> str:
"""Get the private key."""
return self.private_key
def get_public_key(self) -> str:
"""Get the public key."""
return self.public_key
class SnowflakeId:
"""Class for generating snowflake ids."""
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)
self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1
def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()
if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")
if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0
# Update the last timestamp
self.last_timestamp = timestamp
# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)
def _current_millis(self):
return round(time.time() * 1000)
def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp

View File

@ -14,6 +14,7 @@ from .smartdimmer import SmartDimmer
from .smartlightstrip import SmartLightStrip from .smartlightstrip import SmartLightStrip
from .smartplug import SmartPlug from .smartplug import SmartPlug
from .smartstrip import SmartStrip from .smartstrip import SmartStrip
from .tapo.tapoplug import TapoPlug
DEVICE_TYPE_TO_CLASS = { DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug, DeviceType.Plug: SmartPlug,
@ -21,6 +22,7 @@ DEVICE_TYPE_TO_CLASS = {
DeviceType.Strip: SmartStrip, DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer, DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip, DeviceType.LightStrip: SmartLightStrip,
DeviceType.TapoPlug: TapoPlug,
} }
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View File

@ -14,6 +14,7 @@ class DeviceType(Enum):
StripSocket = "stripsocket" StripSocket = "stripsocket"
Dimmer = "dimmer" Dimmer = "dimmer"
LightStrip = "lightstrip" LightStrip = "lightstrip"
TapoPlug = "tapoplug"
Unknown = "unknown" Unknown = "unknown"
@staticmethod @staticmethod

View File

@ -15,14 +15,16 @@ try:
except ImportError: except ImportError:
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from kasa.aesprotocol import TPLinkAes
from kasa.credentials import Credentials from kasa.credentials import Credentials
from kasa.exceptions import UnsupportedDeviceException from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads from kasa.json import loads as json_loads
from kasa.klapprotocol import TPLinkKlap from kasa.klapprotocol import TPLinkKlap
from kasa.protocol import TPLinkSmartHomeProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol
from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartplug import SmartPlug from kasa.smartplug import SmartPlug
from kasa.tapo.tapoplug import TapoPlug
from .device_factory import get_device_class_from_info from .device_factory import get_device_class_from_info
@ -378,27 +380,38 @@ class Discover:
f"Unable to read response from device: {ip}: {ex}" f"Unable to read response from device: {ip}: {ex}"
) from ex ) from ex
if ( type_ = discovery_result.device_type
discovery_result.mgt_encrypt_schm.encrypt_type == "KLAP" encrypt_type_ = (
and discovery_result.mgt_encrypt_schm.lv is None f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}"
): )
type_ = discovery_result.device_type device_class = None
device_class = None
if type_.upper() == "IOT.SMARTPLUGSWITCH":
device_class = SmartPlug
if device_class: supported_device_types: dict[str, Type[SmartDevice]] = {
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info) "SMART.TAPOPLUG": TapoPlug,
device = device_class(ip, port=port, credentials=credentials) "SMART.KASAPLUG": TapoPlug,
device.update_from_discover_info(discovery_result.get_dict()) "IOT.SMARTPLUGSWITCH": SmartPlug,
device.protocol = TPLinkKlap(ip, credentials=credentials) }
return device supported_device_protocols: dict[str, Type[TPLinkProtocol]] = {
else: "IOT.KLAP": TPLinkKlap,
raise UnsupportedDeviceException( "SMART.AES": TPLinkAes,
f"Unsupported device {ip} of type {type_}: {info}" }
)
else: if (device_class := supported_device_types.get(type_)) is None:
raise UnsupportedDeviceException(f"Unsupported device {ip}: {info}") _LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceException(
f"Unsupported device {ip} of type {type_}: {info}"
)
if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None:
_LOGGER.warning("Got unsupported device type: %s", encrypt_type_)
raise UnsupportedDeviceException(
f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}"
)
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
device = device_class(ip, port=port, credentials=credentials)
device.protocol = protocol_class(ip, credentials=credentials)
device.update_from_discover_info(discovery_result.get_dict())
return device
class DiscoveryResult(BaseModel): class DiscoveryResult(BaseModel):
@ -415,7 +428,7 @@ class DiscoveryResult(BaseModel):
is_support_https: Optional[bool] = None is_support_https: Optional[bool] = None
encrypt_type: Optional[str] = None encrypt_type: Optional[str] = None
http_port: Optional[int] = None http_port: Optional[int] = None
lv: Optional[int] = None lv: Optional[int] = 1
device_type: str = Field(alias="device_type_text") device_type: str = Field(alias="device_type_text")
device_model: str = Field(alias="model") device_model: str = Field(alias="model")

164
kasa/tapo/tapodevice.py Normal file
View File

@ -0,0 +1,164 @@
"""Module for a TAPO device."""
import base64
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Set, cast
from ..aesprotocol import TPLinkAes
from ..credentials import Credentials
from ..exceptions import AuthenticationException
from ..smartdevice import SmartDevice
_LOGGER = logging.getLogger(__name__)
class TapoDevice(SmartDevice):
"""Base class to represent a TAPO device."""
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
self._state_information: Dict[str, Any] = {}
self._discovery_info: Optional[Dict[str, Any]] = None
self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout)
async def update(self, update_children: bool = True):
"""Update the device."""
if self.credentials is None or self.credentials.username is None:
raise AuthenticationException("Tapo plug requires authentication.")
self._info = await self.protocol.query("get_device_info")
self._usage = await self.protocol.query("get_device_usage")
self._time = await self.protocol.query("get_device_time")
self._last_update = self._data = {
"info": self._info,
"usage": self._usage,
"time": self._time,
}
_LOGGER.debug("Got an update: %s", self._data)
@property
def sys_info(self) -> Dict[str, Any]:
"""Returns the device info."""
return self._info
@property
def model(self) -> str:
"""Returns the device model."""
return str(self._info.get("model"))
@property
def alias(self) -> str:
"""Returns the device alias or nickname."""
return base64.b64decode(str(self._info.get("nickname"))).decode()
@property
def time(self) -> datetime:
"""Return the time."""
td = timedelta(minutes=cast(float, self._time.get("time_diff")))
if self._time.get("region"):
tz = timezone(td, str(self._time.get("region")))
else:
# in case the device returns a blank region this will result in the
# tzname being a UTC offset
tz = timezone(td)
return datetime.fromtimestamp(
cast(float, self._time.get("timestamp")),
tz=tz,
)
@property
def timezone(self) -> Dict:
"""Return the timezone and time_difference."""
ti = self.time
return {"timezone": ti.tzname()}
@property
def hw_info(self) -> Dict:
"""Return hardware info for the device."""
return {
"sw_ver": self._info.get("fw_ver"),
"hw_ver": self._info.get("hw_ver"),
"mac": self._info.get("mac"),
"type": self._info.get("type"),
"hwId": self._info.get("device_id"),
"dev_name": self.alias,
"oemId": self._info.get("oem_id"),
}
@property
def location(self) -> Dict:
"""Return the device location."""
loc = {
"latitude": cast(float, self._info.get("latitude")) / 10_000,
"longitude": cast(float, self._info.get("longitude")) / 10_000,
}
return loc
@property
def rssi(self) -> Optional[int]:
"""Return the rssi."""
rssi = self._info.get("rssi")
return int(rssi) if rssi else None
@property
def mac(self) -> str:
"""Return the mac formatted with colons."""
return str(self._info.get("mac")).replace("-", ":")
@property
def device_id(self) -> str:
"""Return the device id."""
return str(self._info.get("device_id"))
@property
def internal_state(self) -> Any:
"""Return all the internal state data."""
return self._data
async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
) -> Any:
res = await self.protocol.query({cmd: arg})
return res
@property
def state_information(self) -> Dict[str, Any]:
"""Return the key state information."""
return {
"overheated": self._info.get("overheated"),
"signal_level": self._info.get("signal_level"),
"SSID": base64.b64decode(str(self._info.get("ssid"))).decode(),
}
@property
def features(self) -> Set[str]:
"""Return the list of supported features."""
# TODO:
return set()
@property
def is_on(self) -> bool:
"""Return true if the device is on."""
return bool(self._info.get("device_on"))
async def turn_on(self, **kwargs):
"""Turn on the device."""
await self.protocol.query({"set_device_info": {"device_on": True}})
async def turn_off(self, **kwargs):
"""Turn off the device."""
await self.protocol.query({"set_device_info": {"device_on": False}})
def update_from_discover_info(self, info):
"""Update state from info from the discover call."""
self._discovery_info = info

73
kasa/tapo/tapoplug.py Normal file
View File

@ -0,0 +1,73 @@
"""Module for a TAPO Plug."""
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, cast
from ..credentials import Credentials
from ..emeterstatus import EmeterStatus
from ..smartdevice import DeviceType
from .tapodevice import TapoDevice
_LOGGER = logging.getLogger(__name__)
class TapoPlug(TapoDevice):
"""Class to represent a TAPO Plug."""
def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host, port=port, credentials=credentials, timeout=timeout)
self._device_type = DeviceType.Plug
async def update(self, update_children: bool = True):
"""Call the device endpoint and update the device data."""
await super().update(update_children)
self._energy = await self.protocol.query("get_energy_usage")
self._emeter = await self.protocol.query("get_current_power")
self._data["energy"] = self._energy
self._data["emeter"] = self._emeter
_LOGGER.debug("Got an update: %s %s", self._energy, self._emeter)
@property
def state_information(self) -> Dict[str, Any]:
"""Return the key state information."""
return {
**super().state_information,
**{
"On since": self.on_since,
"auto_off_status": self._info.get("auto_off_status"),
"auto_off_remain_time": self._info.get("auto_off_remain_time"),
},
}
@property
def emeter_realtime(self) -> EmeterStatus:
"""Get the emeter status."""
return EmeterStatus({"power_mw": self._energy.get("current_power")})
@property
def emeter_today(self) -> Optional[float]:
"""Get the emeter value for today."""
return None
@property
def emeter_this_month(self) -> Optional[float]:
"""Get the emeter value for this month."""
return None
@property
def on_since(self) -> Optional[datetime]:
"""Return the time that the device was turned on or None if turned off."""
if not self._info.get("device_on"):
return None
on_time = cast(float, self._info.get("on_time"))
return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time)

View File

@ -114,7 +114,7 @@ UNSUPPORTED = {
"result": { "result": {
"device_id": "xx", "device_id": "xx",
"owner": "xx", "owner": "xx",
"device_type": "SMART.TAPOPLUG", "device_type": "SMART.TAPOXMASTREE",
"device_model": "P110(EU)", "device_model": "P110(EU)",
"ip": "127.0.0.1", "ip": "127.0.0.1",
"mac": "48-22xxx", "mac": "48-22xxx",
@ -150,7 +150,7 @@ async def test_discover_single_unsupported(mocker):
discovery_data = UNSUPPORTED discovery_data = UNSUPPORTED
with pytest.raises( with pytest.raises(
UnsupportedDeviceException, UnsupportedDeviceException,
match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}", match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}",
): ):
await Discover.discover_single(host) await Discover.discover_single(host)