mirror of
https://github.com/python-kasa/python-kasa.git
synced 2025-05-30 21:51:24 +00:00
534 lines
16 KiB
Python
534 lines
16 KiB
Python
"""Module for cli rest api."""
|
|
|
|
import asyncio
|
|
import datetime
|
|
import inspect
|
|
import logging
|
|
import socket
|
|
import ssl
|
|
from contextlib import suppress
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import asyncclick as click
|
|
from aiohttp import web
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.x509.oid import NameOID
|
|
|
|
from kasa import (
|
|
Device,
|
|
)
|
|
from kasa.json import dumps as json_dumps
|
|
from kasa.json import loads as json_loads
|
|
from kasa.module import _is_bound_feature
|
|
|
|
from .common import echo
|
|
from .discover import _discover
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
logging.getLogger("aiohttp").setLevel(logging.WARNING)
|
|
|
|
CERT_FILENAME = "certificate.pem"
|
|
KEY_FILENAME = "key.pem"
|
|
DEFAULT_PASSPHRASE = "passthrough" # noqa: S105
|
|
|
|
|
|
async def wait_on_keyboard_interrupt(msg: str):
|
|
"""Non loop blocking get input."""
|
|
echo(msg + ", press Ctrl-C to cancel\n")
|
|
|
|
with suppress(asyncio.CancelledError):
|
|
await asyncio.Event().wait()
|
|
|
|
|
|
async def _get_host_ip() -> str:
|
|
def get_ip() -> str:
|
|
# From https://stackoverflow.com/a/28950776
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
s.settimeout(0)
|
|
try:
|
|
# doesn't even have to be reachable
|
|
s.connect(("10.254.254.254", 1))
|
|
ip = s.getsockname()[0]
|
|
finally:
|
|
s.close()
|
|
return ip
|
|
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(None, get_ip)
|
|
|
|
|
|
@click.command()
|
|
@click.option(
|
|
"--rest-port",
|
|
default=8080,
|
|
required=False,
|
|
help="Port for rest api.",
|
|
)
|
|
@click.option(
|
|
"--rest-host",
|
|
default=None,
|
|
required=False,
|
|
envvar="KASA_LISTEN_IP",
|
|
help="Host for rest api.",
|
|
)
|
|
@click.option(
|
|
"--cert-file",
|
|
default=None,
|
|
required=False,
|
|
help="Cert file for https.",
|
|
)
|
|
@click.option(
|
|
"--key-file",
|
|
default=None,
|
|
required=False,
|
|
help="Key file for https.",
|
|
)
|
|
@click.option(
|
|
"--key-passphrase",
|
|
default=None,
|
|
required=False,
|
|
help="Passphrase for https key.",
|
|
)
|
|
@click.option(
|
|
"--https/--no-https",
|
|
default=True,
|
|
is_flag=True,
|
|
type=bool,
|
|
help=(
|
|
"Use https, recommended to ensure passwords are not sent unencrypted. "
|
|
"If no cert-file provided will auto-create self-signed cert."
|
|
),
|
|
)
|
|
@click.option(
|
|
"--localhost/--no-localhost",
|
|
default=True,
|
|
is_flag=True,
|
|
type=bool,
|
|
help=("Start server on localhost or primary device ip. "),
|
|
)
|
|
@click.option(
|
|
"--self-signed-folder",
|
|
default="./self-signed",
|
|
help="Location to store auto-created self-signed cert files.",
|
|
)
|
|
@click.option(
|
|
"--secure/--no-secure",
|
|
default=True,
|
|
is_flag=True,
|
|
type=bool,
|
|
help=(
|
|
"Require username and password in requests for "
|
|
"devices that require authentication."
|
|
),
|
|
)
|
|
@click.pass_context
|
|
async def rest(
|
|
ctx: click.Context,
|
|
rest_port: int,
|
|
rest_host: str | None,
|
|
cert_file: str | None,
|
|
key_file: str | None,
|
|
key_passphrase: str | None,
|
|
https: bool,
|
|
self_signed_folder: str,
|
|
localhost: bool,
|
|
secure: bool,
|
|
) -> None:
|
|
"""Start the rest api.
|
|
|
|
Example calls:
|
|
|
|
List all device attributes:
|
|
POST /device?host=<device_ip>
|
|
|
|
Username and password required in all requests for devices that require
|
|
authentication unless --no-secure is set:
|
|
POST /device?host=<device_ip> '{"username": "user@example.com", "password": "pwd"}'
|
|
|
|
Set a device attribute
|
|
POST /device?host=<device_ip> '{"name": "set_alias", "value": "Study cam"}'
|
|
|
|
List all device module ids:
|
|
POST /module?host=<device_ip>
|
|
|
|
List all module attributes
|
|
POST /module?host=<device_ip> '{"id": "Light"}'
|
|
|
|
Set a module attribute
|
|
POST /module?host=<device_ip> '{"id": "Light", "name": "set_brightness",
|
|
"value": 50}'
|
|
|
|
List all device feature ids:
|
|
POST /feature?host=<device_ip>
|
|
|
|
Set a feature value
|
|
POST /feature?host=<device_ip> '{"id": "brightness", "value": 50}'
|
|
"""
|
|
if not rest_host:
|
|
if localhost:
|
|
rest_host = "localhost"
|
|
else:
|
|
rest_host = await _get_host_ip()
|
|
|
|
scheme = "https" if https or cert_file else "http"
|
|
echo(f"Starting the rest api on {scheme}://{rest_host}:{rest_port}")
|
|
|
|
devices = await _discover(ctx.parent)
|
|
|
|
if https and not cert_file:
|
|
if not key_passphrase:
|
|
key_passphrase = DEFAULT_PASSPHRASE
|
|
|
|
self_signed_path = Path(self_signed_folder)
|
|
cert_file_path = self_signed_path / CERT_FILENAME
|
|
key_file_path = self_signed_path / KEY_FILENAME
|
|
cert_file = str(cert_file_path)
|
|
key_file = str(key_file_path)
|
|
|
|
if not cert_file_path.exists() or not key_file_path.exists():
|
|
echo("Creating self-signed certificate")
|
|
|
|
self_signed_path.mkdir(exist_ok=True)
|
|
_create_self_signed_key(key_file, cert_file, key_passphrase)
|
|
|
|
if TYPE_CHECKING:
|
|
assert ctx.parent
|
|
username = ctx.parent.params["username"]
|
|
password = ctx.parent.params["password"]
|
|
server = RestServer(devices, username=username, password=password, secure=secure)
|
|
|
|
await server.start(
|
|
rest_host,
|
|
rest_port,
|
|
cert_file=cert_file,
|
|
key_file=key_file,
|
|
key_passphrase=key_passphrase,
|
|
)
|
|
|
|
msg = f"Started rest api on {scheme}://{rest_host}:{rest_port}"
|
|
|
|
await wait_on_keyboard_interrupt(msg)
|
|
|
|
echo("\nStopping rest api")
|
|
|
|
await server.stop()
|
|
|
|
|
|
class RestServer:
|
|
"""Rest server class."""
|
|
|
|
def __init__(
|
|
self,
|
|
devices: dict[str, Device],
|
|
*,
|
|
username: str | None,
|
|
password: str | None,
|
|
secure: bool,
|
|
) -> None:
|
|
self.devices = devices
|
|
self.running = False
|
|
self._username = username
|
|
self._password = password
|
|
self._secure = secure
|
|
|
|
@staticmethod
|
|
def _serializable(o: object, attr_name: str) -> Any | None:
|
|
val = getattr(o, attr_name)
|
|
if hasattr(val, "to_dict"):
|
|
return val.to_dict()
|
|
try:
|
|
json_dumps(val)
|
|
return val
|
|
except (TypeError, OverflowError):
|
|
return None
|
|
|
|
def _check_auth(self, request_dict: dict[str, Any], device: Device) -> bool:
|
|
if not self._secure:
|
|
return True
|
|
|
|
if not device.device_info.requires_auth:
|
|
return True
|
|
|
|
if (un := request_dict.get("username")) and (
|
|
pw := request_dict.get("password")
|
|
):
|
|
return (un == self._username) and (pw == self._password)
|
|
|
|
return False
|
|
|
|
async def _module(self, request: web.Request) -> web.Response:
|
|
if not (host := request.query.get("host")):
|
|
return web.Response(status=400, text="No host provided")
|
|
|
|
def _get_interface(mod):
|
|
for base in mod.__class__.__bases__:
|
|
if base.__module__.startswith("kasa.interfaces"):
|
|
return base
|
|
return None
|
|
|
|
dev = self.devices[host]
|
|
await dev.update()
|
|
|
|
if req_body := await request.read():
|
|
req = json_loads(req_body.decode())
|
|
else:
|
|
req = {}
|
|
|
|
if not self._check_auth(req, dev):
|
|
return web.Response(status=401)
|
|
|
|
if not (module_id := req.get("id")):
|
|
list_result = {
|
|
"result": [
|
|
mod_name
|
|
for mod_name, mod in dev.modules.items()
|
|
if _get_interface(mod) is not None
|
|
]
|
|
}
|
|
body = json_dumps(list_result)
|
|
return web.Response(body=body)
|
|
|
|
if not (module := dev.modules.get(module_id)) or not (
|
|
interface_cls := _get_interface(module)
|
|
):
|
|
return web.Response(status=400)
|
|
|
|
# TODO make valid_temperature_range a FeatureAttribute
|
|
skip = {
|
|
"valid_temperature_range",
|
|
"is_color",
|
|
"is_dimmable",
|
|
"is_variable_color_temp",
|
|
"has_effects",
|
|
}
|
|
properties = {
|
|
attr_name: val
|
|
for attr_name in vars(interface_cls)
|
|
if attr_name[0] != "_"
|
|
and attr_name not in skip
|
|
and (attr := getattr(module.__class__, attr_name))
|
|
and isinstance(attr, property)
|
|
and (not _is_bound_feature(attr) or module.has_feature(attr_name))
|
|
and (val := self._serializable(module, attr_name))
|
|
}
|
|
|
|
# Return all the properties
|
|
if "name" not in req:
|
|
result = {"result": properties}
|
|
body = json_dumps(result)
|
|
return web.Response(body=body)
|
|
|
|
setter_properties = {
|
|
attr_name
|
|
for attr_name in vars(interface_cls)
|
|
if attr_name[:3] == "set"
|
|
and (attr := getattr(module.__class__, attr_name))
|
|
and inspect.iscoroutinefunction(attr)
|
|
}
|
|
|
|
# Set a value on the module
|
|
if (value := req.get("value")) is not None:
|
|
if req["name"] not in setter_properties:
|
|
return web.Response(status=400)
|
|
res = await getattr(module, req["name"])(value)
|
|
result = {"result": res}
|
|
return web.Response(body=json_dumps(result))
|
|
|
|
# Call a method with no params
|
|
callable_methods = {
|
|
attr_name
|
|
for attr_name in vars(interface_cls)
|
|
if (attr := getattr(module.__class__, attr_name))
|
|
and inspect.iscoroutinefunction(attr)
|
|
}
|
|
if req["name"] not in callable_methods:
|
|
return web.Response(status=400)
|
|
|
|
res = await getattr(module, req["name"])()
|
|
result = {"result": res}
|
|
return web.Response(body=json_dumps(result))
|
|
|
|
async def _device(self, request: web.Request) -> web.Response:
|
|
if not (host := request.query.get("host")):
|
|
return web.Response(status=400, text="No host provided")
|
|
|
|
dev = self.devices[host]
|
|
await dev.update()
|
|
|
|
if req_body := await request.read():
|
|
req = json_loads(req_body.decode())
|
|
else:
|
|
req = {}
|
|
|
|
if not self._check_auth(req, dev):
|
|
return web.Response(status=401)
|
|
|
|
if not (name := req.get("name")):
|
|
skip = {"internal_state", "sys_info", "config", "hw_info"}
|
|
properties = {
|
|
attr_name: val
|
|
for attr_name in vars(Device)
|
|
if attr_name[0] != "_"
|
|
and attr_name not in skip
|
|
and (attr := getattr(Device, attr_name))
|
|
and isinstance(attr, property)
|
|
and (val := self._serializable(dev, attr_name))
|
|
}
|
|
|
|
result = {"result": properties}
|
|
return web.Response(body=json_dumps(result))
|
|
|
|
setter_properties = {
|
|
attr_name
|
|
for attr_name in vars(Device)
|
|
if attr_name[:3] == "set"
|
|
and (attr := getattr(Device, attr_name))
|
|
and inspect.iscoroutinefunction(attr)
|
|
}
|
|
if name not in setter_properties:
|
|
return web.Response(status=400)
|
|
|
|
res = await getattr(dev, name)(req["value"])
|
|
result = {"result": res}
|
|
return web.Response(body=json_dumps(result))
|
|
|
|
async def _feature(self, request: web.Request) -> web.Response:
|
|
if not (host := request.query.get("host")):
|
|
return web.Response(status=400, text="No host provided")
|
|
|
|
dev = self.devices[host]
|
|
await dev.update()
|
|
features = dev.features
|
|
|
|
if req_body := await request.read():
|
|
req = json_loads(req_body.decode())
|
|
else:
|
|
req = {}
|
|
|
|
if not self._check_auth(req, dev):
|
|
return web.Response(status=401)
|
|
|
|
if not (feat_id := req.get("id")):
|
|
feats = {feat.id: feat.value for feat in features.values()}
|
|
result = {"result": feats}
|
|
body = json_dumps(result)
|
|
return web.Response(body=body)
|
|
|
|
if not (feat := features.get(feat_id)):
|
|
return web.Response(status=400)
|
|
|
|
await feat.set_value(req["value"])
|
|
return web.Response()
|
|
|
|
async def start(
|
|
self,
|
|
rest_ip: str,
|
|
rest_port: int,
|
|
*,
|
|
cert_file: str | None = None,
|
|
key_file: str | None = None,
|
|
key_passphrase: str | None = None,
|
|
) -> None:
|
|
"""Start the server."""
|
|
app = web.Application()
|
|
app.add_routes(
|
|
[
|
|
web.post("/device", self._device),
|
|
web.post("/module", self._module),
|
|
web.post("/feature", self._feature),
|
|
]
|
|
)
|
|
|
|
self.runner = web.AppRunner(app)
|
|
await self.runner.setup()
|
|
|
|
if cert_file:
|
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
ssl_context.load_cert_chain(cert_file, key_file, key_passphrase)
|
|
else:
|
|
ssl_context = None
|
|
|
|
self.site = web.TCPSite(
|
|
self.runner, rest_ip, rest_port, ssl_context=ssl_context
|
|
)
|
|
try:
|
|
await self.site.start()
|
|
except Exception as ex:
|
|
_LOGGER.exception(
|
|
"Error trying to start rest api on %s:%s: %s", rest_ip, rest_port, ex
|
|
)
|
|
raise
|
|
|
|
_LOGGER.debug(
|
|
"Rest api running on %s:%s",
|
|
rest_ip,
|
|
rest_port,
|
|
)
|
|
self.running = True
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the rest api."""
|
|
if not self.running:
|
|
_LOGGER.debug("Rest api already stopped")
|
|
return
|
|
|
|
_LOGGER.debug("Stopping rest api")
|
|
self.running = False
|
|
|
|
await self.site.stop()
|
|
await self.runner.shutdown()
|
|
|
|
|
|
def _create_self_signed_key(key_file: str, certificate_file: str, passphrase: str):
|
|
key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=2048,
|
|
)
|
|
# Write our key to disk for safe keeping
|
|
with open(key_file, "wb") as f:
|
|
f.write(
|
|
key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.BestAvailableEncryption(
|
|
passphrase.encode()
|
|
),
|
|
)
|
|
)
|
|
|
|
# Various details about who we are. For a self-signed certificate the
|
|
# subject and issuer are always the same.
|
|
subject = issuer = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
|
|
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
|
|
x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "My Company"),
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "mysite.com"),
|
|
]
|
|
)
|
|
cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(datetime.datetime.now(datetime.UTC))
|
|
.not_valid_after(
|
|
# Our certificate will be valid for 10 days
|
|
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=10)
|
|
)
|
|
.add_extension(
|
|
x509.SubjectAlternativeName([x509.DNSName("localhost")]),
|
|
critical=False,
|
|
# Sign our certificate with our private key
|
|
)
|
|
.sign(key, hashes.SHA256())
|
|
)
|
|
# Write our certificate out to disk.
|
|
with open(certificate_file, "wb") as f:
|
|
f.write(cert.public_bytes(serialization.Encoding.PEM))
|