Ensure connections are closed when cli is finished (#752)

* Ensure connections are closed when cli is finished

* Test for close calls on error and success
This commit is contained in:
Steven B 2024-02-14 17:03:50 +00:00 committed by GitHub
parent 5d81e9f94c
commit 45f251e57e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 9 deletions

View File

@ -5,6 +5,7 @@ import json
import logging
import re
import sys
from contextlib import asynccontextmanager
from functools import singledispatch, wraps
from pprint import pformat as pf
from typing import Any, Dict, cast
@ -365,7 +366,14 @@ async def cli(
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
await dev.update()
ctx.obj = dev
@asynccontextmanager
async def async_wrapped_device(device: Device):
try:
yield device
finally:
await device.disconnect()
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
if ctx.invoked_subcommand is None:
return await ctx.invoke(state)

View File

@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
if host:
config = DeviceConfig(host=host)
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
try:
return await _connect(config, protocol)
except:
await protocol.close()
raise
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled:
start_time = time.perf_counter()
@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
)
start_time = time.perf_counter()
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
device_class: Optional[Type[Device]]
device: Optional[Device] = None

View File

@ -53,7 +53,7 @@ async def test_connect(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
dev = await connect(
config=config,
)
@ -61,8 +61,9 @@ async def test_connect(
assert isinstance(dev.protocol, protocol_class)
assert dev.config == config
assert close_mock.call_count == 0
await dev.disconnect()
assert close_mock.call_count == 1
@pytest.mark.parametrize("custom_port", [123, None])
@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
)
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
assert close_mock.call_count == 0
with pytest.raises(SmartDeviceException):
await connect(config=config)
assert close_mock.call_count == 1
async def test_connect_http_client(all_fixture_data, mocker):