mirror of
https://github.com/python-kasa/python-kasa.git
synced 2024-12-23 03:33:35 +00:00
Ensure connections are closed when cli is finished (#752)
* Ensure connections are closed when cli is finished * Test for close calls on error and success
This commit is contained in:
parent
5d81e9f94c
commit
45f251e57e
10
kasa/cli.py
10
kasa/cli.py
@ -5,6 +5,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from functools import singledispatch, wraps
|
from functools import singledispatch, wraps
|
||||||
from pprint import pformat as pf
|
from pprint import pformat as pf
|
||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, cast
|
||||||
@ -365,7 +366,14 @@ async def cli(
|
|||||||
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
|
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
|
||||||
await dev.update()
|
await dev.update()
|
||||||
|
|
||||||
ctx.obj = dev
|
@asynccontextmanager
|
||||||
|
async def async_wrapped_device(device: Device):
|
||||||
|
try:
|
||||||
|
yield device
|
||||||
|
finally:
|
||||||
|
await device.disconnect()
|
||||||
|
|
||||||
|
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
|
||||||
|
|
||||||
if ctx.invoked_subcommand is None:
|
if ctx.invoked_subcommand is None:
|
||||||
return await ctx.invoke(state)
|
return await ctx.invoke(state)
|
||||||
|
@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
|
|||||||
if host:
|
if host:
|
||||||
config = DeviceConfig(host=host)
|
config = DeviceConfig(host=host)
|
||||||
|
|
||||||
|
if (protocol := get_protocol(config=config)) is None:
|
||||||
|
raise UnsupportedDeviceException(
|
||||||
|
f"Unsupported device for {config.host}: "
|
||||||
|
+ f"{config.connection_type.device_family.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await _connect(config, protocol)
|
||||||
|
except:
|
||||||
|
await protocol.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
|
||||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
if debug_enabled:
|
if debug_enabled:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
|
|||||||
)
|
)
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if (protocol := get_protocol(config=config)) is None:
|
|
||||||
raise UnsupportedDeviceException(
|
|
||||||
f"Unsupported device for {config.host}: "
|
|
||||||
+ f"{config.connection_type.device_family.value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
device_class: Optional[Type[Device]]
|
device_class: Optional[Type[Device]]
|
||||||
device: Optional[Device] = None
|
device: Optional[Device] = None
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ async def test_connect(
|
|||||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||||
)
|
)
|
||||||
protocol_class = get_protocol(config).__class__
|
protocol_class = get_protocol(config).__class__
|
||||||
|
close_mock = mocker.patch.object(protocol_class, "close")
|
||||||
dev = await connect(
|
dev = await connect(
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
@ -61,8 +61,9 @@ async def test_connect(
|
|||||||
assert isinstance(dev.protocol, protocol_class)
|
assert isinstance(dev.protocol, protocol_class)
|
||||||
|
|
||||||
assert dev.config == config
|
assert dev.config == config
|
||||||
|
assert close_mock.call_count == 0
|
||||||
await dev.disconnect()
|
await dev.disconnect()
|
||||||
|
assert close_mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("custom_port", [123, None])
|
@pytest.mark.parametrize("custom_port", [123, None])
|
||||||
@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
|
|||||||
config = DeviceConfig(
|
config = DeviceConfig(
|
||||||
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
|
||||||
)
|
)
|
||||||
|
protocol_class = get_protocol(config).__class__
|
||||||
|
close_mock = mocker.patch.object(protocol_class, "close")
|
||||||
|
assert close_mock.call_count == 0
|
||||||
with pytest.raises(SmartDeviceException):
|
with pytest.raises(SmartDeviceException):
|
||||||
await connect(config=config)
|
await connect(config=config)
|
||||||
|
assert close_mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_connect_http_client(all_fixture_data, mocker):
|
async def test_connect_http_client(all_fixture_data, mocker):
|
||||||
|
Loading…
Reference in New Issue
Block a user