diff --git a/kasa/transports/sslaestransport.py b/kasa/transports/sslaestransport.py index 2061d293..dd997914 100644 --- a/kasa/transports/sslaestransport.py +++ b/kasa/transports/sslaestransport.py @@ -7,6 +7,7 @@ import base64 import hashlib import logging import secrets +import socket import ssl from enum import Enum, auto from typing import TYPE_CHECKING, Any, cast @@ -107,12 +108,7 @@ class SslAesTransport(BaseTransport): self._app_url = URL(f"https://{self._host_port}") self._token_url: URL | None = None self._ssl_context: ssl.SSLContext | None = None - ref = str(self._token_url) if self._token_url else str(self._app_url) - self._headers = { - **self.COMMON_HEADERS, - "Host": self._host_port, - "Referer": ref, - } + self._headers: dict | None = None self._seq: int | None = None self._pwd_hash: str | None = None self._username: str | None = None @@ -187,6 +183,34 @@ class SslAesTransport(BaseTransport): ) return self._ssl_context + async def _get_host_ip(self) -> str: + def get_ip() -> str: + # From https://stackoverflow.com/a/28950776 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + # doesn't even have to be reachable + s.connect(("10.254.254.254", 1)) + ip = s.getsockname()[0] + except Exception: + ip = "127.0.0.1" + finally: + s.close() + return ip + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, get_ip) + + async def _get_headers(self) -> dict: + if not self._headers: + this_ip = await self._get_host_ip() + self._headers = { + **self.COMMON_HEADERS, + "Referer": f"https://{this_ip}/", + "Host": self._host_port, + } + return self._headers + async def send_secure_passthrough(self, request: str) -> dict[str, Any]: """Send encrypted message as passthrough.""" if self._state is TransportState.ESTABLISHED and self._token_url: @@ -207,7 +231,7 @@ class SslAesTransport(BaseTransport): tag = self.generate_tag( passthrough_request_str, self._local_nonce, self._pwd_hash, self._seq ) - headers = {**self._headers, "Seq": str(self._seq), "Tapo_tag": tag} + headers = {**await self._get_headers(), "Seq": str(self._seq), "Tapo_tag": tag} self._seq += 1 status_code, resp_dict = await self._http_client.post( url, @@ -326,7 +350,7 @@ class SslAesTransport(BaseTransport): status_code, resp_dict = await http_client.post( self._app_url, json=body, - headers=self._headers, + headers=await self._get_headers(), ssl=await self._get_ssl_context(), ) if status_code != 200: @@ -443,7 +467,7 @@ class SslAesTransport(BaseTransport): status_code, resp_dict = await http_client.post( self._app_url, json=body, - headers=self._headers, + headers=await self._get_headers(), ssl=await self._get_ssl_context(), ) diff --git a/tests/transports/test_sslaestransport.py b/tests/transports/test_sslaestransport.py index 6816fa35..a2f3edfb 100644 --- a/tests/transports/test_sslaestransport.py +++ b/tests/transports/test_sslaestransport.py @@ -27,7 +27,8 @@ from kasa.transports.sslaestransport import ( ) # Transport tests are not designed for real devices -pytestmark = [pytest.mark.requires_dummy] +# SslAesTransport use a socket to get it's own ip address +pytestmark = [pytest.mark.requires_dummy, pytest.mark.enable_socket] MOCK_ADMIN_USER = get_default_credentials(DEFAULT_CREDENTIALS["TAPOCAMERA"]).username MOCK_PWD = "correct_pwd" # noqa: S105