From 6f58c99bf713cb5d6902db26232aaea49b807c0b Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 24 Dec 2024 15:23:12 +0000 Subject: [PATCH] Add rest api command to the cli --- .gitignore | 3 + kasa/cli/discover.py | 17 +- kasa/cli/main.py | 3 +- kasa/cli/rest.py | 533 +++++++++++++++++++++++++++++++++++++++ kasa/interfaces/light.py | 4 + 5 files changed, 554 insertions(+), 6 deletions(-) create mode 100644 kasa/cli/rest.py diff --git a/.gitignore b/.gitignore index 573a4c08..90b5badf 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ venv /build docs/build + +# self-signed certs +self-signed diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py index 2470434b..78b3dd96 100644 --- a/kasa/cli/discover.py +++ b/kasa/cli/discover.py @@ -73,7 +73,9 @@ async def detail(ctx): echo() discovered = await _discover( - ctx, print_discovered=print_discovered, print_unsupported=print_unsupported + ctx.parent.parent, + print_discovered=print_discovered, + print_unsupported=print_unsupported, ) if ctx.parent.parent.params["host"]: return discovered @@ -111,7 +113,7 @@ async def raw(ctx, redact: bool): ) echo(json_dumps(discovered, indent=True)) - return await _discover(ctx, print_raw=print_raw, do_echo=False) + return await _discover(ctx.parent.parent, print_raw=print_raw, do_echo=False) @discover.command() @@ -148,7 +150,7 @@ async def list(ctx): f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}" ) return await _discover( - ctx, + ctx.parent.parent, print_discovered=print_discovered, print_unsupported=print_unsupported, do_echo=False, @@ -156,9 +158,14 @@ async def list(ctx): async def _discover( - ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True + root_ctx, + *, + print_discovered=None, + print_unsupported=None, + print_raw=None, + do_echo=True, ): - params = ctx.parent.parent.params + params = root_ctx.params target = params["target"] username = params["username"] password = params["password"] diff --git a/kasa/cli/main.py b/kasa/cli/main.py index fbcdf391..1664a0db 100755 --- a/kasa/cli/main.py +++ b/kasa/cli/main.py @@ -76,6 +76,7 @@ def _legacy_type_to_class(_type: str) -> Any: "schedule": None, "usage": None, "energy": "usage", + "rest": None, # device commands runnnable at top level "state": "device", "on": "device", @@ -270,7 +271,7 @@ async def cli( # but this keeps mypy happy for now logging.basicConfig(**logging_config) # type: ignore - if ctx.invoked_subcommand == "discover": + if ctx.invoked_subcommand in {"discover", "rest"}: return if alias is not None and host is not None: diff --git a/kasa/cli/rest.py b/kasa/cli/rest.py new file mode 100644 index 00000000..3ac8f9b6 --- /dev/null +++ b/kasa/cli/rest.py @@ -0,0 +1,533 @@ +"""Module for cli rest api.""" + +import asyncio +import datetime +import inspect +import logging +import socket +import ssl +from contextlib import suppress +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import asyncclick as click +from aiohttp import web +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +from kasa import ( + Device, +) +from kasa.json import dumps as json_dumps +from kasa.json import loads as json_loads +from kasa.module import _is_bound_feature + +from .common import echo +from .discover import _discover + +_LOGGER = logging.getLogger(__name__) +logging.getLogger("aiohttp").setLevel(logging.WARNING) + +CERT_FILENAME = "certificate.pem" +KEY_FILENAME = "key.pem" +DEFAULT_PASSPHRASE = "passthrough" # noqa: S105 + + +async def wait_on_keyboard_interrupt(msg: str): + """Non loop blocking get input.""" + echo(msg + ", press Ctrl-C to cancel\n") + + with suppress(asyncio.CancelledError): + await asyncio.Event().wait() + + +async def _get_host_ip() -> str: + def get_ip() -> str: + # From https://stackoverflow.com/a/28950776 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + # doesn't even have to be reachable + s.connect(("10.254.254.254", 1)) + ip = s.getsockname()[0] + finally: + s.close() + return ip + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, get_ip) + + +@click.command() +@click.option( + "--rest-port", + default=8080, + required=False, + help="Port for rest api.", +) +@click.option( + "--rest-host", + default=None, + required=False, + envvar="KASA_LISTEN_IP", + help="Host for rest api.", +) +@click.option( + "--cert-file", + default=None, + required=False, + help="Cert file for https.", +) +@click.option( + "--key-file", + default=None, + required=False, + help="Key file for https.", +) +@click.option( + "--key-passphrase", + default=None, + required=False, + help="Passphrase for https key.", +) +@click.option( + "--https/--no-https", + default=True, + is_flag=True, + type=bool, + help=( + "Use https, recommended to ensure passwords are not sent unencrypted. " + "If no cert-file provided will auto-create self-signed cert." + ), +) +@click.option( + "--localhost/--no-localhost", + default=True, + is_flag=True, + type=bool, + help=("Start server on localhost or primary device ip. "), +) +@click.option( + "--self-signed-folder", + default="./self-signed", + help="Location to store auto-created self-signed cert files.", +) +@click.option( + "--secure/--no-secure", + default=True, + is_flag=True, + type=bool, + help=( + "Require username and password in requests for " + "devices that require authentication." + ), +) +@click.pass_context +async def rest( + ctx: click.Context, + rest_port: int, + rest_host: str | None, + cert_file: str | None, + key_file: str | None, + key_passphrase: str | None, + https: bool, + self_signed_folder: str, + localhost: bool, + secure: bool, +) -> None: + """Start the rest api. + + Example calls: + + List all device attributes: + POST /device?host= + + Username and password required in all requests for devices that require + authentication unless --no-secure is set: + POST /device?host= '{"username": "user@example.com", "password": "pwd"}' + + Set a device attribute + POST /device?host= '{"name": "set_alias", "value": "Study cam"}' + + List all device module ids: + POST /module?host= + + List all module attributes + POST /module?host= '{"id": "Light"}' + + Set a module attribute + POST /module?host= '{"id": "Light", "name": "set_brightness", + "value": 50}' + + List all device feature ids: + POST /feature?host= + + Set a feature value + POST /feature?host= '{"id": "brightness", "value": 50}' + """ + if not rest_host: + if localhost: + rest_host = "localhost" + else: + rest_host = await _get_host_ip() + + scheme = "https" if https or cert_file else "http" + echo(f"Starting the rest api on {scheme}://{rest_host}:{rest_port}") + + devices = await _discover(ctx.parent) + + if https and not cert_file: + if not key_passphrase: + key_passphrase = DEFAULT_PASSPHRASE + + self_signed_path = Path(self_signed_folder) + cert_file_path = self_signed_path / CERT_FILENAME + key_file_path = self_signed_path / KEY_FILENAME + cert_file = str(cert_file_path) + key_file = str(key_file_path) + + if not cert_file_path.exists() or not key_file_path.exists(): + echo("Creating self-signed certificate") + + self_signed_path.mkdir(exist_ok=True) + _create_self_signed_key(key_file, cert_file, key_passphrase) + + if TYPE_CHECKING: + assert ctx.parent + username = ctx.parent.params["username"] + password = ctx.parent.params["password"] + server = RestServer(devices, username=username, password=password, secure=secure) + + await server.start( + rest_host, + rest_port, + cert_file=cert_file, + key_file=key_file, + key_passphrase=key_passphrase, + ) + + msg = f"Started rest api on {scheme}://{rest_host}:{rest_port}" + + await wait_on_keyboard_interrupt(msg) + + echo("\nStopping rest api") + + await server.stop() + + +class RestServer: + """Rest server class.""" + + def __init__( + self, + devices: dict[str, Device], + *, + username: str | None, + password: str | None, + secure: bool, + ) -> None: + self.devices = devices + self.running = False + self._username = username + self._password = password + self._secure = secure + + @staticmethod + def _serializable(o: object, attr_name: str) -> Any | None: + val = getattr(o, attr_name) + if hasattr(val, "to_dict"): + return val.to_dict() + try: + json_dumps(val) + return val + except (TypeError, OverflowError): + return None + + def _check_auth(self, request_dict: dict[str, Any], device: Device) -> bool: + if not self._secure: + return True + + if not device.device_info.requires_auth: + return True + + if (un := request_dict.get("username")) and ( + pw := request_dict.get("password") + ): + return (un == self._username) and (pw == self._password) + + return False + + async def _module(self, request: web.Request) -> web.Response: + if not (host := request.query.get("host")): + return web.Response(status=400, text="No host provided") + + def _get_interface(mod): + for base in mod.__class__.__bases__: + if base.__module__.startswith("kasa.interfaces"): + return base + return None + + dev = self.devices[host] + await dev.update() + + if req_body := await request.read(): + req = json_loads(req_body.decode()) + else: + req = {} + + if not self._check_auth(req, dev): + return web.Response(status=401) + + if not (module_id := req.get("id")): + list_result = { + "result": [ + mod_name + for mod_name, mod in dev.modules.items() + if _get_interface(mod) is not None + ] + } + body = json_dumps(list_result) + return web.Response(body=body) + + if not (module := dev.modules.get(module_id)) or not ( + interface_cls := _get_interface(module) + ): + return web.Response(status=400) + + # TODO make valid_temperature_range a FeatureAttribute + skip = { + "valid_temperature_range", + "is_color", + "is_dimmable", + "is_variable_color_temp", + "has_effects", + } + properties = { + attr_name: val + for attr_name in vars(interface_cls) + if attr_name[0] != "_" + and attr_name not in skip + and (attr := getattr(module.__class__, attr_name)) + and isinstance(attr, property) + and (not _is_bound_feature(attr) or module.has_feature(attr_name)) + and (val := self._serializable(module, attr_name)) + } + + # Return all the properties + if "name" not in req: + result = {"result": properties} + body = json_dumps(result) + return web.Response(body=body) + + setter_properties = { + attr_name + for attr_name in vars(interface_cls) + if attr_name[:3] == "set" + and (attr := getattr(module.__class__, attr_name)) + and inspect.iscoroutinefunction(attr) + } + + # Set a value on the module + if (value := req.get("value")) is not None: + if req["name"] not in setter_properties: + return web.Response(status=400) + res = await getattr(module, req["name"])(value) + result = {"result": res} + return web.Response(body=json_dumps(result)) + + # Call a method with no params + callable_methods = { + attr_name + for attr_name in vars(interface_cls) + if (attr := getattr(module.__class__, attr_name)) + and inspect.iscoroutinefunction(attr) + } + if req["name"] not in callable_methods: + return web.Response(status=400) + + res = await getattr(module, req["name"])() + result = {"result": res} + return web.Response(body=json_dumps(result)) + + async def _device(self, request: web.Request) -> web.Response: + if not (host := request.query.get("host")): + return web.Response(status=400, text="No host provided") + + dev = self.devices[host] + await dev.update() + + if req_body := await request.read(): + req = json_loads(req_body.decode()) + else: + req = {} + + if not self._check_auth(req, dev): + return web.Response(status=401) + + if not (name := req.get("name")): + skip = {"internal_state", "sys_info", "config", "hw_info"} + properties = { + attr_name: val + for attr_name in vars(Device) + if attr_name[0] != "_" + and attr_name not in skip + and (attr := getattr(Device, attr_name)) + and isinstance(attr, property) + and (val := self._serializable(dev, attr_name)) + } + + result = {"result": properties} + return web.Response(body=json_dumps(result)) + + setter_properties = { + attr_name + for attr_name in vars(Device) + if attr_name[:3] == "set" + and (attr := getattr(Device, attr_name)) + and inspect.iscoroutinefunction(attr) + } + if name not in setter_properties: + return web.Response(status=400) + + res = await getattr(dev, name)(req["value"]) + result = {"result": res} + return web.Response(body=json_dumps(result)) + + async def _feature(self, request: web.Request) -> web.Response: + if not (host := request.query.get("host")): + return web.Response(status=400, text="No host provided") + + dev = self.devices[host] + await dev.update() + features = dev.features + + if req_body := await request.read(): + req = json_loads(req_body.decode()) + else: + req = {} + + if not self._check_auth(req, dev): + return web.Response(status=401) + + if not (feat_id := req.get("id")): + feats = {feat.id: feat.value for feat in features.values()} + result = {"result": feats} + body = json_dumps(result) + return web.Response(body=body) + + if not (feat := features.get(feat_id)): + return web.Response(status=400) + + await feat.set_value(req["value"]) + return web.Response() + + async def start( + self, + rest_ip: str, + rest_port: int, + *, + cert_file: str | None = None, + key_file: str | None = None, + key_passphrase: str | None = None, + ) -> None: + """Start the server.""" + app = web.Application() + app.add_routes( + [ + web.post("/device", self._device), + web.post("/module", self._module), + web.post("/feature", self._feature), + ] + ) + + self.runner = web.AppRunner(app) + await self.runner.setup() + + if cert_file: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(cert_file, key_file, key_passphrase) + else: + ssl_context = None + + self.site = web.TCPSite( + self.runner, rest_ip, rest_port, ssl_context=ssl_context + ) + try: + await self.site.start() + except Exception as ex: + _LOGGER.exception( + "Error trying to start rest api on %s:%s: %s", rest_ip, rest_port, ex + ) + raise + + _LOGGER.debug( + "Rest api running on %s:%s", + rest_ip, + rest_port, + ) + self.running = True + + async def stop(self) -> None: + """Stop the rest api.""" + if not self.running: + _LOGGER.debug("Rest api already stopped") + return + + _LOGGER.debug("Stopping rest api") + self.running = False + + await self.site.stop() + await self.runner.shutdown() + + +def _create_self_signed_key(key_file: str, certificate_file: str, passphrase: str): + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + # Write our key to disk for safe keeping + with open(key_file, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption( + passphrase.encode() + ), + ) + ) + + # Various details about who we are. For a self-signed certificate the + # subject and issuer are always the same. + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "My Company"), + x509.NameAttribute(NameOID.COMMON_NAME, "mysite.com"), + ] + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after( + # Our certificate will be valid for 10 days + datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=10) + ) + .add_extension( + x509.SubjectAlternativeName([x509.DNSName("localhost")]), + critical=False, + # Sign our certificate with our private key + ) + .sign(key, hashes.SHA256()) + ) + # Write our certificate out to disk. + with open(certificate_file, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) diff --git a/kasa/interfaces/light.py b/kasa/interfaces/light.py index 89058f98..01ad2995 100644 --- a/kasa/interfaces/light.py +++ b/kasa/interfaces/light.py @@ -96,6 +96,10 @@ class HSV(NamedTuple): saturation: int value: int + def to_dict(self) -> dict: + """Return dict represenation.""" + return {"hue": self.hue, "saturation": self.saturation, "value": self.value} + class Light(Module, ABC): """Base class for TP-Link Light."""