From f5b495992389481fbd85e6fa5c75f1e4ca2acf43 Mon Sep 17 00:00:00 2001 From: Wen Chuan Lee Date: Sun, 7 Jun 2026 00:25:57 -0700 Subject: [PATCH] feat: Replace BasicAuth with OAuth2 client credentials flow OpenSky Network dropped basic authentication on March 18, 2026 and now exclusively uses OAuth2 client credentials. This updates the authenticate() method to accept client_id and client_secret instead of BasicAuth, with automatic token refresh. OAuth2 state (credentials + token + expiry) is bundled into a private _OAuthSession dataclass to keep the OpenSky instance attribute count at 8, matching pre-OAuth2. Ref: openskynetwork/opensky-api#85 Ref: home-assistant/core#156643 --- README.md | 5 + src/python_opensky/const.py | 3 + src/python_opensky/opensky.py | 99 ++++++++++++++-- tests/test_states.py | 206 ++++++++++++++++++++++++++++++++-- 4 files changed, 291 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index f3c4ba81..480e8de2 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,11 @@ from python_opensky import OpenSky, StatesResponse async def main() -> None: """Show example of fetching all flight states.""" async with OpenSky() as opensky: + # Optional: authenticate for higher rate limits + # await opensky.authenticate( + # client_id="your_client_id", + # client_secret="your_client_secret", + # ) states: StatesResponse = await opensky.get_states() print(states) diff --git a/src/python_opensky/const.py b/src/python_opensky/const.py index e4529f8b..2b0bad1e 100644 --- a/src/python_opensky/const.py +++ b/src/python_opensky/const.py @@ -46,3 +46,6 @@ class AircraftCategory(int, Enum): MAX_LATITUDE = "lamax" MIN_LONGITUDE = "lomin" MAX_LONGITUDE = "lomax" + +TOKEN_URL = "https://auth.opensky-network.org/auth/realms/opensky-network/protocol/openid-connect/token" # noqa: S105 +TOKEN_REFRESH_MARGIN = 30 diff --git a/src/python_opensky/opensky.py b/src/python_opensky/opensky.py index 07f9e002..8afff507 100644 --- a/src/python_opensky/opensky.py +++ b/src/python_opensky/opensky.py @@ -10,11 +10,18 @@ from importlib import metadata from typing import TYPE_CHECKING, Any, cast -from aiohttp import BasicAuth, ClientError, ClientResponseError, ClientSession -from aiohttp.hdrs import METH_GET +from aiohttp import ClientError, ClientResponseError, ClientSession +from aiohttp.hdrs import METH_GET, METH_POST from yarl import URL -from .const import MAX_LATITUDE, MAX_LONGITUDE, MIN_LATITUDE, MIN_LONGITUDE +from .const import ( + MAX_LATITUDE, + MAX_LONGITUDE, + MIN_LATITUDE, + MIN_LONGITUDE, + TOKEN_REFRESH_MARGIN, + TOKEN_URL, +) from .exceptions import ( OpenSkyConnectionError, OpenSkyError, @@ -29,6 +36,16 @@ VERSION = metadata.version(__package__) +@dataclass +class _OAuthSession: + """OAuth2 client credentials and the access token they hold.""" + + client_id: str + client_secret: str + token: str | None = None + expires_at: datetime | None = None + + @dataclass class OpenSky: """Main class for handling connections with OpenSky.""" @@ -40,21 +57,23 @@ class OpenSky: timezone = UTC _close_session: bool = False _credit_usage: dict[datetime, int] = field(default_factory=dict) - _auth: BasicAuth | None = None + _oauth: _OAuthSession | None = None _contributing_user: bool = False async def authenticate( self, - auth: BasicAuth, + client_id: str, + client_secret: str, *, contributing_user: bool = False, ) -> None: """Authenticate the user.""" - self._auth = auth + self._oauth = _OAuthSession(client_id=client_id, client_secret=client_secret) try: + await self._refresh_token() await self.get_states(bounding_box=BoundingBox(0.0, 0.0, 1.0, 1.0)) except OpenSkyUnauthenticatedError as exc: - self._auth = None + self._oauth = None raise OpenSkyUnauthenticatedError from exc self._contributing_user = contributing_user if contributing_user: @@ -70,7 +89,64 @@ def is_contributing_user(self) -> bool: @property def is_authenticated(self) -> bool: """Return if the user is correctly authenticated.""" - return self._auth is not None + return self._oauth is not None + + async def _refresh_token(self) -> None: + """Refresh the OAuth2 access token.""" + assert self._oauth is not None # noqa: S101 — callers guard + if self.session is None: + self.session = ClientSession() + self._close_session = True + + try: + async with asyncio.timeout(self.request_timeout): + response = await self.session.request( + METH_POST, + TOKEN_URL, + data={ + "grant_type": "client_credentials", + "client_id": self._oauth.client_id, + "client_secret": self._oauth.client_secret, + }, + ) + except TimeoutError as exception: + msg = "Timeout occurred while connecting to the OpenSky API" + raise OpenSkyConnectionError(msg) from exception + except ( + ClientError, + ClientResponseError, + socket.gaierror, + ) as exception: + msg = "Error occurred while communicating with OpenSky API" + raise OpenSkyConnectionError(msg) from exception + + if response.status == 401: + raise OpenSkyUnauthenticatedError + + try: + response.raise_for_status() + except ClientResponseError as exception: + msg = "Error occurred while communicating with OpenSky API" + raise OpenSkyConnectionError(msg) from exception + + token_data = await response.json() + self._oauth.token = token_data["access_token"] + self._oauth.expires_at = datetime.now(UTC) + timedelta( + seconds=token_data["expires_in"] - TOKEN_REFRESH_MARGIN, + ) + + async def _get_access_token(self) -> str | None: + """Get a valid access token, refreshing if needed.""" + if self._oauth is None: + return None + if ( + self._oauth.token + and self._oauth.expires_at + and self._oauth.expires_at > datetime.now(UTC) + ): + return self._oauth.token + await self._refresh_token() + return self._oauth.token async def _request( self, @@ -116,12 +192,15 @@ async def _request( self.session = ClientSession() self._close_session = True + token = await self._get_access_token() + if token is not None: + headers["Authorization"] = f"Bearer {token}" + try: async with asyncio.timeout(self.request_timeout): response = await self.session.request( METH_GET, url.with_query(data), - auth=self._auth, headers=headers, ) response.raise_for_status() @@ -185,7 +264,7 @@ async def get_states( async def get_own_states(self, time: int = 0) -> StatesResponse: """Retrieve state vectors from your own sensors.""" - if not self._auth: + if self._oauth is None: raise OpenSkyUnauthenticatedError params = { "time": time, diff --git a/tests/test_states.py b/tests/test_states.py index 78cd5bee..c033eb15 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -2,10 +2,12 @@ import asyncio from dataclasses import asdict +from datetime import UTC, datetime +from typing import Any import aiohttp import pytest -from aiohttp import BasicAuth, ClientError +from aiohttp import ClientError from aiohttp.web_request import BaseRequest from aresponses import Response, ResponsesMockServer from syrupy.assertion import SnapshotAssertion @@ -22,6 +24,29 @@ from . import load_fixture OPENSKY_URL = "opensky-network.org" +TOKEN_HOST = "auth.opensky-network.org" # noqa: S105 +TOKEN_PATH = "/auth/realms/opensky-network/protocol/openid-connect/token" # noqa: S105 +TOKEN_RESPONSE = '{"access_token": "test-token", "expires_in": 1800}' # noqa: S105 + + +def _add_token_mock( + aresponses: ResponsesMockServer, + *, + repeat: int = 1, + status: int = 200, +) -> None: + """Add a token endpoint mock.""" + aresponses.add( + TOKEN_HOST, + TOKEN_PATH, + "POST", + aresponses.Response( + status=status, + headers={"Content-Type": "application/json"}, + text=TOKEN_RESPONSE, + ), + repeat=repeat, + ) async def test_states( @@ -73,6 +98,7 @@ async def test_own_states( aresponses: ResponsesMockServer, ) -> None: """Test retrieving own states.""" + _add_token_mock(aresponses) aresponses.add( OPENSKY_URL, "/api/states/all", @@ -96,7 +122,8 @@ async def test_own_states( async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) await opensky.authenticate( - BasicAuth(login="test", password="test"), + client_id="test_id", + client_secret="test_secret", contributing_user=True, ) response: StatesResponse = await opensky.get_own_states() @@ -110,6 +137,7 @@ async def test_unavailable_own_states( aresponses: ResponsesMockServer, ) -> None: """Test retrieving no own states.""" + _add_token_mock(aresponses) aresponses.add( OPENSKY_URL, "/api/states/all", @@ -133,7 +161,8 @@ async def test_unavailable_own_states( async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) await opensky.authenticate( - BasicAuth(login="test", password="test"), + client_id="test_id", + client_secret="test_secret", contributing_user=True, ) response: StatesResponse = await opensky.get_own_states() @@ -236,12 +265,13 @@ async def response_handler(_: BaseRequest) -> Response: async def test_auth(aresponses: ResponsesMockServer) -> None: """Test request authentication.""" + _add_token_mock(aresponses, repeat=1) def response_handler(request: BaseRequest) -> Response: """Response handler for this test.""" assert request.headers assert request.headers["Authorization"] - assert request.headers["Authorization"] == "Basic dGVzdDp0ZXN0" + assert request.headers["Authorization"] == "Bearer test-token" return aresponses.Response( status=200, headers={"Content-Type": "application/json"}, @@ -258,7 +288,10 @@ def response_handler(request: BaseRequest) -> Response: async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) - await opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) await opensky.get_states() await opensky.close() @@ -266,20 +299,23 @@ def response_handler(request: BaseRequest) -> Response: async def test_unauthorized(aresponses: ResponsesMockServer) -> None: """Test request authentication.""" aresponses.add( - OPENSKY_URL, - "/api/states/all", - "GET", + TOKEN_HOST, + TOKEN_PATH, + "POST", aresponses.Response( status=401, headers={"Content-Type": "application/json"}, - text=load_fixture("states.json"), + text="{}", ), ) async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) try: - await opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) pytest.fail("Should've thrown exception") except OpenSkyUnauthenticatedError: pass @@ -289,6 +325,7 @@ async def test_unauthorized(aresponses: ResponsesMockServer) -> None: async def test_user_credits(aresponses: ResponsesMockServer) -> None: """Test authenticated user credits.""" + _add_token_mock(aresponses, repeat=2) aresponses.add( OPENSKY_URL, "/api/states/all", @@ -303,10 +340,14 @@ async def test_user_credits(aresponses: ResponsesMockServer) -> None: async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) assert opensky.opensky_credits == 400 - await opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) assert opensky.opensky_credits == 4000 await opensky.authenticate( - BasicAuth(login="test", password="test"), + client_id="test_id", + client_secret="test_secret", contributing_user=True, ) assert opensky.opensky_credits == 8000 @@ -397,3 +438,144 @@ async def test_calculating_credit_usage() -> None: max_longitude=10.9, ) assert opensky.calculate_credit_costs(bounding_box) == 4 + + +async def test_token_refresh(aresponses: ResponsesMockServer) -> None: + """Test that token is refreshed when expired.""" + _add_token_mock(aresponses, repeat=2) + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + repeat=2, + ) + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + # Expire the token + assert opensky._oauth is not None # noqa: SLF001 + opensky._oauth.expires_at = datetime(2020, 1, 1, tzinfo=UTC) # noqa: SLF001 + await opensky.get_states() + await opensky.close() + + +async def test_token_refresh_new_session(aresponses: ResponsesMockServer) -> None: + """Test that _refresh_token creates a session if none exists.""" + _add_token_mock(aresponses) + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + ) + async with OpenSky() as opensky: + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + assert opensky.session + + +async def test_token_refresh_timeout(aresponses: ResponsesMockServer) -> None: + """Test token refresh timeout.""" + + async def response_handler(_: BaseRequest) -> Response: + await asyncio.sleep(2) + return aresponses.Response(body="Timeout") + + aresponses.add( + TOKEN_HOST, + TOKEN_PATH, + "POST", + response_handler, + ) + + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session, request_timeout=1) + with pytest.raises(OpenSkyConnectionError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + await opensky.close() + + +async def test_token_refresh_connection_error() -> None: + """Test token refresh connection error.""" + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + # Patch session.request to raise ClientError + original_request = session.request + + async def mock_request(*_args: Any, **_kwargs: Any) -> None: + raise ClientError + + session.request = mock_request # type: ignore[assignment] + with pytest.raises(OpenSkyConnectionError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + session.request = original_request # type: ignore[assignment] + await opensky.close() + + +async def test_token_refresh_server_error( + aresponses: ResponsesMockServer, +) -> None: + """Test token refresh with server error response.""" + aresponses.add( + TOKEN_HOST, + TOKEN_PATH, + "POST", + aresponses.Response( + status=500, + headers={"Content-Type": "application/json"}, + text="{}", + ), + ) + + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + with pytest.raises(OpenSkyConnectionError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + await opensky.close() + + +async def test_api_returns_401(aresponses: ResponsesMockServer) -> None: + """Test that a 401 from the API raises OpenSkyUnauthenticatedError.""" + _add_token_mock(aresponses) + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=401, + headers={"Content-Type": "application/json"}, + text="{}", + ), + ) + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + with pytest.raises(OpenSkyUnauthenticatedError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + assert opensky.is_authenticated is False + await opensky.close()