mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-03 22:22:06 +00:00 
			
		
		
		
	Redact sensitive info from debug logs (#1069)
Redacts sensitive data when debug logging device responses such as mac, location and usernames
This commit is contained in:
		@@ -87,7 +87,8 @@ import ipaddress
 | 
			
		||||
import logging
 | 
			
		||||
import socket
 | 
			
		||||
from collections.abc import Awaitable
 | 
			
		||||
from typing import Callable, Dict, Optional, Type, cast
 | 
			
		||||
from pprint import pformat as pf
 | 
			
		||||
from typing import Any, Callable, Dict, Optional, Type, cast
 | 
			
		||||
 | 
			
		||||
# When support for cpython older than 3.11 is dropped
 | 
			
		||||
# async_timeout can be replaced with asyncio.timeout
 | 
			
		||||
@@ -112,8 +113,10 @@ from kasa.exceptions import (
 | 
			
		||||
    UnsupportedDeviceError,
 | 
			
		||||
)
 | 
			
		||||
from kasa.iot.iotdevice import IotDevice
 | 
			
		||||
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
 | 
			
		||||
from kasa.json import dumps as json_dumps
 | 
			
		||||
from kasa.json import loads as json_loads
 | 
			
		||||
from kasa.protocol import mask_mac, redact_data
 | 
			
		||||
from kasa.xortransport import XorEncryption
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
@@ -123,6 +126,12 @@ OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
 | 
			
		||||
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
 | 
			
		||||
DeviceDict = Dict[str, Device]
 | 
			
		||||
 | 
			
		||||
NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
 | 
			
		||||
    "device_id": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "owner": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "mac": mask_mac,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _DiscoverProtocol(asyncio.DatagramProtocol):
 | 
			
		||||
    """Implementation of the discovery protocol handler.
 | 
			
		||||
@@ -293,6 +302,8 @@ class Discover:
 | 
			
		||||
    DISCOVERY_PORT_2 = 20002
 | 
			
		||||
    DISCOVERY_QUERY_2 = binascii.unhexlify("020000010000000000000000463cb5d3")
 | 
			
		||||
 | 
			
		||||
    _redact_data = True
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def discover(
 | 
			
		||||
        *,
 | 
			
		||||
@@ -484,7 +495,9 @@ class Discover:
 | 
			
		||||
                f"Unable to read response from device: {config.host}: {ex}"
 | 
			
		||||
            ) from ex
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
 | 
			
		||||
        if _LOGGER.isEnabledFor(logging.DEBUG):
 | 
			
		||||
            data = redact_data(info, IOT_REDACTORS) if Discover._redact_data else info
 | 
			
		||||
            _LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
 | 
			
		||||
 | 
			
		||||
        device_class = cast(Type[IotDevice], Discover._get_device_class(info))
 | 
			
		||||
        device = device_class(config.host, config=config)
 | 
			
		||||
@@ -504,6 +517,7 @@ class Discover:
 | 
			
		||||
        config: DeviceConfig,
 | 
			
		||||
    ) -> Device:
 | 
			
		||||
        """Get SmartDevice from the new 20002 response."""
 | 
			
		||||
        debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
 | 
			
		||||
        try:
 | 
			
		||||
            info = json_loads(data[16:])
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
@@ -514,9 +528,17 @@ class Discover:
 | 
			
		||||
        try:
 | 
			
		||||
            discovery_result = DiscoveryResult(**info["result"])
 | 
			
		||||
        except ValidationError as ex:
 | 
			
		||||
            _LOGGER.debug(
 | 
			
		||||
                "Unable to parse discovery from device %s: %s", config.host, info
 | 
			
		||||
            )
 | 
			
		||||
            if debug_enabled:
 | 
			
		||||
                data = (
 | 
			
		||||
                    redact_data(info, NEW_DISCOVERY_REDACTORS)
 | 
			
		||||
                    if Discover._redact_data
 | 
			
		||||
                    else info
 | 
			
		||||
                )
 | 
			
		||||
                _LOGGER.debug(
 | 
			
		||||
                    "Unable to parse discovery from device %s: %s",
 | 
			
		||||
                    config.host,
 | 
			
		||||
                    pf(data),
 | 
			
		||||
                )
 | 
			
		||||
            raise UnsupportedDeviceError(
 | 
			
		||||
                f"Unable to parse discovery from device: {config.host}: {ex}"
 | 
			
		||||
            ) from ex
 | 
			
		||||
@@ -551,7 +573,13 @@ class Discover:
 | 
			
		||||
                discovery_result=discovery_result.get_dict(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info)
 | 
			
		||||
        if debug_enabled:
 | 
			
		||||
            data = (
 | 
			
		||||
                redact_data(info, NEW_DISCOVERY_REDACTORS)
 | 
			
		||||
                if Discover._redact_data
 | 
			
		||||
                else info
 | 
			
		||||
            )
 | 
			
		||||
            _LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
 | 
			
		||||
        device = device_class(config.host, protocol=protocol)
 | 
			
		||||
 | 
			
		||||
        di = discovery_result.get_dict()
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,8 @@ from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
from pprint import pformat as pf
 | 
			
		||||
from typing import Any, Callable
 | 
			
		||||
 | 
			
		||||
from .deviceconfig import DeviceConfig
 | 
			
		||||
from .exceptions import (
 | 
			
		||||
@@ -14,11 +16,26 @@ from .exceptions import (
 | 
			
		||||
    _RetryableError,
 | 
			
		||||
)
 | 
			
		||||
from .json import dumps as json_dumps
 | 
			
		||||
from .protocol import BaseProtocol, BaseTransport
 | 
			
		||||
from .protocol import BaseProtocol, BaseTransport, mask_mac, redact_data
 | 
			
		||||
from .xortransport import XorEncryption, XorTransport
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
REDACTORS: dict[str, Callable[[Any], Any] | None] = {
 | 
			
		||||
    "latitude": lambda x: 0,
 | 
			
		||||
    "longitude": lambda x: 0,
 | 
			
		||||
    "latitude_i": lambda x: 0,
 | 
			
		||||
    "longitude_i": lambda x: 0,
 | 
			
		||||
    "deviceId": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "id": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "alias": lambda x: "#MASKED_NAME#" if x else "",
 | 
			
		||||
    "mac": mask_mac,
 | 
			
		||||
    "mic_mac": mask_mac,
 | 
			
		||||
    "ssid": lambda x: "#MASKED_SSID#" if x else "",
 | 
			
		||||
    "oemId": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "username": lambda _: "user@example.com",  # cnCloud
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IotProtocol(BaseProtocol):
 | 
			
		||||
    """Class for the legacy TPLink IOT KASA Protocol."""
 | 
			
		||||
@@ -34,6 +51,7 @@ class IotProtocol(BaseProtocol):
 | 
			
		||||
        super().__init__(transport=transport)
 | 
			
		||||
 | 
			
		||||
        self._query_lock = asyncio.Lock()
 | 
			
		||||
        self._redact_data = True
 | 
			
		||||
 | 
			
		||||
    async def query(self, request: str | dict, retry_count: int = 3) -> dict:
 | 
			
		||||
        """Query the device retrying for retry_count on failure."""
 | 
			
		||||
@@ -85,7 +103,24 @@ class IotProtocol(BaseProtocol):
 | 
			
		||||
        raise KasaException("Query reached somehow to unreachable")
 | 
			
		||||
 | 
			
		||||
    async def _execute_query(self, request: str, retry_count: int) -> dict:
 | 
			
		||||
        return await self._transport.send(request)
 | 
			
		||||
        debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
 | 
			
		||||
 | 
			
		||||
        if debug_enabled:
 | 
			
		||||
            _LOGGER.debug(
 | 
			
		||||
                "%s >> %s",
 | 
			
		||||
                self._host,
 | 
			
		||||
                request,
 | 
			
		||||
            )
 | 
			
		||||
        resp = await self._transport.send(request)
 | 
			
		||||
 | 
			
		||||
        if debug_enabled:
 | 
			
		||||
            data = redact_data(resp, REDACTORS) if self._redact_data else resp
 | 
			
		||||
            _LOGGER.debug(
 | 
			
		||||
                "%s << %s",
 | 
			
		||||
                self._host,
 | 
			
		||||
                pf(data),
 | 
			
		||||
            )
 | 
			
		||||
        return resp
 | 
			
		||||
 | 
			
		||||
    async def close(self) -> None:
 | 
			
		||||
        """Close the underlying transport."""
 | 
			
		||||
 
 | 
			
		||||
@@ -50,7 +50,6 @@ import logging
 | 
			
		||||
import secrets
 | 
			
		||||
import struct
 | 
			
		||||
import time
 | 
			
		||||
from pprint import pformat as pf
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
 | 
			
		||||
from cryptography.hazmat.primitives import padding
 | 
			
		||||
@@ -349,7 +348,7 @@ class KlapTransport(BaseTransport):
 | 
			
		||||
                    + f"request with seq {seq}"
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            _LOGGER.debug("Query posted " + msg)
 | 
			
		||||
            _LOGGER.debug("Device %s query posted %s", self._host, msg)
 | 
			
		||||
 | 
			
		||||
            # Check for mypy
 | 
			
		||||
            if self._encryption_session is not None:
 | 
			
		||||
@@ -357,11 +356,7 @@ class KlapTransport(BaseTransport):
 | 
			
		||||
 | 
			
		||||
            json_payload = json_loads(decrypted_response)
 | 
			
		||||
 | 
			
		||||
            _LOGGER.debug(
 | 
			
		||||
                "%s << %s",
 | 
			
		||||
                self._host,
 | 
			
		||||
                _LOGGER.isEnabledFor(logging.DEBUG) and pf(json_payload),
 | 
			
		||||
            )
 | 
			
		||||
            _LOGGER.debug("Device %s query response received", self._host)
 | 
			
		||||
 | 
			
		||||
            return json_payload
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -18,6 +18,7 @@ import hashlib
 | 
			
		||||
import logging
 | 
			
		||||
import struct
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Any, Callable, TypeVar, cast
 | 
			
		||||
 | 
			
		||||
# When support for cpython older than 3.11 is dropped
 | 
			
		||||
# async_timeout can be replaced with asyncio.timeout
 | 
			
		||||
@@ -28,6 +29,46 @@ _LOGGER = logging.getLogger(__name__)
 | 
			
		||||
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}
 | 
			
		||||
_UNSIGNED_INT_NETWORK_ORDER = struct.Struct(">I")
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def redact_data(data: _T, redactors: dict[str, Callable[[Any], Any] | None]) -> _T:
 | 
			
		||||
    """Redact sensitive data for logging."""
 | 
			
		||||
    if not isinstance(data, (dict, list)):
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    if isinstance(data, list):
 | 
			
		||||
        return cast(_T, [redact_data(val, redactors) for val in data])
 | 
			
		||||
 | 
			
		||||
    redacted = {**data}
 | 
			
		||||
 | 
			
		||||
    for key, value in redacted.items():
 | 
			
		||||
        if value is None:
 | 
			
		||||
            continue
 | 
			
		||||
        if isinstance(value, str) and not value:
 | 
			
		||||
            continue
 | 
			
		||||
        if key in redactors:
 | 
			
		||||
            if redactor := redactors[key]:
 | 
			
		||||
                try:
 | 
			
		||||
                    redacted[key] = redactor(value)
 | 
			
		||||
                except:  # noqa: E722
 | 
			
		||||
                    redacted[key] = "**REDACTEX**"
 | 
			
		||||
            else:
 | 
			
		||||
                redacted[key] = "**REDACTED**"
 | 
			
		||||
        elif isinstance(value, dict):
 | 
			
		||||
            redacted[key] = redact_data(value, redactors)
 | 
			
		||||
        elif isinstance(value, list):
 | 
			
		||||
            redacted[key] = [redact_data(item, redactors) for item in value]
 | 
			
		||||
 | 
			
		||||
    return cast(_T, redacted)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mask_mac(mac: str) -> str:
 | 
			
		||||
    """Return mac address with last two octects blanked."""
 | 
			
		||||
    delim = ":" if ":" in mac else "-"
 | 
			
		||||
    rest = delim.join(format(s, "02x") for s in bytes.fromhex("000000"))
 | 
			
		||||
    return f"{mac[:8]}{delim}{rest}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def md5(payload: bytes) -> bytes:
 | 
			
		||||
    """Return the MD5 hash of the payload."""
 | 
			
		||||
 
 | 
			
		||||
@@ -193,11 +193,9 @@ class SmartDevice(Device):
 | 
			
		||||
        if not self._features:
 | 
			
		||||
            await self._initialize_features()
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug(
 | 
			
		||||
            "Update completed %s: %s",
 | 
			
		||||
            self.host,
 | 
			
		||||
            self._last_update if first_update else resp,
 | 
			
		||||
        )
 | 
			
		||||
        if _LOGGER.isEnabledFor(logging.DEBUG):
 | 
			
		||||
            updated = self._last_update if first_update else resp
 | 
			
		||||
            _LOGGER.debug("Update completed %s: %s", self.host, list(updated.keys()))
 | 
			
		||||
 | 
			
		||||
    def _handle_module_post_update_hook(self, module: SmartModule) -> bool:
 | 
			
		||||
        try:
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ import logging
 | 
			
		||||
import time
 | 
			
		||||
import uuid
 | 
			
		||||
from pprint import pformat as pf
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Any, Callable
 | 
			
		||||
 | 
			
		||||
from .exceptions import (
 | 
			
		||||
    SMART_AUTHENTICATION_ERRORS,
 | 
			
		||||
@@ -26,10 +26,31 @@ from .exceptions import (
 | 
			
		||||
    _RetryableError,
 | 
			
		||||
)
 | 
			
		||||
from .json import dumps as json_dumps
 | 
			
		||||
from .protocol import BaseProtocol, BaseTransport, md5
 | 
			
		||||
from .protocol import BaseProtocol, BaseTransport, mask_mac, md5, redact_data
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
REDACTORS: dict[str, Callable[[Any], Any] | None] = {
 | 
			
		||||
    "latitude": lambda x: 0,
 | 
			
		||||
    "longitude": lambda x: 0,
 | 
			
		||||
    "la": lambda x: 0,  # lat on ks240
 | 
			
		||||
    "lo": lambda x: 0,  # lon on ks240
 | 
			
		||||
    "device_id": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "parent_device_id": lambda x: "REDACTED_" + x[9::],  # Hub attached children
 | 
			
		||||
    "original_device_id": lambda x: "REDACTED_" + x[9::],  # Strip children
 | 
			
		||||
    "nickname": lambda x: "I01BU0tFRF9OQU1FIw==" if x else "",
 | 
			
		||||
    "mac": mask_mac,
 | 
			
		||||
    "ssid": lambda x: "I01BU0tFRF9TU0lEIw=" if x else "",
 | 
			
		||||
    "bssid": lambda _: "000000000000",
 | 
			
		||||
    "oem_id": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
    "setup_code": None,  # matter
 | 
			
		||||
    "setup_payload": None,  # matter
 | 
			
		||||
    "mfi_setup_code": None,  # mfi_ for homekit
 | 
			
		||||
    "mfi_setup_id": None,
 | 
			
		||||
    "mfi_token_token": None,
 | 
			
		||||
    "mfi_token_uuid": None,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SmartProtocol(BaseProtocol):
 | 
			
		||||
    """Class for the new TPLink SMART protocol."""
 | 
			
		||||
@@ -50,6 +71,7 @@ class SmartProtocol(BaseProtocol):
 | 
			
		||||
        self._multi_request_batch_size = (
 | 
			
		||||
            self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
 | 
			
		||||
        )
 | 
			
		||||
        self._redact_data = True
 | 
			
		||||
 | 
			
		||||
    def get_smart_request(self, method, params=None) -> str:
 | 
			
		||||
        """Get a request message as a string."""
 | 
			
		||||
@@ -167,11 +189,15 @@ class SmartProtocol(BaseProtocol):
 | 
			
		||||
                )
 | 
			
		||||
            response_step = await self._transport.send(smart_request)
 | 
			
		||||
            if debug_enabled:
 | 
			
		||||
                if self._redact_data:
 | 
			
		||||
                    data = redact_data(response_step, REDACTORS)
 | 
			
		||||
                else:
 | 
			
		||||
                    data = response_step
 | 
			
		||||
                _LOGGER.debug(
 | 
			
		||||
                    "%s %s << %s",
 | 
			
		||||
                    self._host,
 | 
			
		||||
                    batch_name,
 | 
			
		||||
                    pf(response_step),
 | 
			
		||||
                    pf(data),
 | 
			
		||||
                )
 | 
			
		||||
            try:
 | 
			
		||||
                self._handle_response_error_code(response_step, batch_name)
 | 
			
		||||
 
 | 
			
		||||
@@ -90,21 +90,26 @@ def create_discovery_mock(ip: str, fixture_data: dict):
 | 
			
		||||
        query_data: dict
 | 
			
		||||
        device_type: str
 | 
			
		||||
        encrypt_type: str
 | 
			
		||||
        _datagram: bytes
 | 
			
		||||
        login_version: int | None = None
 | 
			
		||||
        port_override: int | None = None
 | 
			
		||||
 | 
			
		||||
        @property
 | 
			
		||||
        def _datagram(self) -> bytes:
 | 
			
		||||
            if self.default_port == 9999:
 | 
			
		||||
                return XorEncryption.encrypt(json_dumps(self.discovery_data))[4:]
 | 
			
		||||
            else:
 | 
			
		||||
                return (
 | 
			
		||||
                    b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
 | 
			
		||||
                    + json_dumps(self.discovery_data).encode()
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    if "discovery_result" in fixture_data:
 | 
			
		||||
        discovery_data = {"result": fixture_data["discovery_result"]}
 | 
			
		||||
        discovery_data = {"result": fixture_data["discovery_result"].copy()}
 | 
			
		||||
        device_type = fixture_data["discovery_result"]["device_type"]
 | 
			
		||||
        encrypt_type = fixture_data["discovery_result"]["mgt_encrypt_schm"][
 | 
			
		||||
            "encrypt_type"
 | 
			
		||||
        ]
 | 
			
		||||
        login_version = fixture_data["discovery_result"]["mgt_encrypt_schm"].get("lv")
 | 
			
		||||
        datagram = (
 | 
			
		||||
            b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
 | 
			
		||||
            + json_dumps(discovery_data).encode()
 | 
			
		||||
        )
 | 
			
		||||
        dm = _DiscoveryMock(
 | 
			
		||||
            ip,
 | 
			
		||||
            80,
 | 
			
		||||
@@ -113,16 +118,14 @@ def create_discovery_mock(ip: str, fixture_data: dict):
 | 
			
		||||
            fixture_data,
 | 
			
		||||
            device_type,
 | 
			
		||||
            encrypt_type,
 | 
			
		||||
            datagram,
 | 
			
		||||
            login_version,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        sys_info = fixture_data["system"]["get_sysinfo"]
 | 
			
		||||
        discovery_data = {"system": {"get_sysinfo": sys_info}}
 | 
			
		||||
        discovery_data = {"system": {"get_sysinfo": sys_info.copy()}}
 | 
			
		||||
        device_type = sys_info.get("mic_type") or sys_info.get("type")
 | 
			
		||||
        encrypt_type = "XOR"
 | 
			
		||||
        login_version = None
 | 
			
		||||
        datagram = XorEncryption.encrypt(json_dumps(discovery_data))[4:]
 | 
			
		||||
        dm = _DiscoveryMock(
 | 
			
		||||
            ip,
 | 
			
		||||
            9999,
 | 
			
		||||
@@ -131,7 +134,6 @@ def create_discovery_mock(ip: str, fixture_data: dict):
 | 
			
		||||
            fixture_data,
 | 
			
		||||
            device_type,
 | 
			
		||||
            encrypt_type,
 | 
			
		||||
            datagram,
 | 
			
		||||
            login_version,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@
 | 
			
		||||
# ruff: noqa: S106
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
import re
 | 
			
		||||
import socket
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
@@ -565,3 +566,38 @@ async def test_do_discover_external_cancel(mocker):
 | 
			
		||||
    with pytest.raises(asyncio.TimeoutError):
 | 
			
		||||
        async with asyncio_timeout(0):
 | 
			
		||||
            await dp.wait_for_discovery_to_complete()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_discovery_redaction(discovery_mock, caplog: pytest.LogCaptureFixture):
 | 
			
		||||
    """Test query sensitive info redaction."""
 | 
			
		||||
    mac = "12:34:56:78:9A:BC"
 | 
			
		||||
 | 
			
		||||
    if discovery_mock.default_port == 9999:
 | 
			
		||||
        sysinfo = discovery_mock.discovery_data["system"]["get_sysinfo"]
 | 
			
		||||
        if "mac" in sysinfo:
 | 
			
		||||
            sysinfo["mac"] = mac
 | 
			
		||||
        elif "mic_mac" in sysinfo:
 | 
			
		||||
            sysinfo["mic_mac"] = mac
 | 
			
		||||
    else:
 | 
			
		||||
        discovery_mock.discovery_data["result"]["mac"] = mac
 | 
			
		||||
 | 
			
		||||
    # Info no message logging
 | 
			
		||||
    caplog.set_level(logging.INFO)
 | 
			
		||||
    await Discover.discover()
 | 
			
		||||
 | 
			
		||||
    assert mac not in caplog.text
 | 
			
		||||
 | 
			
		||||
    caplog.set_level(logging.DEBUG)
 | 
			
		||||
 | 
			
		||||
    # Debug no redaction
 | 
			
		||||
    caplog.clear()
 | 
			
		||||
    Discover._redact_data = False
 | 
			
		||||
    await Discover.discover()
 | 
			
		||||
    assert mac in caplog.text
 | 
			
		||||
 | 
			
		||||
    # Debug redaction
 | 
			
		||||
    caplog.clear()
 | 
			
		||||
    Discover._redact_data = True
 | 
			
		||||
    await Discover.discover()
 | 
			
		||||
    assert mac not in caplog.text
 | 
			
		||||
    assert "12:34:56:00:00:00" in caplog.text
 | 
			
		||||
 
 | 
			
		||||
@@ -8,9 +8,12 @@ import os
 | 
			
		||||
import pkgutil
 | 
			
		||||
import struct
 | 
			
		||||
import sys
 | 
			
		||||
from typing import cast
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from kasa.iot import IotDevice
 | 
			
		||||
 | 
			
		||||
from ..aestransport import AesTransport
 | 
			
		||||
from ..credentials import Credentials
 | 
			
		||||
from ..deviceconfig import DeviceConfig
 | 
			
		||||
@@ -20,8 +23,12 @@ from ..klaptransport import KlapTransport, KlapTransportV2
 | 
			
		||||
from ..protocol import (
 | 
			
		||||
    BaseProtocol,
 | 
			
		||||
    BaseTransport,
 | 
			
		||||
    mask_mac,
 | 
			
		||||
    redact_data,
 | 
			
		||||
)
 | 
			
		||||
from ..xortransport import XorEncryption, XorTransport
 | 
			
		||||
from .conftest import device_iot
 | 
			
		||||
from .fakeprotocol_iot import FakeIotTransport
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
@@ -614,3 +621,63 @@ def test_deprecated_protocol():
 | 
			
		||||
        host = "127.0.0.1"
 | 
			
		||||
        proto = TPLinkSmartHomeProtocol(host=host)
 | 
			
		||||
        assert proto.config.host == host
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@device_iot
 | 
			
		||||
async def test_iot_queries_redaction(dev: IotDevice, caplog: pytest.LogCaptureFixture):
 | 
			
		||||
    """Test query sensitive info redaction."""
 | 
			
		||||
    device_id = "123456789ABCDEF"
 | 
			
		||||
    cast(FakeIotTransport, dev.protocol._transport).proto["system"]["get_sysinfo"][
 | 
			
		||||
        "deviceId"
 | 
			
		||||
    ] = device_id
 | 
			
		||||
 | 
			
		||||
    # Info no message logging
 | 
			
		||||
    caplog.set_level(logging.INFO)
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert device_id not in caplog.text
 | 
			
		||||
 | 
			
		||||
    caplog.set_level(logging.DEBUG, logger="kasa")
 | 
			
		||||
    # The fake iot protocol also logs so disable it
 | 
			
		||||
    test_logger = logging.getLogger("kasa.tests.fakeprotocol_iot")
 | 
			
		||||
    test_logger.setLevel(logging.INFO)
 | 
			
		||||
 | 
			
		||||
    # Debug no redaction
 | 
			
		||||
    caplog.clear()
 | 
			
		||||
    cast(IotProtocol, dev.protocol)._redact_data = False
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert device_id in caplog.text
 | 
			
		||||
 | 
			
		||||
    # Debug redaction
 | 
			
		||||
    caplog.clear()
 | 
			
		||||
    cast(IotProtocol, dev.protocol)._redact_data = True
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert device_id not in caplog.text
 | 
			
		||||
    assert "REDACTED_" + device_id[9::] in caplog.text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_redact_data():
 | 
			
		||||
    """Test redact data function."""
 | 
			
		||||
    data = {
 | 
			
		||||
        "device_id": "123456789ABCDEF",
 | 
			
		||||
        "owner": "0987654",
 | 
			
		||||
        "mac": "12:34:56:78:90:AB",
 | 
			
		||||
        "ip": "192.168.1",
 | 
			
		||||
        "no_val": None,
 | 
			
		||||
    }
 | 
			
		||||
    excpected_data = {
 | 
			
		||||
        "device_id": "REDACTED_ABCDEF",
 | 
			
		||||
        "owner": "**REDACTED**",
 | 
			
		||||
        "mac": "12:34:56:00:00:00",
 | 
			
		||||
        "ip": "**REDACTEX**",
 | 
			
		||||
        "no_val": None,
 | 
			
		||||
    }
 | 
			
		||||
    REDACTORS = {
 | 
			
		||||
        "device_id": lambda x: "REDACTED_" + x[9::],
 | 
			
		||||
        "owner": None,
 | 
			
		||||
        "mac": mask_mac,
 | 
			
		||||
        "ip": lambda x: "127.0.0." + x.split(".")[3],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    redacted_data = redact_data(data, REDACTORS)
 | 
			
		||||
 | 
			
		||||
    assert redacted_data == excpected_data
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,11 @@
 | 
			
		||||
import logging
 | 
			
		||||
from typing import cast
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import pytest_mock
 | 
			
		||||
 | 
			
		||||
from kasa.smart import SmartDevice
 | 
			
		||||
 | 
			
		||||
from ..exceptions import (
 | 
			
		||||
    SMART_RETRYABLE_ERRORS,
 | 
			
		||||
    DeviceError,
 | 
			
		||||
@@ -10,6 +13,7 @@ from ..exceptions import (
 | 
			
		||||
    SmartErrorCode,
 | 
			
		||||
)
 | 
			
		||||
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
 | 
			
		||||
from .conftest import device_smart
 | 
			
		||||
from .fakeprotocol_smart import FakeSmartTransport
 | 
			
		||||
 | 
			
		||||
DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
 | 
			
		||||
@@ -409,3 +413,34 @@ async def test_incomplete_list(mocker, caplog):
 | 
			
		||||
        "Device 127.0.0.123 returned empty results list for method get_preset_rules"
 | 
			
		||||
        in caplog.text
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@device_smart
 | 
			
		||||
async def test_smart_queries_redaction(
 | 
			
		||||
    dev: SmartDevice, caplog: pytest.LogCaptureFixture
 | 
			
		||||
):
 | 
			
		||||
    """Test query sensitive info redaction."""
 | 
			
		||||
    device_id = "123456789ABCDEF"
 | 
			
		||||
    cast(FakeSmartTransport, dev.protocol._transport).info["get_device_info"][
 | 
			
		||||
        "device_id"
 | 
			
		||||
    ] = device_id
 | 
			
		||||
 | 
			
		||||
    # Info no message logging
 | 
			
		||||
    caplog.set_level(logging.INFO)
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert device_id not in caplog.text
 | 
			
		||||
 | 
			
		||||
    caplog.set_level(logging.DEBUG)
 | 
			
		||||
 | 
			
		||||
    # Debug no redaction
 | 
			
		||||
    caplog.clear()
 | 
			
		||||
    dev.protocol._redact_data = False
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert device_id in caplog.text
 | 
			
		||||
 | 
			
		||||
    # Debug redaction
 | 
			
		||||
    caplog.clear()
 | 
			
		||||
    dev.protocol._redact_data = True
 | 
			
		||||
    await dev.update()
 | 
			
		||||
    assert device_id not in caplog.text
 | 
			
		||||
    assert "REDACTED_" + device_id[9::] in caplog.text
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,6 @@ import logging
 | 
			
		||||
import socket
 | 
			
		||||
import struct
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
from pprint import pformat as pf
 | 
			
		||||
 | 
			
		||||
# When support for cpython older than 3.11 is dropped
 | 
			
		||||
# async_timeout can be replaced with asyncio.timeout
 | 
			
		||||
@@ -78,9 +77,8 @@ class XorTransport(BaseTransport):
 | 
			
		||||
        """Execute a query on the device and wait for the response."""
 | 
			
		||||
        assert self.writer is not None  # noqa: S101
 | 
			
		||||
        assert self.reader is not None  # noqa: S101
 | 
			
		||||
        debug_log = _LOGGER.isEnabledFor(logging.DEBUG)
 | 
			
		||||
        if debug_log:
 | 
			
		||||
            _LOGGER.debug("%s >> %s", self._host, request)
 | 
			
		||||
        _LOGGER.debug("Device %s sending query %s", self._host, request)
 | 
			
		||||
 | 
			
		||||
        self.writer.write(XorEncryption.encrypt(request))
 | 
			
		||||
        await self.writer.drain()
 | 
			
		||||
 | 
			
		||||
@@ -90,8 +88,8 @@ class XorTransport(BaseTransport):
 | 
			
		||||
        buffer = await self.reader.readexactly(length)
 | 
			
		||||
        response = XorEncryption.decrypt(buffer)
 | 
			
		||||
        json_payload = json_loads(response)
 | 
			
		||||
        if debug_log:
 | 
			
		||||
            _LOGGER.debug("%s << %s", self._host, pf(json_payload))
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Device %s query response received", self._host)
 | 
			
		||||
 | 
			
		||||
        return json_payload
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user