Enable ruff check for ANN (#1139)

This commit is contained in:
Teemu R.
2024-11-10 19:55:13 +01:00
committed by GitHub
parent 6b44fe6242
commit 66eb17057e
89 changed files with 596 additions and 452 deletions

View File

@@ -89,9 +89,19 @@ import logging
import secrets
import socket
import struct
from collections.abc import Awaitable
from asyncio.transports import DatagramTransport
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
NamedTuple,
Optional,
Type,
cast,
)
from aiohttp import ClientSession
@@ -140,8 +150,8 @@ class ConnectAttempt(NamedTuple):
device: type
OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnDiscoveredCallable = Callable[[Device], Coroutine]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Coroutine]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device]
@@ -156,7 +166,7 @@ class _AesDiscoveryQuery:
keypair: KeyPair | None = None
@classmethod
def generate_query(cls):
def generate_query(cls) -> bytearray:
if not cls.keypair:
cls.keypair = KeyPair.create_key_pair(key_size=2048)
secret = secrets.token_bytes(4)
@@ -215,7 +225,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
credentials: Credentials | None = None,
timeout: int | None = None,
) -> None:
self.transport = None
self.transport: DatagramTransport | None = None
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
@@ -239,16 +249,19 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.target_discovered: bool = False
self._started_event = asyncio.Event()
def _run_callback_task(self, coro):
task = asyncio.create_task(coro)
def _run_callback_task(self, coro: Coroutine) -> None:
task: asyncio.Task = asyncio.create_task(coro)
self.callback_tasks.append(task)
async def wait_for_discovery_to_complete(self):
async def wait_for_discovery_to_complete(self) -> None:
"""Wait for the discovery task to complete."""
# Give some time for connection_made event to be received
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
await self._started_event.wait()
try:
if TYPE_CHECKING:
assert isinstance(self.discover_task, asyncio.Task)
await self.discover_task
except asyncio.CancelledError:
# if target_discovered then cancel was called internally
@@ -257,11 +270,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
# Wait for any pending callbacks to complete
await asyncio.gather(*self.callback_tasks)
def connection_made(self, transport) -> None:
def connection_made(self, transport: DatagramTransport) -> None: # type: ignore[override]
"""Set socket options for broadcasting."""
self.transport = transport
self.transport = cast(DatagramTransport, transport)
sock = transport.get_extra_info("socket")
sock = self.transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -292,7 +305,11 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self.transport.sendto(aes_discovery_query, self.target_2) # type: ignore
await asyncio.sleep(sleep_between_packets)
def datagram_received(self, data, addr) -> None:
def datagram_received(
self,
data: bytes,
addr: tuple[str, int],
) -> None:
"""Handle discovery responses."""
if TYPE_CHECKING:
assert _AesDiscoveryQuery.keypair
@@ -338,18 +355,18 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
self._handle_discovered_event()
def _handle_discovered_event(self):
def _handle_discovered_event(self) -> None:
"""If target is in seen_hosts cancel discover_task."""
if self.target in self.seen_hosts:
self.target_discovered = True
if self.discover_task:
self.discover_task.cancel()
def error_received(self, ex):
def error_received(self, ex: Exception) -> None:
"""Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex)
def connection_lost(self, ex): # pragma: no cover
def connection_lost(self, ex: Exception | None) -> None: # pragma: no cover
"""Cancel the discover task if running."""
if self.discover_task:
self.discover_task.cancel()
@@ -372,17 +389,17 @@ class Discover:
@staticmethod
async def discover(
*,
target="255.255.255.255",
on_discovered=None,
discovery_timeout=5,
discovery_packets=3,
interface=None,
on_unsupported=None,
credentials=None,
target: str = "255.255.255.255",
on_discovered: OnDiscoveredCallable | None = None,
discovery_timeout: int = 5,
discovery_packets: int = 3,
interface: str | None = None,
on_unsupported: OnUnsupportedCallable | None = None,
credentials: Credentials | None = None,
username: str | None = None,
password: str | None = None,
port=None,
timeout=None,
port: int | None = None,
timeout: int | None = None,
) -> DeviceDict:
"""Discover supported devices.
@@ -636,7 +653,7 @@ class Discover:
)
if not dev_class:
raise UnsupportedDeviceError(
"Unknown device type: %s" % discovery_result.device_type,
f"Unknown device type: {discovery_result.device_type}",
discovery_result=info,
)
return dev_class