diff --git a/kasa/cli/common.py b/kasa/cli/common.py
index 5114f7af..d0ef9dc3 100644
--- a/kasa/cli/common.py
+++ b/kasa/cli/common.py
@@ -10,7 +10,7 @@ from collections.abc import Callable
 from contextlib import contextmanager
 from functools import singledispatch, update_wrapper, wraps
 from gettext import gettext
-from typing import TYPE_CHECKING, Any, Final
+from typing import TYPE_CHECKING, Any, Final, NoReturn
 
 import asyncclick as click
 
@@ -57,7 +57,7 @@ def echo(*args, **kwargs) -> None:
         _echo(*args, **kwargs)
 
 
-def error(msg: str) -> None:
+def error(msg: str) -> NoReturn:
     """Print an error and exit."""
     echo(f"[bold red]{msg}[/bold red]")
     sys.exit(1)
@@ -68,6 +68,16 @@ def json_formatter_cb(result: Any, **kwargs) -> None:
     if not kwargs.get("json"):
         return
 
+    # Calling the discover command directly always returns a DeviceDict so if host
+    # was specified just format the device json
+    if (
+        (host := kwargs.get("host"))
+        and isinstance(result, dict)
+        and (dev := result.get(host))
+        and isinstance(dev, Device)
+    ):
+        result = dev
+
     @singledispatch
     def to_serializable(val):
         """Regular obj-to-string for json serialization.
@@ -85,6 +95,25 @@ def json_formatter_cb(result: Any, **kwargs) -> None:
     print(json_content)
 
 
+async def invoke_subcommand(
+    command: click.BaseCommand,
+    ctx: click.Context,
+    args: list[str] | None = None,
+    **extra: Any,
+) -> Any:
+    """Invoke a click subcommand.
+
+    Calling ctx.Invoke() treats the command like a simple callback and doesn't
+    process any result_callbacks so we use this pattern from the click docs
+    https://click.palletsprojects.com/en/stable/exceptions/#what-if-i-don-t-want-that.
+    """
+    if args is None:
+        args = []
+    sub_ctx = await command.make_context(command.name, args, parent=ctx, **extra)
+    async with sub_ctx:
+        return await command.invoke(sub_ctx)
+
+
 def pass_dev_or_child(wrapped_function: Callable) -> Callable:
     """Pass the device or child to the click command based on the child options."""
     child_help = (
diff --git a/kasa/cli/device.py b/kasa/cli/device.py
index 0ef8a76f..a10f485d 100644
--- a/kasa/cli/device.py
+++ b/kasa/cli/device.py
@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 from pprint import pformat as pf
+from typing import TYPE_CHECKING
 
 import asyncclick as click
 
@@ -82,6 +83,8 @@ async def state(ctx, dev: Device):
         echo()
         from .discover import _echo_discovery_info
 
+        if TYPE_CHECKING:
+            assert dev._discovery_info
         _echo_discovery_info(dev._discovery_info)
 
     return dev.internal_state
diff --git a/kasa/cli/discover.py b/kasa/cli/discover.py
index ff201ce6..07500f3b 100644
--- a/kasa/cli/discover.py
+++ b/kasa/cli/discover.py
@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import asyncio
 from pprint import pformat as pf
+from typing import TYPE_CHECKING, cast
 
 import asyncclick as click
 
@@ -17,8 +18,12 @@ from kasa import (
 from kasa.discover import (
     NEW_DISCOVERY_REDACTORS,
     ConnectAttempt,
+    DeviceDict,
     DiscoveredRaw,
     DiscoveryResult,
+    OnDiscoveredCallable,
+    OnDiscoveredRawCallable,
+    OnUnsupportedCallable,
 )
 from kasa.iot.iotdevice import _extract_sys_info
 from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
@@ -30,15 +35,33 @@ from .common import echo, error
 
 @click.group(invoke_without_command=True)
 @click.pass_context
-async def discover(ctx):
+async def discover(ctx: click.Context):
     """Discover devices in the network."""
     if ctx.invoked_subcommand is None:
         return await ctx.invoke(detail)
 
 
+@discover.result_callback()
+@click.pass_context
+async def _close_protocols(ctx: click.Context, discovered: DeviceDict):
+    """Close all the device protocols if discover was invoked directly by the user."""
+    if _discover_is_root_cmd(ctx):
+        for dev in discovered.values():
+            await dev.disconnect()
+    return discovered
+
+
+def _discover_is_root_cmd(ctx: click.Context) -> bool:
+    """Will return true if discover was invoked directly by the user."""
+    root_ctx = ctx.find_root()
+    return (
+        root_ctx.invoked_subcommand is None or root_ctx.invoked_subcommand == "discover"
+    )
+
+
 @discover.command()
 @click.pass_context
-async def detail(ctx):
+async def detail(ctx: click.Context) -> DeviceDict:
     """Discover devices in the network using udp broadcasts."""
     unsupported = []
     auth_failed = []
@@ -59,10 +82,14 @@ async def detail(ctx):
     from .device import state
 
     async def print_discovered(dev: Device) -> None:
+        if TYPE_CHECKING:
+            assert ctx.parent
         async with sem:
             try:
                 await dev.update()
             except AuthenticationError:
+                if TYPE_CHECKING:
+                    assert dev._discovery_info
                 auth_failed.append(dev._discovery_info)
                 echo("== Authentication failed for device ==")
                 _echo_discovery_info(dev._discovery_info)
@@ -73,9 +100,11 @@ async def detail(ctx):
             echo()
 
     discovered = await _discover(
-        ctx, print_discovered=print_discovered, print_unsupported=print_unsupported
+        ctx,
+        print_discovered=print_discovered if _discover_is_root_cmd(ctx) else None,
+        print_unsupported=print_unsupported,
     )
-    if ctx.parent.parent.params["host"]:
+    if ctx.find_root().params["host"]:
         return discovered
 
     echo(f"Found {len(discovered)} devices")
@@ -96,7 +125,7 @@ async def detail(ctx):
     help="Set flag to redact sensitive data from raw output.",
 )
 @click.pass_context
-async def raw(ctx, redact: bool):
+async def raw(ctx: click.Context, redact: bool) -> DeviceDict:
     """Return raw discovery data returned from devices."""
 
     def print_raw(discovered: DiscoveredRaw):
@@ -116,7 +145,7 @@ async def raw(ctx, redact: bool):
 
 @discover.command()
 @click.pass_context
-async def list(ctx):
+async def list(ctx: click.Context) -> DeviceDict:
     """List devices in the network in a table using udp broadcasts."""
     sem = asyncio.Semaphore()
 
@@ -147,18 +176,24 @@ async def list(ctx):
         f"{'HOST':<15} {'MODEL':<9} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} "
         f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}"
     )
-    return await _discover(
+    discovered = await _discover(
         ctx,
         print_discovered=print_discovered,
         print_unsupported=print_unsupported,
         do_echo=False,
     )
+    return discovered
 
 
 async def _discover(
-    ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True
-):
-    params = ctx.parent.parent.params
+    ctx: click.Context,
+    *,
+    print_discovered: OnDiscoveredCallable | None = None,
+    print_unsupported: OnUnsupportedCallable | None = None,
+    print_raw: OnDiscoveredRawCallable | None = None,
+    do_echo=True,
+) -> DeviceDict:
+    params = ctx.find_root().params
     target = params["target"]
     username = params["username"]
     password = params["password"]
@@ -170,8 +205,9 @@ async def _discover(
     credentials = Credentials(username, password) if username and password else None
 
     if host:
+        host = cast(str, host)
         echo(f"Discovering device {host} for {discovery_timeout} seconds")
-        return await Discover.discover_single(
+        dev = await Discover.discover_single(
             host,
             port=port,
             credentials=credentials,
@@ -180,6 +216,12 @@ async def _discover(
             on_unsupported=print_unsupported,
             on_discovered_raw=print_raw,
         )
+        if dev:
+            if print_discovered:
+                await print_discovered(dev)
+            return {host: dev}
+        else:
+            return {}
     if do_echo:
         echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
     discovered_devices = await Discover.discover(
@@ -193,21 +235,18 @@ async def _discover(
         on_discovered_raw=print_raw,
     )
 
-    for device in discovered_devices.values():
-        await device.protocol.close()
-
     return discovered_devices
 
 
 @discover.command()
 @click.pass_context
-async def config(ctx):
+async def config(ctx: click.Context) -> DeviceDict:
     """Bypass udp discovery and try to show connection config for a device.
 
     Bypasses udp discovery and shows the parameters required to connect
     directly to the device.
     """
-    params = ctx.parent.parent.params
+    params = ctx.find_root().params
     username = params["username"]
     password = params["password"]
     timeout = params["timeout"]
@@ -239,6 +278,7 @@ async def config(ctx):
             f"--encrypt-type {cparams.encryption_type.value} "
             f"{'--https' if cparams.https else '--no-https'}"
         )
+        return {host: dev}
     else:
         error(f"Unable to connect to {host}")
 
@@ -251,7 +291,7 @@ def _echo_dictionary(discovery_info: dict) -> None:
         echo(f"\t{key_name_and_spaces}{value}")
 
 
-def _echo_discovery_info(discovery_info) -> None:
+def _echo_discovery_info(discovery_info: dict) -> None:
     # We don't have discovery info when all connection params are passed manually
     if discovery_info is None:
         return
diff --git a/kasa/cli/main.py b/kasa/cli/main.py
index fbcdf391..debde60c 100755
--- a/kasa/cli/main.py
+++ b/kasa/cli/main.py
@@ -22,6 +22,7 @@ from .common import (
     CatchAllExceptions,
     echo,
     error,
+    invoke_subcommand,
     json_formatter_cb,
     pass_dev_or_child,
 )
@@ -295,9 +296,10 @@ async def cli(
         echo("No host name given, trying discovery..")
         from .discover import discover
 
-        return await ctx.invoke(discover)
+        return await invoke_subcommand(discover, ctx)
 
     device_updated = False
+    device_discovered = False
 
     if type is not None and type not in {"smart", "camera"}:
         from kasa.deviceconfig import DeviceConfig
@@ -351,12 +353,14 @@ async def cli(
             return
         echo(f"Found hostname by alias: {dev.host}")
         device_updated = True
-    else:
+    else:  # host will be set
         from .discover import discover
 
-        dev = await ctx.invoke(discover)
-        if not dev:
+        discovered = await invoke_subcommand(discover, ctx)
+        if not discovered:
             error(f"Unable to create device for {host}")
+        dev = discovered[host]
+        device_discovered = True
 
     # Skip update on specific commands, or if device factory,
     # that performs an update was used for the device.
@@ -372,11 +376,14 @@ async def cli(
 
     ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
 
-    if ctx.invoked_subcommand is None:
+    # discover command has already invoked state
+    if ctx.invoked_subcommand is None and not device_discovered:
         from .device import state
 
         return await ctx.invoke(state)
 
+    return dev
+
 
 @cli.command()
 @pass_dev_or_child