mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-22 19:23:34 +00:00
Ensure connections are closed when cli is finished (#752)
* Ensure connections are closed when cli is finished * Test for close calls on error and success
This commit is contained in:
parent
5d81e9f94c
commit
45f251e57e
10
kasa/cli.py
10
kasa/cli.py
@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import singledispatch, wraps
|
||||
from pprint import pformat as pf
|
||||
from typing import Any, Dict, cast
|
||||
@ -365,7 +366,14 @@ async def cli(
|
||||
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
|
||||
await dev.update()
|
||||
|
||||
ctx.obj = dev
|
||||
@asynccontextmanager
|
||||
async def async_wrapped_device(device: Device):
|
||||
try:
|
||||
yield device
|
||||
finally:
|
||||
await device.disconnect()
|
||||
|
||||
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
return await ctx.invoke(state)
|
||||
|
@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
|
||||
if host:
|
||||
config = DeviceConfig(host=host)
|
||||
|
||||
if (protocol := get_protocol(config=config)) is None:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device for {config.host}: "
|
||||
+ f"{config.connection_type.device_family.value}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await _connect(config, protocol)
|
||||
except:
|
||||
await protocol.close()
|
||||
raise
|
||||
|
||||
|
||||
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
if debug_enabled:
|
||||
start_time = time.perf_counter()
|
||||
@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
|
||||
)
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if (protocol := get_protocol(config=config)) is None:
|
||||
raise UnsupportedDeviceException(
|
||||
f"Unsupported device for {config.host}: "
|
||||
+ f"{config.connection_type.device_family.value}"
|
||||
)
|
||||
|
||||
device_class: Optional[Type[Device]]
|
||||
device: Optional[Device] = None
|
||||
|
||||
|
@ -53,7 +53,7 @@ async def test_connect(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
protocol_class = get_protocol(config).__class__
|
||||
|
||||
close_mock = mocker.patch.object(protocol_class, "close")
|
||||
dev = await connect(
|
||||
config=config,
|
||||
)
|
||||
@ -61,8 +61,9 @@ async def test_connect(
|
||||
assert isinstance(dev.protocol, protocol_class)
|
||||
|
||||
assert dev.config == config
|
||||
|
||||
assert close_mock.call_count == 0
|
||||
await dev.disconnect()
|
||||
assert close_mock.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_port", [123, None])
|
||||
@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
|
||||
config = DeviceConfig(
|
||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||
)
|
||||
protocol_class = get_protocol(config).__class__
|
||||
close_mock = mocker.patch.object(protocol_class, "close")
|
||||
assert close_mock.call_count == 0
|
||||
with pytest.raises(SmartDeviceException):
|
||||
await connect(config=config)
|
||||
assert close_mock.call_count == 1
|
||||
|
||||
|
||||
async def test_connect_http_client(all_fixture_data, mocker):
|
||||
|
Loading…
Reference in New Issue
Block a user