mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-10-31 20:51:54 +00:00 
			
		
		
		
	Ensure connections are closed when cli is finished (#752)
* Ensure connections are closed when cli is finished * Test for close calls on error and success
This commit is contained in:
		
							
								
								
									
										10
									
								
								kasa/cli.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								kasa/cli.py
									
									
									
									
									
								
							| @@ -5,6 +5,7 @@ import json | ||||
| import logging | ||||
| import re | ||||
| import sys | ||||
| from contextlib import asynccontextmanager | ||||
| from functools import singledispatch, wraps | ||||
| from pprint import pformat as pf | ||||
| from typing import Any, Dict, cast | ||||
| @@ -365,7 +366,14 @@ async def cli( | ||||
|     if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family: | ||||
|         await dev.update() | ||||
|  | ||||
|     ctx.obj = dev | ||||
|     @asynccontextmanager | ||||
|     async def async_wrapped_device(device: Device): | ||||
|         try: | ||||
|             yield device | ||||
|         finally: | ||||
|             await device.disconnect() | ||||
|  | ||||
|     ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev)) | ||||
|  | ||||
|     if ctx.invoked_subcommand is None: | ||||
|         return await ctx.invoke(state) | ||||
|   | ||||
| @@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic | ||||
|     if host: | ||||
|         config = DeviceConfig(host=host) | ||||
|  | ||||
|     if (protocol := get_protocol(config=config)) is None: | ||||
|         raise UnsupportedDeviceException( | ||||
|             f"Unsupported device for {config.host}: " | ||||
|             + f"{config.connection_type.device_family.value}" | ||||
|         ) | ||||
|  | ||||
|     try: | ||||
|         return await _connect(config, protocol) | ||||
|     except: | ||||
|         await protocol.close() | ||||
|         raise | ||||
|  | ||||
|  | ||||
| async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device": | ||||
|     debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) | ||||
|     if debug_enabled: | ||||
|         start_time = time.perf_counter() | ||||
| @@ -63,12 +77,6 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic | ||||
|             ) | ||||
|             start_time = time.perf_counter() | ||||
|  | ||||
|     if (protocol := get_protocol(config=config)) is None: | ||||
|         raise UnsupportedDeviceException( | ||||
|             f"Unsupported device for {config.host}: " | ||||
|             + f"{config.connection_type.device_family.value}" | ||||
|         ) | ||||
|  | ||||
|     device_class: Optional[Type[Device]] | ||||
|     device: Optional[Device] = None | ||||
|  | ||||
|   | ||||
| @@ -53,7 +53,7 @@ async def test_connect( | ||||
|         host=host, credentials=Credentials("foor", "bar"), connection_type=ctype | ||||
|     ) | ||||
|     protocol_class = get_protocol(config).__class__ | ||||
|  | ||||
|     close_mock = mocker.patch.object(protocol_class, "close") | ||||
|     dev = await connect( | ||||
|         config=config, | ||||
|     ) | ||||
| @@ -61,8 +61,9 @@ async def test_connect( | ||||
|     assert isinstance(dev.protocol, protocol_class) | ||||
|  | ||||
|     assert dev.config == config | ||||
|  | ||||
|     assert close_mock.call_count == 0 | ||||
|     await dev.disconnect() | ||||
|     assert close_mock.call_count == 1 | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("custom_port", [123, None]) | ||||
| @@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker): | ||||
|     config = DeviceConfig( | ||||
|         host=host, credentials=Credentials("foor", "bar"), connection_type=ctype | ||||
|     ) | ||||
|     protocol_class = get_protocol(config).__class__ | ||||
|     close_mock = mocker.patch.object(protocol_class, "close") | ||||
|     assert close_mock.call_count == 0 | ||||
|     with pytest.raises(SmartDeviceException): | ||||
|         await connect(config=config) | ||||
|     assert close_mock.call_count == 1 | ||||
|  | ||||
|  | ||||
| async def test_connect_http_client(all_fixture_data, mocker): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steven B
					Steven B