Ensure connections are closed when cli is finished (#752)

* Ensure connections are closed when cli is finished

* Test for close calls on error and success
This commit is contained in:
Steven B 2024-02-14 17:03:50 +00:00 committed by GitHub
parent 5d81e9f94c
commit 45f251e57e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 9 deletions

View File

@ -5,6 +5,7 @@ import json
import logging import logging
import re import re
import sys import sys
from contextlib import asynccontextmanager
from functools import singledispatch, wraps from functools import singledispatch, wraps
from pprint import pformat as pf from pprint import pformat as pf
from typing import Any, Dict, cast from typing import Any, Dict, cast
@ -365,7 +366,14 @@ async def cli(
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family: if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
await dev.update() await dev.update()
ctx.obj = dev @asynccontextmanager
async def async_wrapped_device(device: Device):
try:
yield device
finally:
await device.disconnect()
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
return await ctx.invoke(state) return await ctx.invoke(state)

View File

@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
if host: if host:
config = DeviceConfig(host=host) config = DeviceConfig(host=host)
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
try:
return await _connect(config, protocol)
except:
await protocol.close()
raise
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
if debug_enabled: if debug_enabled:
start_time = time.perf_counter() start_time = time.perf_counter()
@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
) )
start_time = time.perf_counter() start_time = time.perf_counter()
if (protocol := get_protocol(config=config)) is None:
raise UnsupportedDeviceException(
f"Unsupported device for {config.host}: "
+ f"{config.connection_type.device_family.value}"
)
device_class: Optional[Type[Device]] device_class: Optional[Type[Device]]
device: Optional[Device] = None device: Optional[Device] = None

View File

@ -53,7 +53,7 @@ async def test_connect(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
) )
protocol_class = get_protocol(config).__class__ protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
dev = await connect( dev = await connect(
config=config, config=config,
) )
@ -61,8 +61,9 @@ async def test_connect(
assert isinstance(dev.protocol, protocol_class) assert isinstance(dev.protocol, protocol_class)
assert dev.config == config assert dev.config == config
assert close_mock.call_count == 0
await dev.disconnect() await dev.disconnect()
assert close_mock.call_count == 1
@pytest.mark.parametrize("custom_port", [123, None]) @pytest.mark.parametrize("custom_port", [123, None])
@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
config = DeviceConfig( config = DeviceConfig(
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
) )
protocol_class = get_protocol(config).__class__
close_mock = mocker.patch.object(protocol_class, "close")
assert close_mock.call_count == 0
with pytest.raises(SmartDeviceException): with pytest.raises(SmartDeviceException):
await connect(config=config) await connect(config=config)
assert close_mock.call_count == 1
async def test_connect_http_client(all_fixture_data, mocker): async def test_connect_http_client(all_fixture_data, mocker):