Add rest api command to the cli

This commit is contained in:
Steven B 2024-12-24 15:23:12 +00:00
parent 63f4f82791
commit 6f58c99bf7
No known key found for this signature in database
GPG Key ID: 6D5B46B3679F2A43
5 changed files with 554 additions and 6 deletions

3
.gitignore vendored
View File

@ -28,3 +28,6 @@ venv
/build /build
docs/build docs/build
# self-signed certs
self-signed

View File

@ -73,7 +73,9 @@ async def detail(ctx):
echo() echo()
discovered = await _discover( discovered = await _discover(
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported ctx.parent.parent,
print_discovered=print_discovered,
print_unsupported=print_unsupported,
) )
if ctx.parent.parent.params["host"]: if ctx.parent.parent.params["host"]:
return discovered return discovered
@ -111,7 +113,7 @@ async def raw(ctx, redact: bool):
) )
echo(json_dumps(discovered, indent=True)) echo(json_dumps(discovered, indent=True))
return await _discover(ctx, print_raw=print_raw, do_echo=False) return await _discover(ctx.parent.parent, print_raw=print_raw, do_echo=False)
@discover.command() @discover.command()
@ -148,7 +150,7 @@ async def list(ctx):
f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}" f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}"
) )
return await _discover( return await _discover(
ctx, ctx.parent.parent,
print_discovered=print_discovered, print_discovered=print_discovered,
print_unsupported=print_unsupported, print_unsupported=print_unsupported,
do_echo=False, do_echo=False,
@ -156,9 +158,14 @@ async def list(ctx):
async def _discover( async def _discover(
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True root_ctx,
*,
print_discovered=None,
print_unsupported=None,
print_raw=None,
do_echo=True,
): ):
params = ctx.parent.parent.params params = root_ctx.params
target = params["target"] target = params["target"]
username = params["username"] username = params["username"]
password = params["password"] password = params["password"]

View File

@ -76,6 +76,7 @@ def _legacy_type_to_class(_type: str) -> Any:
"schedule": None, "schedule": None,
"usage": None, "usage": None,
"energy": "usage", "energy": "usage",
"rest": None,
# device commands runnnable at top level # device commands runnnable at top level
"state": "device", "state": "device",
"on": "device", "on": "device",
@ -270,7 +271,7 @@ async def cli(
# but this keeps mypy happy for now # but this keeps mypy happy for now
logging.basicConfig(**logging_config) # type: ignore logging.basicConfig(**logging_config) # type: ignore
if ctx.invoked_subcommand == "discover": if ctx.invoked_subcommand in {"discover", "rest"}:
return return
if alias is not None and host is not None: if alias is not None and host is not None:

533
kasa/cli/rest.py Normal file
View File

@ -0,0 +1,533 @@
"""Module for cli rest api."""
import asyncio
import datetime
import inspect
import logging
import socket
import ssl
from contextlib import suppress
from pathlib import Path
from typing import TYPE_CHECKING, Any
import asyncclick as click
from aiohttp import web
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from kasa import (
Device,
)
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.module import _is_bound_feature
from .common import echo
from .discover import _discover
_LOGGER = logging.getLogger(__name__)
logging.getLogger("aiohttp").setLevel(logging.WARNING)
CERT_FILENAME = "certificate.pem"
KEY_FILENAME = "key.pem"
DEFAULT_PASSPHRASE = "passthrough" # noqa: S105
async def wait_on_keyboard_interrupt(msg: str):
"""Non loop blocking get input."""
echo(msg + ", press Ctrl-C to cancel\n")
with suppress(asyncio.CancelledError):
await asyncio.Event().wait()
async def _get_host_ip() -> str:
def get_ip() -> str:
# From https://stackoverflow.com/a/28950776
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
try:
# doesn't even have to be reachable
s.connect(("10.254.254.254", 1))
ip = s.getsockname()[0]
finally:
s.close()
return ip
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_ip)
@click.command()
@click.option(
"--rest-port",
default=8080,
required=False,
help="Port for rest api.",
)
@click.option(
"--rest-host",
default=None,
required=False,
envvar="KASA_LISTEN_IP",
help="Host for rest api.",
)
@click.option(
"--cert-file",
default=None,
required=False,
help="Cert file for https.",
)
@click.option(
"--key-file",
default=None,
required=False,
help="Key file for https.",
)
@click.option(
"--key-passphrase",
default=None,
required=False,
help="Passphrase for https key.",
)
@click.option(
"--https/--no-https",
default=True,
is_flag=True,
type=bool,
help=(
"Use https, recommended to ensure passwords are not sent unencrypted. "
"If no cert-file provided will auto-create self-signed cert."
),
)
@click.option(
"--localhost/--no-localhost",
default=True,
is_flag=True,
type=bool,
help=("Start server on localhost or primary device ip. "),
)
@click.option(
"--self-signed-folder",
default="./self-signed",
help="Location to store auto-created self-signed cert files.",
)
@click.option(
"--secure/--no-secure",
default=True,
is_flag=True,
type=bool,
help=(
"Require username and password in requests for "
"devices that require authentication."
),
)
@click.pass_context
async def rest(
ctx: click.Context,
rest_port: int,
rest_host: str | None,
cert_file: str | None,
key_file: str | None,
key_passphrase: str | None,
https: bool,
self_signed_folder: str,
localhost: bool,
secure: bool,
) -> None:
"""Start the rest api.
Example calls:
List all device attributes:
POST /device?host=<device_ip>
Username and password required in all requests for devices that require
authentication unless --no-secure is set:
POST /device?host=<device_ip> '{"username": "user@example.com", "password": "pwd"}'
Set a device attribute
POST /device?host=<device_ip> '{"name": "set_alias", "value": "Study cam"}'
List all device module ids:
POST /module?host=<device_ip>
List all module attributes
POST /module?host=<device_ip> '{"id": "Light"}'
Set a module attribute
POST /module?host=<device_ip> '{"id": "Light", "name": "set_brightness",
"value": 50}'
List all device feature ids:
POST /feature?host=<device_ip>
Set a feature value
POST /feature?host=<device_ip> '{"id": "brightness", "value": 50}'
"""
if not rest_host:
if localhost:
rest_host = "localhost"
else:
rest_host = await _get_host_ip()
scheme = "https" if https or cert_file else "http"
echo(f"Starting the rest api on {scheme}://{rest_host}:{rest_port}")
devices = await _discover(ctx.parent)
if https and not cert_file:
if not key_passphrase:
key_passphrase = DEFAULT_PASSPHRASE
self_signed_path = Path(self_signed_folder)
cert_file_path = self_signed_path / CERT_FILENAME
key_file_path = self_signed_path / KEY_FILENAME
cert_file = str(cert_file_path)
key_file = str(key_file_path)
if not cert_file_path.exists() or not key_file_path.exists():
echo("Creating self-signed certificate")
self_signed_path.mkdir(exist_ok=True)
_create_self_signed_key(key_file, cert_file, key_passphrase)
if TYPE_CHECKING:
assert ctx.parent
username = ctx.parent.params["username"]
password = ctx.parent.params["password"]
server = RestServer(devices, username=username, password=password, secure=secure)
await server.start(
rest_host,
rest_port,
cert_file=cert_file,
key_file=key_file,
key_passphrase=key_passphrase,
)
msg = f"Started rest api on {scheme}://{rest_host}:{rest_port}"
await wait_on_keyboard_interrupt(msg)
echo("\nStopping rest api")
await server.stop()
class RestServer:
"""Rest server class."""
def __init__(
self,
devices: dict[str, Device],
*,
username: str | None,
password: str | None,
secure: bool,
) -> None:
self.devices = devices
self.running = False
self._username = username
self._password = password
self._secure = secure
@staticmethod
def _serializable(o: object, attr_name: str) -> Any | None:
val = getattr(o, attr_name)
if hasattr(val, "to_dict"):
return val.to_dict()
try:
json_dumps(val)
return val
except (TypeError, OverflowError):
return None
def _check_auth(self, request_dict: dict[str, Any], device: Device) -> bool:
if not self._secure:
return True
if not device.device_info.requires_auth:
return True
if (un := request_dict.get("username")) and (
pw := request_dict.get("password")
):
return (un == self._username) and (pw == self._password)
return False
async def _module(self, request: web.Request) -> web.Response:
if not (host := request.query.get("host")):
return web.Response(status=400, text="No host provided")
def _get_interface(mod):
for base in mod.__class__.__bases__:
if base.__module__.startswith("kasa.interfaces"):
return base
return None
dev = self.devices[host]
await dev.update()
if req_body := await request.read():
req = json_loads(req_body.decode())
else:
req = {}
if not self._check_auth(req, dev):
return web.Response(status=401)
if not (module_id := req.get("id")):
list_result = {
"result": [
mod_name
for mod_name, mod in dev.modules.items()
if _get_interface(mod) is not None
]
}
body = json_dumps(list_result)
return web.Response(body=body)
if not (module := dev.modules.get(module_id)) or not (
interface_cls := _get_interface(module)
):
return web.Response(status=400)
# TODO make valid_temperature_range a FeatureAttribute
skip = {
"valid_temperature_range",
"is_color",
"is_dimmable",
"is_variable_color_temp",
"has_effects",
}
properties = {
attr_name: val
for attr_name in vars(interface_cls)
if attr_name[0] != "_"
and attr_name not in skip
and (attr := getattr(module.__class__, attr_name))
and isinstance(attr, property)
and (not _is_bound_feature(attr) or module.has_feature(attr_name))
and (val := self._serializable(module, attr_name))
}
# Return all the properties
if "name" not in req:
result = {"result": properties}
body = json_dumps(result)
return web.Response(body=body)
setter_properties = {
attr_name
for attr_name in vars(interface_cls)
if attr_name[:3] == "set"
and (attr := getattr(module.__class__, attr_name))
and inspect.iscoroutinefunction(attr)
}
# Set a value on the module
if (value := req.get("value")) is not None:
if req["name"] not in setter_properties:
return web.Response(status=400)
res = await getattr(module, req["name"])(value)
result = {"result": res}
return web.Response(body=json_dumps(result))
# Call a method with no params
callable_methods = {
attr_name
for attr_name in vars(interface_cls)
if (attr := getattr(module.__class__, attr_name))
and inspect.iscoroutinefunction(attr)
}
if req["name"] not in callable_methods:
return web.Response(status=400)
res = await getattr(module, req["name"])()
result = {"result": res}
return web.Response(body=json_dumps(result))
async def _device(self, request: web.Request) -> web.Response:
if not (host := request.query.get("host")):
return web.Response(status=400, text="No host provided")
dev = self.devices[host]
await dev.update()
if req_body := await request.read():
req = json_loads(req_body.decode())
else:
req = {}
if not self._check_auth(req, dev):
return web.Response(status=401)
if not (name := req.get("name")):
skip = {"internal_state", "sys_info", "config", "hw_info"}
properties = {
attr_name: val
for attr_name in vars(Device)
if attr_name[0] != "_"
and attr_name not in skip
and (attr := getattr(Device, attr_name))
and isinstance(attr, property)
and (val := self._serializable(dev, attr_name))
}
result = {"result": properties}
return web.Response(body=json_dumps(result))
setter_properties = {
attr_name
for attr_name in vars(Device)
if attr_name[:3] == "set"
and (attr := getattr(Device, attr_name))
and inspect.iscoroutinefunction(attr)
}
if name not in setter_properties:
return web.Response(status=400)
res = await getattr(dev, name)(req["value"])
result = {"result": res}
return web.Response(body=json_dumps(result))
async def _feature(self, request: web.Request) -> web.Response:
if not (host := request.query.get("host")):
return web.Response(status=400, text="No host provided")
dev = self.devices[host]
await dev.update()
features = dev.features
if req_body := await request.read():
req = json_loads(req_body.decode())
else:
req = {}
if not self._check_auth(req, dev):
return web.Response(status=401)
if not (feat_id := req.get("id")):
feats = {feat.id: feat.value for feat in features.values()}
result = {"result": feats}
body = json_dumps(result)
return web.Response(body=body)
if not (feat := features.get(feat_id)):
return web.Response(status=400)
await feat.set_value(req["value"])
return web.Response()
async def start(
self,
rest_ip: str,
rest_port: int,
*,
cert_file: str | None = None,
key_file: str | None = None,
key_passphrase: str | None = None,
) -> None:
"""Start the server."""
app = web.Application()
app.add_routes(
[
web.post("/device", self._device),
web.post("/module", self._module),
web.post("/feature", self._feature),
]
)
self.runner = web.AppRunner(app)
await self.runner.setup()
if cert_file:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(cert_file, key_file, key_passphrase)
else:
ssl_context = None
self.site = web.TCPSite(
self.runner, rest_ip, rest_port, ssl_context=ssl_context
)
try:
await self.site.start()
except Exception as ex:
_LOGGER.exception(
"Error trying to start rest api on %s:%s: %s", rest_ip, rest_port, ex
)
raise
_LOGGER.debug(
"Rest api running on %s:%s",
rest_ip,
rest_port,
)
self.running = True
async def stop(self) -> None:
"""Stop the rest api."""
if not self.running:
_LOGGER.debug("Rest api already stopped")
return
_LOGGER.debug("Stopping rest api")
self.running = False
await self.site.stop()
await self.runner.shutdown()
def _create_self_signed_key(key_file: str, certificate_file: str, passphrase: str):
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
# Write our key to disk for safe keeping
with open(key_file, "wb") as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.BestAvailableEncryption(
passphrase.encode()
),
)
)
# Various details about who we are. For a self-signed certificate the
# subject and issuer are always the same.
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "My Company"),
x509.NameAttribute(NameOID.COMMON_NAME, "mysite.com"),
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.UTC))
.not_valid_after(
# Our certificate will be valid for 10 days
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=10)
)
.add_extension(
x509.SubjectAlternativeName([x509.DNSName("localhost")]),
critical=False,
# Sign our certificate with our private key
)
.sign(key, hashes.SHA256())
)
# Write our certificate out to disk.
with open(certificate_file, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))

View File

@ -96,6 +96,10 @@ class HSV(NamedTuple):
saturation: int saturation: int
value: int value: int
def to_dict(self) -> dict:
"""Return dict represenation."""
return {"hue": self.hue, "saturation": self.saturation, "value": self.value}
class Light(Module, ABC): class Light(Module, ABC):
"""Base class for TP-Link Light.""" """Base class for TP-Link Light."""