diff --git a/kasa/aesprotocol.py b/kasa/aesprotocol.py new file mode 100644 index 00000000..98776ce2 --- /dev/null +++ b/kasa/aesprotocol.py @@ -0,0 +1,498 @@ +"""Implementation of the TP-Link AES Protocol. + +Based on the work of https://github.com/petretiandrea/plugp100 +under compatible GNU GPL3 license. +""" + +import asyncio +import base64 +import hashlib +import logging +import time +import uuid +from pprint import pformat as pf +from typing import Dict, Optional, Union + +import httpx +from cryptography.hazmat.primitives import hashes, padding, serialization +from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from .credentials import Credentials +from .exceptions import AuthenticationException, SmartDeviceException +from .json import dumps as json_dumps +from .json import loads as json_loads +from .protocol import TPLinkProtocol + +_LOGGER = logging.getLogger(__name__) +logging.getLogger("httpx").propagate = False + + +def _md5(payload: bytes) -> bytes: + digest = hashes.Hash(hashes.MD5()) # noqa: S303 + digest.update(payload) + hash = digest.finalize() + return hash + + +def _sha1(payload: bytes) -> str: + sha1_algo = hashlib.sha1() # noqa: S324 + sha1_algo.update(payload) + return sha1_algo.hexdigest() + + +class TPLinkAes(TPLinkProtocol): + """Implementation of the AES encryption protocol. + + AES is the name used in device discovery for TP-Link's TAPO encryption + protocol, sometimes used by newer firmware versions on kasa devices. + """ + + DEFAULT_PORT = 80 + DEFAULT_TIMEOUT = 5 + SESSION_COOKIE_NAME = "TP_SESSIONID" + COMMON_HEADERS = { + "Content-Type": "application/json", + "requestByApp": "true", + "Accept": "application/json", + } + + def __init__( + self, + host: str, + *, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host=host, port=self.DEFAULT_PORT) + + self.credentials = ( + credentials + if credentials and credentials.username and credentials.password + else Credentials(username="", password="") + ) + + self._local_seed: Optional[bytes] = None + self.local_auth_hash = self.generate_auth_hash(self.credentials) + self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() + self.kasa_setup_auth_hash = None + self.blank_auth_hash = None + self.handshake_lock = asyncio.Lock() + self.query_lock = asyncio.Lock() + self.handshake_done = False + + self.encryption_session: Optional[AesEncyptionSession] = None + self.session_expire_at: Optional[float] = None + + self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT + self.session_cookie = None + self.terminal_uuid = None + self.http_client: Optional[httpx.AsyncClient] = None + self.request_id_generator = SnowflakeId(1, 1) + self.login_token = None + + _LOGGER.debug("Created AES object for %s", self.host) + + def hash_credentials(self, credentials, try_login_version2): + """Hash the credentials.""" + if try_login_version2: + un = base64.b64encode( + _sha1(credentials.username.encode()).encode() + ).decode() + pw = base64.b64encode( + _sha1(credentials.password.encode()).encode() + ).decode() + else: + un = base64.b64encode( + _sha1(credentials.username.encode()).encode() + ).decode() + pw = base64.b64encode(credentials.password.encode()).decode() + return un, pw + + async def client_post(self, url, params=None, data=None, json=None, headers=None): + """Send an http post request to the device.""" + response_data = None + cookies = None + if self.session_cookie: + cookies = httpx.Cookies() + cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie) + self.http_client.cookies.clear() + resp = await self.http_client.post( + url, + params=params, + data=data, + json=json, + timeout=self.timeout, + cookies=cookies, + headers=self.COMMON_HEADERS, + ) + if resp.status_code == 200: + response_data = resp.json() + + return resp.status_code, response_data + + async def send_secure_passthrough(self, request): + """Send encrypted message as passthrough.""" + url = f"http://{self.host}/app" + if self.login_token: + url += f"?token={self.login_token}" + raw_request = json_dumps(request) + encrypted_payload = self.encryption_session.encrypt(raw_request.encode()) + passthrough_request = { + "method": "securePassthrough", + "params": {"request": encrypted_payload.decode()}, + } + status_code, resp_dict = await self.client_post(url, json=passthrough_request) + if status_code == 200 and resp_dict["error_code"] == 0: + response = self.encryption_session.decrypt( + resp_dict["result"]["response"].encode() + ) + resp_dict = json_loads(response) + if resp_dict["error_code"] != 0: + raise SmartDeviceException( + f"Could not complete send, response was {resp_dict}", + ) + if "result" in resp_dict: + return resp_dict["result"] + else: + raise AuthenticationException("Could not complete send") + + def get_aes_request(self, method, params=None): + """Get a request message.""" + request = { + "method": method, + "params": params, + "requestID": self.request_id_generator.generate_id(), + "request_time_milis": round(time.time() * 1000), + "terminal_uuid": self.terminal_uuid, + } + return request + + async def perform_login(self, login_v2): + """Login to the device.""" + self.login_token = None + + un, pw = self.hash_credentials(self.credentials, login_v2) + params = {"password": pw, "username": un} + request = self.get_aes_request("login_device", params) + try: + result = await self.send_secure_passthrough(request) + except SmartDeviceException as ex: + raise AuthenticationException(ex) from ex + self.login_token = result["token"] + + async def perform_handshake(self): + """Perform the handshake.""" + _LOGGER.debug("Will perform handshaking...") + _LOGGER.debug("Generating keypair") + + self.handshake_done = False + self.session_expire_at = None + self.session_cookie = None + + url = f"http://{self.host}/app" + key_pair = KeyPair.create_key_pair() + + pub_key = ( + "-----BEGIN PUBLIC KEY-----\n" + + key_pair.get_public_key() + + "\n-----END PUBLIC KEY-----\n" + ) + handshake_params = {"key": pub_key} + _LOGGER.debug(f"Handshake params: {handshake_params}") + + request_body = {"method": "handshake", "params": handshake_params} + + _LOGGER.debug(f"Request {request_body}") + + status_code, resp_dict = await self.client_post(url, json=request_body) + + _LOGGER.debug(f"Device responded with: {resp_dict}") + + if status_code == 200 and resp_dict["error_code"] == 0: + _LOGGER.debug("Decoding handshake key...") + handshake_key = resp_dict["result"]["key"] + + self.session_cookie = self.http_client.cookies.get( # type: ignore + self.SESSION_COOKIE_NAME + ) + if not self.session_cookie: + self.session_cookie = self.http_client.cookies.get( # type: ignore + "SESSIONID" + ) + + self.session_expire_at = time.time() + 86400 + self.encryption_session = AesEncyptionSession.create_from_keypair( + handshake_key, key_pair + ) + + self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode( + "UTF-8" + ) + self.handshake_done = True + + _LOGGER.debug("Handshake with %s complete", self.host) + + else: + raise AuthenticationException("Could not complete handshake") + + def handshake_session_expired(self): + """Return true if session has expired.""" + return ( + self.session_expire_at is None or self.session_expire_at - time.time() <= 0 + ) + + @staticmethod + def generate_auth_hash(creds: Credentials): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + un = creds.username or "" + pw = creds.password or "" + return _md5(_md5(un.encode()) + _md5(pw.encode())) + + @staticmethod + def generate_owner_hash(creds: Credentials): + """Return the MD5 hash of the username in this object.""" + un = creds.username or "" + return _md5(un.encode()) + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + """Query the device retrying for retry_count on failure.""" + async with self.query_lock: + return await self._query(request, retry_count) + + async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + for retry in range(retry_count + 1): + try: + return await self._execute_query(request, retry) + except httpx.CloseError as sdex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {sdex}" + ) from sdex + continue + except httpx.ConnectError as cex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {cex}" + ) from cex + except TimeoutError as tex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device, timed out: {self.host}: {tex}" + ) from tex + except AuthenticationException as auex: + _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + raise auex + except Exception as ex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {ex}" + ) from ex + continue + + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") + + async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: + _LOGGER.debug( + "%s >> %s", + self.host, + _LOGGER.isEnabledFor(logging.DEBUG) and pf(request), + ) + + if not self.http_client: + self.http_client = httpx.AsyncClient() + + if not self.handshake_done or self.handshake_session_expired(): + try: + await self.perform_handshake() + await self.perform_login(False) + except AuthenticationException: + await self.perform_handshake() + await self.perform_login(True) + + if isinstance(request, dict): + aes_method = next(iter(request)) + aes_params = request[aes_method] + else: + aes_method = request + aes_params = None + + aes_request = self.get_aes_request(aes_method, aes_params) + response_data = await self.send_secure_passthrough(aes_request) + + _LOGGER.debug( + "%s << %s", + self.host, + _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), + ) + + return response_data + + async def close(self) -> None: + """Close the protocol.""" + client = self.http_client + self.http_client = None + if client: + await client.aclose() + + +class AesEncyptionSession: + """Class for an AES encryption session.""" + + @staticmethod + def create_from_keypair(handshake_key: str, keypair): + """Create the encryption session.""" + handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8")) + private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8")) + + private_key = serialization.load_der_private_key(private_key_data, None, None) + key_and_iv = private_key.decrypt( + handshake_key_bytes, asymmetric_padding.PKCS1v15() + ) + if key_and_iv is None: + raise ValueError("Decryption failed!") + + return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:]) + + def __init__(self, key, iv): + self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) + self.padding_strategy = padding.PKCS7(algorithms.AES.block_size) + + def encrypt(self, data) -> bytes: + """Encrypt the message.""" + encryptor = self.cipher.encryptor() + padder = self.padding_strategy.padder() + padded_data = padder.update(data) + padder.finalize() + encrypted = encryptor.update(padded_data) + encryptor.finalize() + return base64.b64encode(encrypted) + + def decrypt(self, data) -> str: + """Decrypt the message.""" + decryptor = self.cipher.decryptor() + unpadder = self.padding_strategy.unpadder() + decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize() + unpadded_data = unpadder.update(decrypted) + unpadder.finalize() + return unpadded_data.decode() + + +class KeyPair: + """Class for generating key pairs.""" + + @staticmethod + def create_key_pair(key_size: int = 1024): + """Create a key pair.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + public_key = private_key.public_key() + + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_key_bytes = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return KeyPair( + private_key=base64.b64encode(private_key_bytes).decode("UTF-8"), + public_key=base64.b64encode(public_key_bytes).decode("UTF-8"), + ) + + def __init__(self, private_key: str, public_key: str): + self.private_key = private_key + self.public_key = public_key + + def get_private_key(self) -> str: + """Get the private key.""" + return self.private_key + + def get_public_key(self) -> str: + """Get the public key.""" + return self.public_key + + +class SnowflakeId: + """Class for generating snowflake ids.""" + + EPOCH = 1420041600000 # Custom epoch (in milliseconds) + WORKER_ID_BITS = 5 + DATA_CENTER_ID_BITS = 5 + SEQUENCE_BITS = 12 + + MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1 + MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1 + + SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1 + + def __init__(self, worker_id, data_center_id): + if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0: + raise ValueError( + "Worker ID can't be greater than " + + str(SnowflakeId.MAX_WORKER_ID) + + " or less than 0" + ) + if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0: + raise ValueError( + "Data center ID can't be greater than " + + str(SnowflakeId.MAX_DATA_CENTER_ID) + + " or less than 0" + ) + + self.worker_id = worker_id + self.data_center_id = data_center_id + self.sequence = 0 + self.last_timestamp = -1 + + def generate_id(self): + """Generate a snowflake id.""" + timestamp = self._current_millis() + + if timestamp < self.last_timestamp: + raise ValueError("Clock moved backwards. Refusing to generate ID.") + + if timestamp == self.last_timestamp: + # Within the same millisecond, increment the sequence number + self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK + if self.sequence == 0: + # Sequence exceeds its bit range, wait until the next millisecond + timestamp = self._wait_next_millis(self.last_timestamp) + else: + # New millisecond, reset the sequence number + self.sequence = 0 + + # Update the last timestamp + self.last_timestamp = timestamp + + # Generate and return the final ID + return ( + ( + (timestamp - SnowflakeId.EPOCH) + << ( + SnowflakeId.WORKER_ID_BITS + + SnowflakeId.SEQUENCE_BITS + + SnowflakeId.DATA_CENTER_ID_BITS + ) + ) + | ( + self.data_center_id + << (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS) + ) + | (self.worker_id << SnowflakeId.SEQUENCE_BITS) + | self.sequence + ) + + def _current_millis(self): + return round(time.time() * 1000) + + def _wait_next_millis(self, last_timestamp): + timestamp = self._current_millis() + while timestamp <= last_timestamp: + timestamp = self._current_millis() + return timestamp diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 049969fb..9122003c 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -14,6 +14,7 @@ from .smartdimmer import SmartDimmer from .smartlightstrip import SmartLightStrip from .smartplug import SmartPlug from .smartstrip import SmartStrip +from .tapo.tapoplug import TapoPlug DEVICE_TYPE_TO_CLASS = { DeviceType.Plug: SmartPlug, @@ -21,6 +22,7 @@ DEVICE_TYPE_TO_CLASS = { DeviceType.Strip: SmartStrip, DeviceType.Dimmer: SmartDimmer, DeviceType.LightStrip: SmartLightStrip, + DeviceType.TapoPlug: TapoPlug, } _LOGGER = logging.getLogger(__name__) diff --git a/kasa/device_type.py b/kasa/device_type.py index 162fc4f2..c8657306 100755 --- a/kasa/device_type.py +++ b/kasa/device_type.py @@ -14,6 +14,7 @@ class DeviceType(Enum): StripSocket = "stripsocket" Dimmer = "dimmer" LightStrip = "lightstrip" + TapoPlug = "tapoplug" Unknown = "unknown" @staticmethod diff --git a/kasa/discover.py b/kasa/discover.py index 61faeaff..59849bc0 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -15,14 +15,16 @@ try: except ImportError: from pydantic import BaseModel, Field +from kasa.aesprotocol import TPLinkAes from kasa.credentials import Credentials from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.klapprotocol import TPLinkKlap -from kasa.protocol import TPLinkSmartHomeProtocol +from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartplug import SmartPlug +from kasa.tapo.tapoplug import TapoPlug from .device_factory import get_device_class_from_info @@ -378,27 +380,38 @@ class Discover: f"Unable to read response from device: {ip}: {ex}" ) from ex - if ( - discovery_result.mgt_encrypt_schm.encrypt_type == "KLAP" - and discovery_result.mgt_encrypt_schm.lv is None - ): - type_ = discovery_result.device_type - device_class = None - if type_.upper() == "IOT.SMARTPLUGSWITCH": - device_class = SmartPlug + type_ = discovery_result.device_type + encrypt_type_ = ( + f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}" + ) + device_class = None - if device_class: - _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) - device = device_class(ip, port=port, credentials=credentials) - device.update_from_discover_info(discovery_result.get_dict()) - device.protocol = TPLinkKlap(ip, credentials=credentials) - return device - else: - raise UnsupportedDeviceException( - f"Unsupported device {ip} of type {type_}: {info}" - ) - else: - raise UnsupportedDeviceException(f"Unsupported device {ip}: {info}") + supported_device_types: dict[str, Type[SmartDevice]] = { + "SMART.TAPOPLUG": TapoPlug, + "SMART.KASAPLUG": TapoPlug, + "IOT.SMARTPLUGSWITCH": SmartPlug, + } + supported_device_protocols: dict[str, Type[TPLinkProtocol]] = { + "IOT.KLAP": TPLinkKlap, + "SMART.AES": TPLinkAes, + } + + if (device_class := supported_device_types.get(type_)) is None: + _LOGGER.warning("Got unsupported device type: %s", type_) + raise UnsupportedDeviceException( + f"Unsupported device {ip} of type {type_}: {info}" + ) + if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None: + _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) + raise UnsupportedDeviceException( + f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}" + ) + + _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) + device = device_class(ip, port=port, credentials=credentials) + device.protocol = protocol_class(ip, credentials=credentials) + device.update_from_discover_info(discovery_result.get_dict()) + return device class DiscoveryResult(BaseModel): @@ -415,7 +428,7 @@ class DiscoveryResult(BaseModel): is_support_https: Optional[bool] = None encrypt_type: Optional[str] = None http_port: Optional[int] = None - lv: Optional[int] = None + lv: Optional[int] = 1 device_type: str = Field(alias="device_type_text") device_model: str = Field(alias="model") diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py new file mode 100644 index 00000000..e02983d3 --- /dev/null +++ b/kasa/tapo/tapodevice.py @@ -0,0 +1,164 @@ +"""Module for a TAPO device.""" +import base64 +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Set, cast + +from ..aesprotocol import TPLinkAes +from ..credentials import Credentials +from ..exceptions import AuthenticationException +from ..smartdevice import SmartDevice + +_LOGGER = logging.getLogger(__name__) + + +class TapoDevice(SmartDevice): + """Base class to represent a TAPO device.""" + + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host, port=port, credentials=credentials, timeout=timeout) + self._state_information: Dict[str, Any] = {} + self._discovery_info: Optional[Dict[str, Any]] = None + self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout) + + async def update(self, update_children: bool = True): + """Update the device.""" + if self.credentials is None or self.credentials.username is None: + raise AuthenticationException("Tapo plug requires authentication.") + + self._info = await self.protocol.query("get_device_info") + self._usage = await self.protocol.query("get_device_usage") + self._time = await self.protocol.query("get_device_time") + + self._last_update = self._data = { + "info": self._info, + "usage": self._usage, + "time": self._time, + } + + _LOGGER.debug("Got an update: %s", self._data) + + @property + def sys_info(self) -> Dict[str, Any]: + """Returns the device info.""" + return self._info + + @property + def model(self) -> str: + """Returns the device model.""" + return str(self._info.get("model")) + + @property + def alias(self) -> str: + """Returns the device alias or nickname.""" + return base64.b64decode(str(self._info.get("nickname"))).decode() + + @property + def time(self) -> datetime: + """Return the time.""" + td = timedelta(minutes=cast(float, self._time.get("time_diff"))) + if self._time.get("region"): + tz = timezone(td, str(self._time.get("region"))) + else: + # in case the device returns a blank region this will result in the + # tzname being a UTC offset + tz = timezone(td) + return datetime.fromtimestamp( + cast(float, self._time.get("timestamp")), + tz=tz, + ) + + @property + def timezone(self) -> Dict: + """Return the timezone and time_difference.""" + ti = self.time + return {"timezone": ti.tzname()} + + @property + def hw_info(self) -> Dict: + """Return hardware info for the device.""" + return { + "sw_ver": self._info.get("fw_ver"), + "hw_ver": self._info.get("hw_ver"), + "mac": self._info.get("mac"), + "type": self._info.get("type"), + "hwId": self._info.get("device_id"), + "dev_name": self.alias, + "oemId": self._info.get("oem_id"), + } + + @property + def location(self) -> Dict: + """Return the device location.""" + loc = { + "latitude": cast(float, self._info.get("latitude")) / 10_000, + "longitude": cast(float, self._info.get("longitude")) / 10_000, + } + return loc + + @property + def rssi(self) -> Optional[int]: + """Return the rssi.""" + rssi = self._info.get("rssi") + return int(rssi) if rssi else None + + @property + def mac(self) -> str: + """Return the mac formatted with colons.""" + return str(self._info.get("mac")).replace("-", ":") + + @property + def device_id(self) -> str: + """Return the device id.""" + return str(self._info.get("device_id")) + + @property + def internal_state(self) -> Any: + """Return all the internal state data.""" + return self._data + + async def _query_helper( + self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None + ) -> Any: + res = await self.protocol.query({cmd: arg}) + + return res + + @property + def state_information(self) -> Dict[str, Any]: + """Return the key state information.""" + return { + "overheated": self._info.get("overheated"), + "signal_level": self._info.get("signal_level"), + "SSID": base64.b64decode(str(self._info.get("ssid"))).decode(), + } + + @property + def features(self) -> Set[str]: + """Return the list of supported features.""" + # TODO: + return set() + + @property + def is_on(self) -> bool: + """Return true if the device is on.""" + return bool(self._info.get("device_on")) + + async def turn_on(self, **kwargs): + """Turn on the device.""" + await self.protocol.query({"set_device_info": {"device_on": True}}) + + async def turn_off(self, **kwargs): + """Turn off the device.""" + await self.protocol.query({"set_device_info": {"device_on": False}}) + + def update_from_discover_info(self, info): + """Update state from info from the discover call.""" + self._discovery_info = info diff --git a/kasa/tapo/tapoplug.py b/kasa/tapo/tapoplug.py new file mode 100644 index 00000000..c291c7d1 --- /dev/null +++ b/kasa/tapo/tapoplug.py @@ -0,0 +1,73 @@ +"""Module for a TAPO Plug.""" +import logging +from datetime import datetime, timedelta +from typing import Any, Dict, Optional, cast + +from ..credentials import Credentials +from ..emeterstatus import EmeterStatus +from ..smartdevice import DeviceType +from .tapodevice import TapoDevice + +_LOGGER = logging.getLogger(__name__) + + +class TapoPlug(TapoDevice): + """Class to represent a TAPO Plug.""" + + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host, port=port, credentials=credentials, timeout=timeout) + self._device_type = DeviceType.Plug + + async def update(self, update_children: bool = True): + """Call the device endpoint and update the device data.""" + await super().update(update_children) + + self._energy = await self.protocol.query("get_energy_usage") + self._emeter = await self.protocol.query("get_current_power") + + self._data["energy"] = self._energy + self._data["emeter"] = self._emeter + + _LOGGER.debug("Got an update: %s %s", self._energy, self._emeter) + + @property + def state_information(self) -> Dict[str, Any]: + """Return the key state information.""" + return { + **super().state_information, + **{ + "On since": self.on_since, + "auto_off_status": self._info.get("auto_off_status"), + "auto_off_remain_time": self._info.get("auto_off_remain_time"), + }, + } + + @property + def emeter_realtime(self) -> EmeterStatus: + """Get the emeter status.""" + return EmeterStatus({"power_mw": self._energy.get("current_power")}) + + @property + def emeter_today(self) -> Optional[float]: + """Get the emeter value for today.""" + return None + + @property + def emeter_this_month(self) -> Optional[float]: + """Get the emeter value for this month.""" + return None + + @property + def on_since(self) -> Optional[datetime]: + """Return the time that the device was turned on or None if turned off.""" + if not self._info.get("device_on"): + return None + on_time = cast(float, self._info.get("on_time")) + return datetime.now().replace(microsecond=0) - timedelta(seconds=on_time) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 7e1dabc0..626afd18 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -114,7 +114,7 @@ UNSUPPORTED = { "result": { "device_id": "xx", "owner": "xx", - "device_type": "SMART.TAPOPLUG", + "device_type": "SMART.TAPOXMASTREE", "device_model": "P110(EU)", "ip": "127.0.0.1", "mac": "48-22xxx", @@ -150,7 +150,7 @@ async def test_discover_single_unsupported(mocker): discovery_data = UNSUPPORTED with pytest.raises( UnsupportedDeviceException, - match=f"Unsupported device {host}: {re.escape(str(UNSUPPORTED))}", + match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}", ): await Discover.discover_single(host)