mirror of
				https://github.com/python-kasa/python-kasa.git
				synced 2025-11-03 22:22:06 +00:00 
			
		
		
		
	Fix port-override for aes&klap transports (#734)
* Fix port-override for aes&klap transports * Add tests for port override
This commit is contained in:
		@@ -102,7 +102,7 @@ class AesTransport(BaseTransport):
 | 
			
		||||
        self._session_cookie: Optional[Dict[str, str]] = None
 | 
			
		||||
 | 
			
		||||
        self._key_pair: Optional[KeyPair] = None
 | 
			
		||||
        self._app_url = URL(f"http://{self._host}/app")
 | 
			
		||||
        self._app_url = URL(f"http://{self._host}:{self._port}/app")
 | 
			
		||||
        self._token_url: Optional[URL] = None
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Created AES transport for %s", self._host)
 | 
			
		||||
@@ -257,7 +257,6 @@ class AesTransport(BaseTransport):
 | 
			
		||||
        self._session_expire_at = None
 | 
			
		||||
        self._session_cookie = None
 | 
			
		||||
 | 
			
		||||
        url = f"http://{self._host}/app"
 | 
			
		||||
        # Device needs the content length or it will response with 500
 | 
			
		||||
        headers = {
 | 
			
		||||
            **self.COMMON_HEADERS,
 | 
			
		||||
@@ -266,7 +265,7 @@ class AesTransport(BaseTransport):
 | 
			
		||||
        http_client = self._http_client
 | 
			
		||||
 | 
			
		||||
        status_code, resp_dict = await http_client.post(
 | 
			
		||||
            url,
 | 
			
		||||
            self._app_url,
 | 
			
		||||
            json=self._generate_key_pair_payload(),
 | 
			
		||||
            headers=headers,
 | 
			
		||||
            cookies_dict=self._session_cookie,
 | 
			
		||||
 
 | 
			
		||||
@@ -334,6 +334,7 @@ async def cli(
 | 
			
		||||
        )
 | 
			
		||||
        config = DeviceConfig(
 | 
			
		||||
            host=host,
 | 
			
		||||
            port_override=port,
 | 
			
		||||
            credentials=credentials,
 | 
			
		||||
            credentials_hash=credentials_hash,
 | 
			
		||||
            timeout=timeout,
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
"""Module for HttpClientSession class."""
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
@@ -13,6 +14,8 @@ from .exceptions import (
 | 
			
		||||
)
 | 
			
		||||
from .json import loads as json_loads
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cookie_jar() -> aiohttp.CookieJar:
 | 
			
		||||
    """Return a new cookie jar with the correct options for device communication."""
 | 
			
		||||
@@ -54,6 +57,7 @@ class HttpClient:
 | 
			
		||||
 | 
			
		||||
        If the request is provided via the json parameter json will be returned.
 | 
			
		||||
        """
 | 
			
		||||
        _LOGGER.debug("Posting to %s", url)
 | 
			
		||||
        response_data = None
 | 
			
		||||
        self._last_url = url
 | 
			
		||||
        self.client.cookie_jar.clear()
 | 
			
		||||
 
 | 
			
		||||
@@ -121,7 +121,7 @@ class KlapTransport(BaseTransport):
 | 
			
		||||
        self._session_cookie: Optional[Dict[str, Any]] = None
 | 
			
		||||
 | 
			
		||||
        _LOGGER.debug("Created KLAP transport for %s", self._host)
 | 
			
		||||
        self._app_url = URL(f"http://{self._host}/app")
 | 
			
		||||
        self._app_url = URL(f"http://{self._host}:{self._port}/app")
 | 
			
		||||
        self._request_url = self._app_url / "request"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
 
 | 
			
		||||
@@ -209,6 +209,17 @@ async def test_passthrough_errors(mocker, error_code):
 | 
			
		||||
        await transport.send(json_dumps(request))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_port_override():
 | 
			
		||||
    """Test that port override sets the app_url."""
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    config = DeviceConfig(
 | 
			
		||||
        host, credentials=Credentials("foo", "bar"), port_override=12345
 | 
			
		||||
    )
 | 
			
		||||
    transport = AesTransport(config=config)
 | 
			
		||||
 | 
			
		||||
    assert str(transport._app_url) == "http://127.0.0.1:12345/app"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockAesDevice:
 | 
			
		||||
    class _mock_response:
 | 
			
		||||
        def __init__(self, status, json: dict):
 | 
			
		||||
@@ -256,7 +267,7 @@ class MockAesDevice:
 | 
			
		||||
        elif json["method"] == "login_device":
 | 
			
		||||
            return await self._return_login_response(url, json)
 | 
			
		||||
        else:
 | 
			
		||||
            assert str(url) == f"http://{self.host}/app?token={self.token}"
 | 
			
		||||
            assert str(url) == f"http://{self.host}:80/app?token={self.token}"
 | 
			
		||||
            return await self._return_send_response(url, json)
 | 
			
		||||
 | 
			
		||||
    async def _return_handshake_response(self, url: URL, json: Dict[str, Any]):
 | 
			
		||||
 
 | 
			
		||||
@@ -323,14 +323,14 @@ async def test_handshake(
 | 
			
		||||
    async def _return_handshake_response(url: URL, params=None, data=None, *_, **__):
 | 
			
		||||
        nonlocal client_seed, server_seed, device_auth_hash
 | 
			
		||||
 | 
			
		||||
        if str(url) == "http://127.0.0.1/app/handshake1":
 | 
			
		||||
        if str(url) == "http://127.0.0.1:80/app/handshake1":
 | 
			
		||||
            client_seed = data
 | 
			
		||||
            seed_auth_hash = _sha256(
 | 
			
		||||
                seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return _mock_response(200, server_seed + seed_auth_hash)
 | 
			
		||||
        elif str(url) == "http://127.0.0.1/app/handshake2":
 | 
			
		||||
        elif str(url) == "http://127.0.0.1:80/app/handshake2":
 | 
			
		||||
            seed_auth_hash = _sha256(
 | 
			
		||||
                seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash)
 | 
			
		||||
            )
 | 
			
		||||
@@ -367,14 +367,14 @@ async def test_query(mocker):
 | 
			
		||||
    async def _return_response(url: URL, params=None, data=None, *_, **__):
 | 
			
		||||
        nonlocal client_seed, server_seed, device_auth_hash, seq
 | 
			
		||||
 | 
			
		||||
        if str(url) == "http://127.0.0.1/app/handshake1":
 | 
			
		||||
        if str(url) == "http://127.0.0.1:80/app/handshake1":
 | 
			
		||||
            client_seed = data
 | 
			
		||||
            client_seed_auth_hash = _sha256(data + device_auth_hash)
 | 
			
		||||
 | 
			
		||||
            return _mock_response(200, server_seed + client_seed_auth_hash)
 | 
			
		||||
        elif str(url) == "http://127.0.0.1/app/handshake2":
 | 
			
		||||
        elif str(url) == "http://127.0.0.1:80/app/handshake2":
 | 
			
		||||
            return _mock_response(200, b"")
 | 
			
		||||
        elif str(url) == "http://127.0.0.1/app/request":
 | 
			
		||||
        elif str(url) == "http://127.0.0.1:80/app/request":
 | 
			
		||||
            encryption_session = KlapEncryptionSession(
 | 
			
		||||
                protocol._transport._encryption_session.local_seed,
 | 
			
		||||
                protocol._transport._encryption_session.remote_seed,
 | 
			
		||||
@@ -419,16 +419,16 @@ async def test_authentication_failures(mocker, response_status, expectation):
 | 
			
		||||
    async def _return_response(url: URL, params=None, data=None, *_, **__):
 | 
			
		||||
        nonlocal client_seed, server_seed, device_auth_hash, response_status
 | 
			
		||||
 | 
			
		||||
        if str(url) == "http://127.0.0.1/app/handshake1":
 | 
			
		||||
        if str(url) == "http://127.0.0.1:80/app/handshake1":
 | 
			
		||||
            client_seed = data
 | 
			
		||||
            client_seed_auth_hash = _sha256(data + device_auth_hash)
 | 
			
		||||
 | 
			
		||||
            return _mock_response(
 | 
			
		||||
                response_status[0], server_seed + client_seed_auth_hash
 | 
			
		||||
            )
 | 
			
		||||
        elif str(url) == "http://127.0.0.1/app/handshake2":
 | 
			
		||||
        elif str(url) == "http://127.0.0.1:80/app/handshake2":
 | 
			
		||||
            return _mock_response(response_status[1], b"")
 | 
			
		||||
        elif str(url) == "http://127.0.0.1/app/request":
 | 
			
		||||
        elif str(url) == "http://127.0.0.1:80/app/request":
 | 
			
		||||
            return _mock_response(response_status[2], b"")
 | 
			
		||||
 | 
			
		||||
    mocker.patch.object(aiohttp.ClientSession, "post", side_effect=_return_response)
 | 
			
		||||
@@ -438,3 +438,14 @@ async def test_authentication_failures(mocker, response_status, expectation):
 | 
			
		||||
 | 
			
		||||
    with expectation:
 | 
			
		||||
        await protocol.query({})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_port_override():
 | 
			
		||||
    """Test that port override sets the app_url."""
 | 
			
		||||
    host = "127.0.0.1"
 | 
			
		||||
    config = DeviceConfig(
 | 
			
		||||
        host, credentials=Credentials("foo", "bar"), port_override=12345
 | 
			
		||||
    )
 | 
			
		||||
    transport = KlapTransport(config=config)
 | 
			
		||||
 | 
			
		||||
    assert str(transport._app_url) == "http://127.0.0.1:12345/app"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user