diff --git a/.github/workflows/docker-configs/ldap-docker-compose.yml b/.github/workflows/docker-configs/ldap-docker-compose.yml deleted file mode 100644 index 5cf12a8..0000000 --- a/.github/workflows/docker-configs/ldap-docker-compose.yml +++ /dev/null @@ -1,19 +0,0 @@ -version: '2' - -services: - openldap: - image: docker.io/bitnami/openldap:latest - ports: - - '1389:1389' - - '1636:1636' - environment: - - LDAP_ADMIN_USERNAME=admin - - LDAP_ADMIN_PASSWORD=adminpassword - - LDAP_USERS=user01,user02 - - LDAP_PASSWORDS=password1,password2 - volumes: - - 'openldap_data:/bitnami/openldap' - -volumes: - openldap_data: - driver: local diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 012a480..1a806f7 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -32,6 +32,8 @@ jobs: pip install . popd + pip install --upgrade "git+https://github.com/davidpcls/bluesky-authentication.git@main" + pip install . pip install -r requirements-dev.txt pip list diff --git a/.github/workflows/docs_publish.yml b/.github/workflows/docs_publish.yml index 49d73c3..fc1f0d0 100644 --- a/.github/workflows/docs_publish.yml +++ b/.github/workflows/docs_publish.yml @@ -44,6 +44,8 @@ jobs: pip install . popd + pip install --upgrade "git+https://github.com/davidpcls/bluesky-authentication.git@main" + pip install . pip install -r requirements-dev.txt pip list diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b7d9d54..cf3a820 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -26,7 +26,7 @@ jobs: - name: Fetch tags run: git fetch --tags --prune --unshallow - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - uses: shogo82148/actions-setup-redis@v1 @@ -36,14 +36,8 @@ jobs: run: | # sudo apt install redis - pushd .. - git clone https://github.com/bitnami/containers.git - cd containers/bitnami/openldap/2.6/debian-12 - docker build -t bitnami/openldap:latest . - popd - # Start LDAP - source start_LDAP.sh + bash continuous_integration/scripts/start_LDAP.sh # These packages are installed in the base environment but may be older # versions. Explicitly upgrade them because they often create @@ -62,6 +56,8 @@ jobs: pip install . popd + pip install --upgrade "git+https://github.com/davidpcls/bluesky-authentication.git@main" + pip install --upgrade pip pip install . pip install -r requirements-dev.txt @@ -70,6 +66,19 @@ jobs: pip list - name: Test with pytest + env: + PYTEST_ADDOPTS: "--durations=20" run: | coverage run -m pytest -vv coverage report -m + - name: Dump LDAP diagnostics on failure + if: failure() + run: | + docker ps + docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml ps + LDAP_CONTAINER_ID=$(docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml ps -q openldap | tr -d '[:space:]') + if [ -n "$LDAP_CONTAINER_ID" ]; then + docker logs --tail 200 "$LDAP_CONTAINER_ID" + else + docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml logs --tail 200 openldap + fi diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/_authentication.py similarity index 62% rename from bluesky_httpserver/authentication.py rename to bluesky_httpserver/_authentication.py index 9772974..992c5fb 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -1,18 +1,18 @@ import asyncio -import enum import hashlib import secrets import uuid as uuid_module import warnings from datetime import datetime, timedelta -from typing import Optional +from typing import Any, Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket +from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn -from fastapi.responses import JSONResponse +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery from fastapi.security.utils import get_authorization_scheme_param +from sqlalchemy.exc import IntegrityError # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: @@ -34,7 +34,15 @@ from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm -from .database.core import create_user, latest_principal_activity, lookup_valid_api_key, lookup_valid_session +from .database.core import ( + create_user, + get_or_create_principal, + latest_principal_activity, + lookup_valid_api_key, + lookup_valid_pending_session_by_device_code, + lookup_valid_pending_session_by_user_code, + lookup_valid_session, +) from .settings import get_sessionmaker, get_settings from .utils import ( API_KEY_COOKIE_NAME, @@ -49,17 +57,16 @@ ALGORITHM = "HS256" UNIT_SECOND = timedelta(seconds=1) +# Device code flow constants +DEVICE_CODE_MAX_AGE = timedelta(minutes=10) +DEVICE_CODE_POLLING_INTERVAL = 5 # seconds + def utcnow(): "UTC now with second resolution" return datetime.utcnow().replace(microsecond=0) -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -134,28 +141,53 @@ def create_refresh_token(session_id, secret_key, expires_delta): return encoded_jwt -def decode_token(token, secret_keys): - credentials_exception = HTTPException( - status_code=401, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) +def _decode_token_with_secret_keys(token, secret_keys): # The first key in settings.secret_keys is used for *encoding*. # All keys are tried for *decoding* until one works or they all - # fail. They supports key rotation. + # fail. They support key rotation. for secret_key in secret_keys: try: payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) - break + return payload except ExpiredSignatureError: # Do not let this be caught below with the other JWTError types. raise except JWTError: # Try the next key in the key rotation. continue - else: - raise credentials_exception - return payload + return None + + +def decode_token(token, secret_keys, proxied_authenticator=None): + credentials_exception = HTTPException( + status_code=401, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + payload = _decode_token_with_secret_keys(token, secret_keys) + if payload is not None: + return payload + if proxied_authenticator is not None: + return proxied_authenticator.decode_token(token) + raise credentials_exception + + +def _extract_scopes(decoded_access_token: dict[str, Any]) -> set[str]: + if "scp" in decoded_access_token: + scp = decoded_access_token["scp"] + return set(scp) if isinstance(scp, list) else set(scp.split(" ")) + if "scope" in decoded_access_token: + return set(decoded_access_token["scope"].split(" ")) + return set() + + +def _get_proxied_authenticator(authenticators): + if not authenticators: + return None + for authenticator in authenticators.values(): + if hasattr(authenticator, "oauth2_schema") and hasattr(authenticator, "decode_token"): + return authenticator + return None async def get_api_key( @@ -274,27 +306,49 @@ def get_current_principal( request.state.cookies_to_set.append({"key": API_KEY_COOKIE_NAME, "value": api_key}) elif access_token is not None: try: - payload = decode_token(access_token, settings.secret_keys) + payload = decode_token( + access_token, + settings.secret_keys, + _get_proxied_authenticator(authenticators), + ) except ExpiredSignatureError: raise HTTPException( status_code=401, detail="Access token has expired. Refresh token.", headers=headers_for_401, ) - principal = schemas.Principal( - uuid=uuid_module.UUID(hex=payload["sub"]), - type=payload["sub_typ"], - identities=[ - schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"] - ], - ) - - # scopes = payload["scp"] - - # Combine scopes for all identities (it is expected to be only one identity). - ids = [_["id"] for _ in payload["ids"] if _["idp"] in settings.authentication_provider_names] - scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) + token_scopes = _extract_scopes(payload) + if "sub_typ" in payload and "ids" in payload: + principal = schemas.Principal( + uuid=uuid_module.UUID(hex=payload["sub"]), + type=payload["sub_typ"], + identities=[ + schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"] + ], + ) + ids = [ + _["id"] + for _ in payload["ids"] + if (_["idp"] in settings.authentication_provider_names) + and api_access_manager.is_user_known(_["id"]) + ] + else: + identity_id = payload.get("user") or payload["sub"] + provider = ( + settings.authentication_provider_names[0] + if settings.authentication_provider_names + else _DEFAULT_ANONYMOUS_PROVIDER_NAME + ) + with get_sessionmaker(settings.database_settings)() as db: + principal_orm = get_or_create_principal(db, provider, identity_id) + principal = schemas.Principal( + uuid=principal_orm.uuid, + type="user", + identities=[schemas.Identity(id=identity_id, provider=provider)], + ) + ids = [identity_id] if api_access_manager.is_user_known(identity_id) else [] + scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) if ids else set(token_scopes) roles_sets = [api_access_manager.get_user_roles(_) for _ in ids] roles = set.union(*roles_sets) if roles_sets else set() @@ -355,7 +409,7 @@ def get_current_principal( return principal -def get_current_principal_websocket( +async def get_current_principal_websocket( websocket: WebSocket, scopes: str, ): @@ -367,13 +421,31 @@ def get_current_principal_websocket( auth_header = websocket.headers.get("Authorization", "") access_token, api_key = None, None - # Currently we do not support authentication with tokens - # if auth_header.startswith("Bearer "): - # access_token = auth_header[len("Bearer") :].strip() - if auth_header.startswith("ApiKey "): - api_key = auth_header[len("ApiKey") :].strip() + scheme, param = get_authorization_scheme_param(auth_header) + if scheme.lower() == "bearer": + access_token = param + elif scheme.lower() == "apikey": + api_key = param + + if access_token is None: + access_token = websocket.query_params.get("access_token") + if api_key is None: + api_key = websocket.query_params.get("api_key") principal = None + websocket.state.already_accepted = False + no_credentials = (access_token is None) and (api_key is None) + if no_credentials and not settings.allow_anonymous_access: + try: + await websocket.accept() + websocket.state.already_accepted = True + message = await asyncio.wait_for(websocket.receive_json(), timeout=1) + if isinstance(message, dict) and message.get("type") == "auth": + access_token = message.get("access_token") + api_key = message.get("api_key") + except Exception: + return None + try: principal = get_current_principal( request=websocket, @@ -455,7 +527,8 @@ async def auth_code( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(request) + user_session_state = await authenticator.authenticate(request) + username = user_session_state.user_name if user_session_state else None if username and api_access_manager.is_user_known(username): scopes = api_access_manager.get_user_scopes(username) @@ -484,7 +557,10 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate( + username=form_data.username, password=form_data.password + ) + username = user_session_state.user_name if user_session_state else None err_msg = None if not username: @@ -507,6 +583,442 @@ async def handle_credentials( return handle_credentials +def create_pending_session(db): + """ + Create a pending session for device code flow. + + Returns a dict with 'user_code' (user-facing code) and 'device_code' (for polling). + """ + device_code = secrets.token_bytes(32) + hashed_device_code = hashlib.sha256(device_code).digest() + for _ in range(3): + user_code = secrets.token_hex(4).upper() # 8 digit code + pending_session = orm.PendingSession( + user_code=user_code, + hashed_device_code=hashed_device_code, + expiration_time=utcnow() + DEVICE_CODE_MAX_AGE, + ) + db.add(pending_session) + try: + db.commit() + except IntegrityError: + # Since the user_code is short, we cannot completely dismiss the + # possibility of a collision. Retry. + db.rollback() + continue + break + formatted_user_code = f"{user_code[:4]}-{user_code[4:]}" + return { + "user_code": formatted_user_code, + "device_code": device_code.hex(), + } + + +def build_authorize_route(authenticator, provider): + """Build a GET route that redirects the browser to the OIDC provider for authentication.""" + + async def authorize_redirect( + request: Request, + state: Optional[str] = Query(None), + ): + """Redirect browser to OAuth provider for authentication.""" + redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code" + + requested_scopes = {"openid", "offline_access"} + requested_scopes.update(getattr(authenticator, "extra_scopes", [])) + params = { + "client_id": authenticator.client_id, + "response_type": "code", + "scope": " ".join(sorted(requested_scopes)), + "redirect_uri": redirect_uri, + "prompt": "login", + } + if state: + params["state"] = state + + auth_url = authenticator.authorization_endpoint.copy_with(params=params) + return RedirectResponse(url=str(auth_url)) + + return authorize_redirect + + +def build_device_code_authorize_route(authenticator, provider): + """Build a POST route that initiates the device code flow for CLI/headless clients.""" + + async def device_code_authorize( + request: Request, + settings: BaseSettings = Depends(get_settings), + ): + """ + Initiate device code flow. + + Returns authorization_uri for the user to visit in browser, + and device_code + user_code for the CLI client to poll. + """ + request.state.endpoint = "auth" + with get_sessionmaker(settings.database_settings)() as db: + pending_session = create_pending_session(db) + + verification_uri = f"{get_base_url(request)}/auth/provider/{provider}/token" + authorization_uri = authenticator.authorization_endpoint.copy_with( + params={ + "client_id": authenticator.client_id, + "response_type": "code", + "scope": " ".join( + sorted({"openid", "offline_access", *getattr(authenticator, "extra_scopes", [])}) + ), + "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", + "state": pending_session["user_code"].replace("-", ""), + } + ) + return { + "authorization_uri": str(authorization_uri), # URL that user should visit in browser + "verification_uri": str(verification_uri), # URL that terminal client will poll + "interval": DEVICE_CODE_POLLING_INTERVAL, # suggested polling interval + "device_code": pending_session["device_code"], + "expires_in": int(DEVICE_CODE_MAX_AGE.total_seconds()), # seconds + "user_code": pending_session["user_code"], + } + + return device_code_authorize + + +async def _complete_device_code_authorization( + request: Request, + authenticator, + provider: str, + code: str, + user_code: str, + settings: BaseSettings, + api_access_manager, +): + request.state.endpoint = "auth" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + normalized_user_code = user_code.upper().replace("-", "").strip() + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) + if pending_session is None: + error_html = f""" + + +Error + + + +

Authorization Failed

+
+ Invalid user code. It may have been mistyped, or the pending request may have expired. +
+
Try again + + +""" + return HTMLResponse(content=error_html, status_code=401) + + # Authenticate with the OIDC provider using the authorization code + user_session_state = await authenticator.authenticate(request) + if not user_session_state: + error_html = """ + + +Authentication Failed + + + +

Authentication Failed

+
+ User code was correct but authentication with the identity provider failed. + Please contact the administrator. +
+ + +""" + return HTMLResponse(content=error_html, status_code=401) + + username = user_session_state.user_name + if not api_access_manager.is_user_known(username): + error_html = f""" + + +Authorization Failed + + + +

Authorization Failed

+
User '{username}' is not authorized to access this server.
+ + +""" + return HTMLResponse(content=error_html, status_code=403) + + # Create the session + session = await asyncio.get_running_loop().run_in_executor( + None, _create_session_orm, settings, provider, username, db + ) + + # Link the pending session to the real session + pending_session.session_id = session.id + db.add(pending_session) + db.commit() + + success_html = f""" + + +Success + + + +

Success!

+
+ You have been authenticated. Return to your terminal application - + within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in. +
+ + +""" + return HTMLResponse(content=success_html) + + +def build_device_code_form_route(authenticator, provider): + """Build a GET route that shows the user code entry form.""" + + async def device_code_form( + request: Request, + code: str, + state: Optional[str] = Query(None), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Show form for user to enter user code after browser auth.""" + if state: + return await _complete_device_code_authorization( + request=request, + authenticator=authenticator, + provider=provider, + code=code, + user_code=state, + settings=settings, + api_access_manager=api_access_manager, + ) + + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + html_content = f""" + + + + Authorize Session + + + +

Authorize Bluesky HTTP Server Session

+
+ + + +
+ +
+ + +""" + return HTMLResponse(content=html_content) + + return device_code_form + + +def build_device_code_submit_route(authenticator, provider): + """Build a POST route that handles user code submission after browser auth.""" + + async def device_code_submit( + request: Request, + code: str = Form(), + user_code: str = Form(), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Handle user code submission and link to authenticated session.""" + return await _complete_device_code_authorization( + request=request, + authenticator=authenticator, + provider=provider, + code=code, + user_code=user_code, + settings=settings, + api_access_manager=api_access_manager, + ) + + return device_code_submit + + +def _create_session_orm(settings, identity_provider, id, db): + """ + Create a session and return the ORM object (for device code flow). + + Unlike create_session(), this returns the ORM object so we can link it + to the pending session. + """ + # Have we seen this Identity before? + identity = ( + db.query(orm.Identity) + .filter(orm.Identity.id == id) + .filter(orm.Identity.provider == identity_provider) + .first() + ) + now = utcnow() + if identity is None: + # We have not. Make a new Principal and link this new Identity to it. + principal = create_user(db, identity_provider, id) + (new_identity,) = principal.identities + new_identity.latest_login = now + else: + identity.latest_login = now + principal = identity.principal + + session = orm.Session( + principal_id=principal.id, + expiration_time=utcnow() + settings.session_max_age, + ) + db.add(session) + db.commit() + db.refresh(session) + return session + + +def build_device_code_token_route(authenticator, provider): + """Build a POST route for the CLI client to poll for tokens.""" + + async def device_code_token( + request: Request, + body: schemas.DeviceCode, + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """ + Poll for tokens after device code flow authentication. + + Returns tokens if the user has authenticated, or 400 with + 'authorization_pending' error if still waiting. + """ + request.state.endpoint = "auth" + device_code_hex = body.device_code + try: + device_code = bytes.fromhex(device_code_hex) + except Exception: + # Not valid hex, therefore not a valid device_code + raise HTTPException(status_code=401, detail="Invalid device code") + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_device_code(db, device_code) + if pending_session is None: + raise HTTPException( + status_code=404, + detail="No such device_code. The pending request may have expired.", + ) + if pending_session.session_id is None: + raise HTTPException(status_code=400, detail={"error": "authorization_pending"}) + + session = pending_session.session + principal = session.principal + + # Get scopes for the user + # Find an identity to get the username + identity = db.query(orm.Identity).filter(orm.Identity.principal_id == principal.id).first() + if identity and api_access_manager.is_user_known(identity.id): + scopes = api_access_manager.get_user_scopes(identity.id) + else: + scopes = set() + + # The pending session can only be used once + db.delete(pending_session) + db.commit() + + # Generate tokens + data = { + "sub": principal.uuid.hex, + "sub_typ": principal.type.value, + "scp": list(scopes), + "ids": [{"id": ident.id, "idp": ident.provider} for ident in principal.identities], + } + access_token = create_access_token( + data=data, + expires_delta=settings.access_token_max_age, + secret_key=settings.secret_keys[0], + ) + refresh_token = create_refresh_token( + session_id=session.uuid.hex, + expires_delta=settings.refresh_token_max_age, + secret_key=settings.secret_keys[0], + ) + + return { + "access_token": access_token, + "expires_in": int(settings.access_token_max_age / UNIT_SECOND), + "refresh_token": refresh_token, + "refresh_token_expires_in": int(settings.refresh_token_max_age / UNIT_SECOND), + "token_type": "bearer", + } + + return device_code_token + + def generate_apikey(db, principal, apikey_params, request, allowed_scopes, source_api_key_scopes): # Use API key scopes if API key is generated based on existing API key, otherwise used allowed scopes if (source_api_key_scopes is not None) and ("inherit" not in source_api_key_scopes): diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index f09acb3..84470d1 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -9,13 +9,13 @@ import urllib.parse from functools import lru_cache, partial +from bluesky_authentication.protocols import ExternalAuthenticator, InternalAuthenticator from bluesky_queueserver.manager.comms import validate_zmq_key from bluesky_queueserver_api.zmq.aio import REManagerAPI from fastapi import APIRouter, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from .authentication import Mode from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired @@ -160,6 +160,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server from .authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, oauth2_scheme, ) @@ -179,20 +184,41 @@ def build_app(authentication=None, api_access=None, resource_access=None, server for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, InternalAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): + # Standard OAuth callback route (authorization code flow) authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) authentication_router.post(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) + # Device code flow routes for CLI/headless clients + # GET /authorize - redirects browser to OIDC provider + authentication_router.get(f"/provider/{provider}/authorize")( + build_authorize_route(authenticator, provider) + ) + # POST /authorize - initiates device code flow (returns device_code, user_code, etc.) + authentication_router.post(f"/provider/{provider}/authorize")( + build_device_code_authorize_route(authenticator, provider) + ) + # GET /device_code - shows user code entry form + authentication_router.get(f"/provider/{provider}/device_code")( + build_device_code_form_route(authenticator, provider) + ) + # POST /device_code - handles user code submission after browser auth + authentication_router.post(f"/provider/{provider}/device_code")( + build_device_code_submit_route(authenticator, provider) + ) + # POST /token - CLI client polls this for tokens + authentication_router.post(f"/provider/{provider}/token")( + build_device_code_token_route(authenticator, provider) + ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router(custom_router, prefix=f"/provider/{provider}") diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py new file mode 100644 index 0000000..3475cd1 --- /dev/null +++ b/bluesky_httpserver/authentication/__init__.py @@ -0,0 +1,37 @@ +from .._authentication import ( + _extract_scopes, + base_authentication_router, + build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, + build_handle_credentials_route, + get_current_principal, + get_current_principal_websocket, + oauth2_scheme, +) +from .authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +__all__ = [ + "ExternalAuthenticator", + "InternalAuthenticator", + "UserSessionState", + "_extract_scopes", + "get_current_principal", + "get_current_principal_websocket", + "base_authentication_router", + "build_auth_code_route", + "build_authorize_route", + "build_device_code_authorize_route", + "build_device_code_form_route", + "build_device_code_submit_route", + "build_device_code_token_route", + "build_handle_credentials_route", + "oauth2_scheme", +] diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py new file mode 100644 index 0000000..ae7c599 --- /dev/null +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -0,0 +1,31 @@ +try: + from bluesky_authentication.protocols import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, + ) +except ModuleNotFoundError: + from abc import ABC + from dataclasses import dataclass + from typing import Optional + + from fastapi import Request + + @dataclass + class UserSessionState: + """Data transfer class to communicate custom session state information.""" + + user_name: str + state: dict = None + + class InternalAuthenticator(ABC): + """Base class for authenticators that use username/password credentials.""" + + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: + raise NotImplementedError + + class ExternalAuthenticator(ABC): + """Base class for authenticators that use external identity providers.""" + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 61c2da4..6d4ad01 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,830 +1,39 @@ -import asyncio -import functools -import logging -import re -import secrets -from collections.abc import Iterable - -from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt -from starlette.responses import RedirectResponse - -from .authentication import Mode -from .utils import modules_available - -logger = logging.getLogger(__name__) - - -class DummyAuthenticator: - """ - For test and demo purposes only! - - Accept any username and any password. - - """ - - mode = Mode.password - - async def authenticate(self, username: str, password: str): - return username - - -class DictionaryAuthenticator: - """ - For test and demo purposes only! - - Check passwords from a dictionary of usernames mapped to passwords. - - Parameters - ---------- - - users_to_passwords: dict(str, str) - Mapping of usernames to passwords. - """ - - mode = Mode.password - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - users_to_password: - type: object - description: | - Mapping usernames to password. Environment variable expansion should be - used to avoid placing passwords directly in configuration. -""" - - def __init__(self, users_to_passwords): - self._users_to_passwords = users_to_passwords - - async def authenticate(self, username: str, password: str): - true_password = self._users_to_passwords.get(username) - if not true_password: - # Username is not valid. - return - if secrets.compare_digest(true_password, password): - return username - - -class PAMAuthenticator: - mode = Mode.password - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - service: - type: string - description: PAM service. Default is 'login'. -""" - - def __init__(self, service="login"): - if not modules_available("pamela"): - raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") - self.service = service - # TODO Try to open a PAM session. - - async def authenticate(self, username: str, password: str): - import pamela - - try: - pamela.authenticate(username, password, service=self.service) - except pamela.PAMError: - # Authentication failed. - return - else: - return username - - -class OIDCAuthenticator: - mode = Mode.external - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - client_id: - type: string - client_secret: - type: string - redirect_uri: - type: string - token_uri: - type: string - authorization_endpoint: - type: string - public_keys: - type: array - item: - type: object - properties: - - alg: - type: string - - e - type: string - - kid - type: string - - kty - type: string - - n - type: string - - use - type: string - required: - - alg - - e - - kid - - kty - - n - - use -""" - - def __init__( - self, - client_id, - client_secret, - redirect_uri, - public_keys, - token_uri, - authorization_endpoint, - confirmation_message, - ): - self.client_id = client_id - self.client_secret = client_secret - self.confirmation_message = confirmation_message - self.redirect_uri = redirect_uri - self.public_keys = public_keys - self.token_uri = token_uri - self.authorization_endpoint = authorization_endpoint.format(client_id=client_id, redirect_uri=redirect_uri) - - async def authenticate(self, request): - code = request.query_params["code"] - response = await exchange_code(self.token_uri, code, self.client_id, self.client_secret, self.redirect_uri) - response_body = response.json() - if response.is_error: - logger.error("Authentication error: %r", response_body) - return None - response_body = response.json() - id_token = response_body["id_token"] - access_token = response_body["access_token"] - # Match the kid in id_token to a key in the list of public_keys. - key = find_key(id_token, self.public_keys) - try: - verified_body = jwt.decode(id_token, key, access_token=access_token, audience=self.client_id) - except JWTError: - logger.exception( - "Authentication error. Unverified token: %r", - jwt.get_unverified_claims(id_token), - ) - return None - return verified_body["sub"] - - -class KeyNotFoundError(Exception): - pass - - -def find_key(token, keys): - """ - Find a key from the configured keys based on the kid claim of the token - - Parameters - ---------- - token : token to search for the kid from - keys: list of keys - - Raises - ------ - KeyNotFoundError: - returned if the token does not have a kid claim - - Returns - ------ - key: found key object - """ - - unverified = jwt.get_unverified_header(token) - kid = unverified.get("kid") - if not kid: - raise KeyNotFoundError("No 'kid' in token") - - for key in keys: - if key["kid"] == kid: - return jwk.construct(key) - return KeyNotFoundError(f"Token specifies {kid} but we have {[k['kid'] for k in keys]}") - - -async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri): - """Method that talks to an IdP to exchange a code for an access_token and/or id_token - Args: - token_url ([type]): [description] - auth_code ([type]): [description] - """ - if not modules_available("httpx"): - raise ModuleNotFoundError("This authenticator requires 'httpx'. (pip install httpx)") - import httpx - - response = httpx.post( - url=token_uri, - data={ - "grant_type": "authorization_code", - "client_id": client_id, - "redirect_uri": redirect_uri, - "code": auth_code, - "client_secret": client_secret, - }, - ) - return response - - -class SAMLAuthenticator: - mode = Mode.external - - def __init__( - self, - saml_settings, # See EXAMPLE_SAML_SETTINGS below. - attribute_name, # which SAML attribute to use as 'id' for Idenity - confirmation_message=None, - ): - self.saml_settings = saml_settings - self.attribute_name = attribute_name - self.confirmation_message = confirmation_message - self.authorization_endpoint = "/login" - - router = APIRouter() - - if not modules_available("onelogin"): - # The PyPI package name is 'python3-saml' - # but it imports as 'onelogin'. - # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") - - from onelogin.saml2.auth import OneLogin_Saml2_Auth - - @router.get("/login") - async def saml_login(request: Request): - req = await prepare_saml_from_fastapi_request(request) - auth = OneLogin_Saml2_Auth(req, self.saml_settings) - # saml_settings = auth.get_settings() - # metadata = saml_settings.get_sp_metadata() - # errors = saml_settings.validate_metadata(metadata) - # if len(errors) == 0: - # print(metadata) - # else: - # print("Error found on Metadata: %s" % (', '.join(errors))) - callback_url = auth.login() - response = RedirectResponse(url=callback_url) - return response - - self.include_routers = [router] - - async def authenticate(self, request): - if not modules_available("onelogin"): - raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") - from onelogin.saml2.auth import OneLogin_Saml2_Auth - - req = await prepare_saml_from_fastapi_request(request, True) - auth = OneLogin_Saml2_Auth(req, self.saml_settings) - auth.process_response() # Process IdP response - errors = auth.get_errors() # This method receives an array with the errors - if errors: - raise Exception( - "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) - ) - if auth.is_authenticated(): - # Return a string that the Identity can use as id. - attribute_as_list = auth.get_attributes()[self.attribute_name] - # Confused in what situation this would have more than one item.... - assert len(attribute_as_list) == 1 - return attribute_as_list[0] - else: - return None - - -async def prepare_saml_from_fastapi_request(request, debug=False): - form_data = await request.form() - rv = { - "http_host": request.client.host, - "server_port": request.url.port, - "script_name": request.url.path, - "post_data": {}, - "get_data": {}, - # Advanced request options - # "https": "", - # "request_uri": "", - # "query_string": "", - # "validate_signature_from_qs": False, - # "lowercase_urlencoding": False - } - if request.query_params: - rv["get_data"] = (request.query_params,) - if "SAMLResponse" in form_data: - SAMLResponse = form_data["SAMLResponse"] - rv["post_data"]["SAMLResponse"] = SAMLResponse - if "RelayState" in form_data: - RelayState = form_data["RelayState"] - rv["post_data"]["RelayState"] = RelayState - return rv - - -class LDAPAuthenticator: - """ - LDAP authenticator. - The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator - The parameter ``use_tls`` was added for convenience of testing. - - Parameters - ---------- - server_address: str or list(str) - Address(es) of the LDAP server(s) to contact. A string value may represent a single - server, a list of strings may represent one or more servers. If a server address - includes port, then the value of ``server_port`` is ignored, otherwise ``server_port`` - or the default port is used to access the server. - - Could be an IP address or hostname. - server_port: int or None - Port on which to contact the LDAP server. Default port is used if ``None``. - - Defaults to ``636`` if ``use_ssl`` is set, ``389`` otherwise. - use_ssl: boolean - Use SSL to communicate with the LDAP server. - - Deprecated in version 3 of LDAP. Your LDAP server must be configured to support this, however. - use_tls: boolean - Enable/disable TLS if ``use_ssl`` is False. By default TLS is enabled. It should not be disabled - in production systems. - - connect_timeout: float - Timeout used for connecting to the LDAP server. Default: 5. - - receive_timeout: float - Timeout used for communication with the LDAP server, e.g. this timeout is used to wait for - completion of 2FA. For smooth operation it should probably exceed timeout set at LDAP server. - Default: 60. - - bind_dn_template: list or str - Template from which to construct the full dn - when authenticating to LDAP. ``{username}`` is replaced - with the actual username used to log in. - - If your LDAP is set in such a way that the userdn can not - be formed from a template, but must be looked up with an attribute - (such as uid or ``sAMAccountName``), please see ``lookup_dn``. It might - be particularly relevant for ActiveDirectory installs. - - Unicode Example: - - .. code-block:: - - "uid={username},ou=people,dc=wikimedia,dc=org" - - List Example: - - .. code-block:: - - [ - "uid={username},ou=people,dc=wikimedia,dc=org", - "uid={username},ou=Developers,dc=wikimedia,dc=org" - ] - allowed_groups: list or None - List of LDAP group DNs that users could be members of to be granted access. - - If a user is in any one of the listed groups, then that user is granted access. - Membership is tested by fetching info about each group and looking for the User's - dn to be a value of one of `member` or `uniqueMember`, *or* if the username being - used to log in with is value of the `uid`. - - Set to an empty list or None to allow all users that have an LDAP account to log in, - without performing any group membership checks. - valid_username_regex: str - Regex for validating usernames - those that do not match this regex will be rejected. - - This is primarily used as a measure against LDAP injection, which has fatal security - considerations. The default works for most LDAP installations, but some users might need - to modify it to fit their custom installs. If you are modifying it, be sure to understand - the implications of allowing additional characters in usernames and what that means for - LDAP injection issues. See https://www.owasp.org/index.php/LDAP_injection for an overview - of LDAP injection. - lookup_dn: boolean - Form user's DN by looking up an entry from directory - - By default, LDAPAuthenticator finds the user's DN by using `bind_dn_template`. - However, in some installations, the user's DN does not contain the username, and - hence needs to be looked up. You can set this to True and then use ``user_search_base`` - and ``user_attribute`` to accomplish this. - user_search_base: str - Base for looking up user accounts in the directory, if `lookup_dn` is set to True. - - LDAPAuthenticator will search all objects matching under this base where the `user_attribute` - is set to the current username to form the userdn. - - For example, if all users objects existed under the base ou=people,dc=wikimedia,dc=org, and - the username users use is set with the attribute `uid`, you can use the following config: - - .. code-block:: - - c.LDAPAuthenticator.lookup_dn = True - c.LDAPAuthenticator.lookup_dn_search_filter = '({login_attr}={login})' - c.LDAPAuthenticator.lookup_dn_search_user = 'ldap_search_user_technical_account' - c.LDAPAuthenticator.lookup_dn_search_password = 'secret' - c.LDAPAuthenticator.user_search_base = 'ou=people,dc=wikimedia,dc=org' - c.LDAPAuthenticator.user_attribute = 'sAMAccountName' - c.LDAPAuthenticator.lookup_dn_user_dn_attribute = 'cn' - c.LDAPAuthenticator.bind_dn_template = '{username}' - user_attribute: str - Attribute containing user's name, if ``lookup_dn`` is set to True. - - See ``user_search_base`` for info on how this attribute is used. - - For most LDAP servers, this is uid. For Active Directory, it is - sAMAccountName. - lookup_dn_search_filter: str or None - How to query LDAP for user name lookup, if ``lookup_dn`` is set to True. - lookup_dn_search_user: str or None - Technical account for user lookup, if ``lookup_dn`` is set to True. - - If both lookup_dn_search_user and lookup_dn_search_password are None, - then anonymous LDAP query will be done. - lookup_dn_search_password: str or None - Technical account for user lookup, if ``lookup_dn`` is set to True. - lookup_dn_user_dn_attribute: str or None - Attribute containing user's name needed for building DN string, if ``lookup_dn`` is set to True. - - See ``user_search_base`` for info on how this attribute is used. - - For most LDAP servers, this is username. For Active Directory, it is cn. - escape_userdn: boolean - If set to True, escape special chars in userdn when authenticating in LDAP. - - On some LDAP servers, when userdn contains chars like '(', ')', '\' - authentication may fail when those chars - are not escaped. - search_filter: str - LDAP3 Search Filter whose results are allowed access - attributes: list or None - List of attributes to be searched - auth_state_attributes: list or None - List of attributes to be returned in auth_state for a user - use_lookup_dn_username: boolean - If set to true uses the ``lookup_dn_user_dn_attribute`` attribute as username instead of - the supplied one. - - This can be useful in an heterogeneous environment, when supplying a UNIX username - to authenticate against AD. - - Examples - -------- - - Using the authenticator class (the code runs in ``asyncio`` loop): - - .. code-block:: - - from bluesky_httpserver.authenticators import LDAPAuthenticator - authenticator = LDAPAuthenticator( - "localhost", 1389, bind_dn_template="cn={username},ou=users,dc=example,dc=org", use_tls=False - ) - await authenticator.authenticate("user01", "password1") - await authenticator.authenticate("user02", "password2") - - - Simple example of a config file (e.g. ``config_ldap.yml``): - - .. code-block:: - - uvicorn: - host: localhost - port: 60610 - authentication: - providers: - - provider: ldap_local - authenticator: bluesky_httpserver.authenticators:LDAPAuthenticator - args: - server_address: localhost - server_port: 1389 - bind_dn_template: "cn={username},ou=users,dc=example,dc=org" - use_tls: false - use_ssl: false - tiled_admins: - - provider: ldap_local - id: user02 - """ - - mode = Mode.password - - def __init__( - self, - server_address, - server_port=None, - *, - use_ssl=False, - use_tls=True, - connect_timeout=5, - receive_timeout=60, - bind_dn_template=None, - allowed_groups=None, - valid_username_regex=r"^[a-z][.a-z0-9_-]*$", - lookup_dn=False, - user_search_base=None, - user_attribute=None, - lookup_dn_search_filter="({login_attr}={login})", - lookup_dn_search_user=None, - lookup_dn_search_password=None, - lookup_dn_user_dn_attribute=None, - escape_userdn=False, - search_filter="", - attributes=None, - auth_state_attributes=None, - use_lookup_dn_username=True, - ): - self.use_ssl = use_ssl - self.use_tls = use_tls - self.connect_timeout = connect_timeout - self.receive_timeout = receive_timeout - self.bind_dn_template = bind_dn_template - self.allowed_groups = allowed_groups - self.valid_username_regex = valid_username_regex - self.lookup_dn = lookup_dn - self.user_search_base = user_search_base - self.user_attribute = user_attribute - self.lookup_dn_search_filter = lookup_dn_search_filter - self.lookup_dn_search_user = lookup_dn_search_user - self.lookup_dn_search_password = lookup_dn_search_password - self.lookup_dn_user_dn_attribute = lookup_dn_user_dn_attribute - self.escape_userdn = escape_userdn - self.search_filter = search_filter - self.attributes = attributes if attributes else [] - self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] - self.use_lookup_dn_username = use_lookup_dn_username - - if isinstance(server_address, str): - server_address_list = [server_address] - elif isinstance(server_address, Iterable): - server_address_list = list(server_address) - else: - raise TypeError( - f"Unsupported type of `server_address` (list): server_address={server_address} " - f"type(server_address)={type(server_address)}" - ) - if not server_address_list: - raise ValueError("No servers are specified: 'server_address' is an empty list") - - self.server_address_list = server_address_list - self.server_port = server_port if server_port is not None else self._server_port_default() - - def _server_port_default(self): - if self.use_ssl: - return 636 # default SSL port for LDAP - else: - return 389 # default plaintext port for LDAP - - async def resolve_username(self, username_supplied_by_user): - import ldap3 - - search_dn = self.lookup_dn_search_user - if self.escape_userdn: - search_dn = ldap3.utils.conv.escape_filter_chars(search_dn) - conn = await asyncio.get_running_loop().run_in_executor( - None, self.get_connection, search_dn, self.lookup_dn_search_password - ) - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) - if not is_bound: - msg = "Failed to connect to LDAP server with search user '{search_dn}'" - self.log.warning(msg.format(search_dn=search_dn)) - return (None, None) - - search_filter = self.lookup_dn_search_filter.format( - login_attr=self.user_attribute, login=username_supplied_by_user - ) - msg = "\n".join( - [ - "Looking up user with:", - " search_base = '{search_base}'", - " search_filter = '{search_filter}'", - " attributes = '{attributes}'", - ] - ) - logger.debug( - msg.format( - search_base=self.user_search_base, - search_filter=search_filter, - attributes=self.user_attribute, - ) - ) - - search_func = functools.partial( - conn.search, - search_base=self.user_search_base, - search_scope=ldap3.SUBTREE, - search_filter=search_filter, - attributes=[self.lookup_dn_user_dn_attribute], - ) - await asyncio.get_running_loop().run_in_executor(None, search_func) - - response = conn.response - if len(response) == 0 or "attributes" not in response[0].keys(): - msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" - logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) - return (None, None) - - user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] - if isinstance(user_dn, list): - if len(user_dn) == 0: - return (None, None) - elif len(user_dn) == 1: - user_dn = user_dn[0] - else: - msg = ( - "A lookup of the username '{username}' returned a list " - "of entries for the attribute '{attribute}'. Only the " - "first among these ('{first_entry}') was used. The other " - "entries ({other_entries}) were ignored." - ) - logger.warning( - msg.format( - username=username_supplied_by_user, - attribute=self.lookup_dn_user_dn_attribute, - first_entry=user_dn[0], - other_entries=", ".join(user_dn[1:]), - ) - ) - user_dn = user_dn[0] - - return (user_dn, response[0]["dn"]) - - def get_connection(self, userdn, password): - import ldap3 - - # NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool. - # It probably does not matter if the pool contains only one server, but it could have implications - # when there are multiple servers in the pool. It is not clear what those implications are. - # But using the default 'activate=True' results in the thread being blocked indefinitely - # at the step of creating 'ldap3.Connection' regardless of timeouts in case all the servers are - # inactive (e.g. the pool has one server and it is unaccessible), which is unacceptable. - # Further investigation may be needed in the future. - server_pool = ldap3.ServerPool(None, ldap3.RANDOM, active=False) - for address in self.server_address_list: - if re.search(r".+:\d+", address): - # Port is found in the address - address_split = address.split(":") - server_addr = ":".join(address_split[:-1]) - server_port = int(address_split[-1]) - else: - # Use the default port - server_addr = address - server_port = self.server_port - - server = ldap3.Server( - server_addr, port=server_port, use_ssl=self.use_ssl, connect_timeout=self.connect_timeout - ) - server_pool.add(server) - - auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS - auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl - conn = ldap3.Connection( - server_pool, user=userdn, password=password, auto_bind=auto_bind, receive_timeout=self.receive_timeout - ) - return conn - - async def get_user_attributes(self, conn, userdn): - attrs = {} - if self.auth_state_attributes: - search_func = functools.partial( - conn.search, userdn, "(objectClass=*)", attributes=self.auth_state_attributes - ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) - if found: - attrs = conn.entries[0].entry_attributes_as_dict - return attrs - - async def authenticate(self, username: str, password: str): - import ldap3 - - username_saved = username # Save the user name passed as a parameter - - # Protect against invalid usernames as well as LDAP injection attacks - if not re.match(self.valid_username_regex, username): - logger.warning( - "username:%s Illegal characters in username, must match regex %s", - username, - self.valid_username_regex, - ) - return None - - # No empty passwords! - if password is None or password.strip() == "": - logger.warning("username:%s Login denied for blank password", username) - return None - - # bind_dn_template should be of type List[str] - bind_dn_template = self.bind_dn_template - if isinstance(bind_dn_template, str): - bind_dn_template = [bind_dn_template] - - # sanity check - if not self.lookup_dn and not bind_dn_template: - logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") - return None - - if self.lookup_dn: - username, resolved_dn = await self.resolve_username(username) - if not username: - return None - if str(self.lookup_dn_user_dn_attribute).upper() == "CN": - # Only escape commas if the lookup attribute is CN - username = re.subn(r"([^\\]),", r"\1\,", username)[0] - if not bind_dn_template: - bind_dn_template = [resolved_dn] - - is_bound = False - for dn in bind_dn_template: - if not dn: - logger.warning("Ignoring blank 'bind_dn_template' entry!") - continue - userdn = dn.format(username=username) - if self.escape_userdn: - userdn = ldap3.utils.conv.escape_filter_chars(userdn) - msg = "Attempting to bind {username} with {userdn}" - logger.debug(msg.format(username=username, userdn=userdn)) - msg = "Status of user bind {username} with {userdn} : {is_bound}" - try: - conn = await asyncio.get_running_loop().run_in_executor( - None, self.get_connection, userdn, password - ) - except ldap3.core.exceptions.LDAPBindError as exc: - is_bound = False - msg += "\n{exc_type}: {exc_msg}".format( - exc_type=exc.__class__.__name__, - exc_msg=exc.args[0] if exc.args else "", - ) - else: - if conn.bound: - is_bound = True - else: - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) - - msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) - logger.debug(msg) - if is_bound: - break - - if not is_bound: - msg = "Invalid password for user '{username}'" - logger.warning(msg.format(username=username)) - return None - - if self.search_filter: - search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) - - search_func = functools.partial( - conn.search, - search_base=self.user_search_base, - search_scope=ldap3.SUBTREE, - search_filter=search_filter, - attributes=self.attributes, - ) - await asyncio.get_running_loop().run_in_executor(None, search_func) - - n_users = len(conn.response) - if n_users == 0: - msg = "User with '{userattr}={username}' not found in directory" - logger.warning(msg.format(userattr=self.user_attribute, username=username)) - return None - if n_users > 1: - msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" - logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) - return None - - if self.allowed_groups: - logger.debug("username:%s Using dn %s", username, userdn) - found = False - for group in self.allowed_groups: - group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" - group_filter = group_filter.format(userdn=userdn, uid=username) - group_attributes = ["member", "uniqueMember", "memberUid"] - - search_func = functools.partial( - conn.search, - group, - search_scope=ldap3.BASE, - search_filter=group_filter, - attributes=group_attributes, - ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) - if found: - break - - if not found: - # If we reach here, then none of the groups matched - msg = "username:{username} User not in any of the allowed groups" - logger.warning(msg.format(username=username)) - return None - - if not self.use_lookup_dn_username: - username = username_saved - - user_info = await self.get_user_attributes(conn, userdn) - if user_info: - logger.debug("username:%s attributes:%s", username, user_info) - return {"name": username, "auth_state": user_info} - return username +import warnings + +from bluesky_authentication.authenticators import ( # noqa: F401 + DictionaryAuthenticator, + DummyAuthenticator, + EntraAuthenticator, + LDAPAuthenticator, + OIDCAuthenticator, + PAMAuthenticator, + ProxiedOIDCAuthenticator, + SAMLAuthenticator, +) +from bluesky_authentication.protocols import ( # noqa: F401 + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +warnings.warn( + "Importing authenticators from 'bluesky_httpserver.authenticators' is deprecated " + "and will be removed in a future release. Use 'bluesky_authentication.authenticators' " + "and 'bluesky_authentication.protocols' instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "DictionaryAuthenticator", + "DummyAuthenticator", + "EntraAuthenticator", + "ExternalAuthenticator", + "InternalAuthenticator", + "LDAPAuthenticator", + "OIDCAuthenticator", + "PAMAuthenticator", + "ProxiedOIDCAuthenticator", + "SAMLAuthenticator", + "UserSessionState", +] diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index 57343f7..e7e5148 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -83,8 +83,10 @@ properties: description: | Type of Authenticator to use. - These are typically from the tiled.authenticators module, + These are typically from the bluesky_authentication.authenticators module, though user-defined ones may be used as well. + Legacy import paths under bluesky_httpserver.authenticators remain + supported for backward compatibility. This is given as an import path. In an import path, packages/modules are separated by dots, and the object itself it separated by a colon. @@ -92,36 +94,21 @@ properties: Example: ```yaml - authenticator: bluesky_httpserver.examples.DummyAuthenticator + authenticator: bluesky_authentication.authenticators:DummyAuthenticator ``` - args: - type: [object, "null"] - description: | - Named arguments to pass to Authenticator. If there are none, - `args` may be omitted or empty. + args: + type: object + description: | + Named arguments to pass to Authenticator. If there are none, + `args` may be omitted or empty. - Example: + Example: - ```yaml - authenticator: bluesky_httpserver.examples.PAMAuthenticator - args: - service: "custom_service" - ``` - # qserver_admins: - # type: array - # items: - # type: object - # additionalProperties: false - # required: - # - provider - # - id - # properties: - # provider: - # type: string - # id: - # type: string - # description: | - # Give users with these identities 'admin' Role. + ```yaml + authenticator: bluesky_authentication.authenticators:PAMAuthenticator + args: + service: "custom_service" + ``` secret_keys: type: array items: diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index 163fac3..a394fdd 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -1,6 +1,7 @@ import hashlib import uuid as uuid_module from datetime import datetime +from typing import Optional from alembic import command from alembic.config import Config @@ -10,13 +11,13 @@ from .alembic_utils import temp_alembic_ini from .base import Base -from .orm import APIKey, Identity, Principal, Session # , Role +from .orm import APIKey, Identity, PendingSession, Principal, Session # , Role # This is the alembic revision ID of the database revision # required by this version of Tiled. -REQUIRED_REVISION = "722ff4e4fcc7" +REQUIRED_REVISION = "a1b2c3d4e5f6" # This is list of all valid revisions (from current to oldest). -ALL_REVISIONS = ["722ff4e4fcc7", "481830dd6c11"] +ALL_REVISIONS = ["a1b2c3d4e5f6", "722ff4e4fcc7", "481830dd6c11"] # def create_default_roles(engine): @@ -208,6 +209,15 @@ def create_user(db, identity_provider, id): return principal +def get_or_create_principal(db, identity_provider, id): + identity = db.query(Identity).filter(Identity.id == id).filter(Identity.provider == identity_provider).first() + if identity is None: + principal = create_user(db, identity_provider, id) + else: + principal = identity.principal + return principal + + def lookup_valid_session(db, session_id): if isinstance(session_id, int): # Old versions of tiled used an integer sid. @@ -294,3 +304,38 @@ def latest_principal_activity(db, principal): if all([t is None for t in all_activity]): return None return max(t for t in all_activity if t is not None) + + +def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optional[PendingSession]: + """ + Look up a pending session by its device code. + + Returns None if the pending session is not found or has expired. + """ + hashed_device_code = hashlib.sha256(device_code).digest() + pending_session = ( + db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + ) + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session + + +def lookup_valid_pending_session_by_user_code(db, user_code: str) -> Optional[PendingSession]: + """ + Look up a pending session by its user code. + + Returns None if the pending session is not found or has expired. + """ + pending_session = db.query(PendingSession).filter(PendingSession.user_code == user_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session diff --git a/bluesky_httpserver/database/orm.py b/bluesky_httpserver/database/orm.py index 17d7c82..7611824 100644 --- a/bluesky_httpserver/database/orm.py +++ b/bluesky_httpserver/database/orm.py @@ -181,3 +181,24 @@ class Session(Timestamped, Base): revoked = Column(Boolean, default=False, nullable=False) principal = relationship("Principal", back_populates="sessions") + pending_sessions = relationship("PendingSession", back_populates="session") + + +class PendingSession(Timestamped, Base): + """ + This is used only in Device Code Flow for OIDC authentication. + + When a CLI client initiates the device code flow, a pending session is created + with a device_code (for the client to poll) and a user_code (for the user to + enter in the browser). Once the user authenticates, the pending session is + linked to a real session, which the polling client then receives. + """ + + __tablename__ = "pending_sessions" + + hashed_device_code = Column(LargeBinary(32), primary_key=True, index=True, nullable=False) + user_code = Column(Unicode(8), index=True, nullable=False) + expiration_time = Column(DateTime(timezone=False), nullable=False) + session_id = Column(Integer, ForeignKey("sessions.id"), nullable=True) + + session = relationship("Session", back_populates="pending_sessions") diff --git a/bluesky_httpserver/routers/core_api.py b/bluesky_httpserver/routers/core_api.py index 397972b..b70ca46 100644 --- a/bluesky_httpserver/routers/core_api.py +++ b/bluesky_httpserver/routers/core_api.py @@ -1140,12 +1140,13 @@ def is_alive(self): @router.websocket("/console_output/ws") async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.console_output_stream.add_queue(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() @@ -1166,12 +1167,13 @@ async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): @router.websocket("/status/ws") async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.system_info_stream.add_queue_status(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() @@ -1193,12 +1195,13 @@ async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): @router.websocket("/info/ws") async def info_ws(websocket: WebSocket, scopes=["read:monitor"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.system_info_stream.add_queue_info(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py index c52d8f2..05a1764 100644 --- a/bluesky_httpserver/schemas.py +++ b/bluesky_httpserver/schemas.py @@ -163,6 +163,23 @@ class RefreshToken(pydantic.BaseModel): refresh_token: str +class DeviceCode(pydantic.BaseModel): + """Schema for device code token polling request.""" + + device_code: str + + +class DeviceCodeResponse(pydantic.BaseModel): + """Schema for device code flow initiation response.""" + + authorization_uri: str + verification_uri: str + device_code: str + user_code: str + expires_in: int + interval: int + + class AuthenticationMode(str, enum.Enum): password = "password" external = "external" @@ -173,6 +190,7 @@ class AboutAuthenticationProvider(pydantic.BaseModel): mode: AuthenticationMode links: Dict[str, str] confirmation_message: Optional[str] = None + extra_scopes: Optional[List[str]] = None class AboutAuthenticationLinks(pydantic.BaseModel): diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index ec69415..8a81df9 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -4,6 +4,7 @@ import pytest import requests from bluesky_queueserver.manager.comms import zmq_single_request +from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa: F401 from bluesky_queueserver.manager.tests.common import set_qserver_zmq_encoding # noqa: F401 from xprocess import ProcessStarter @@ -18,6 +19,22 @@ _user_group = "primary" +def _wait_for_http_server_ready(*, timeout=10, request_prefix="/api"): + """Wait until HTTP server accepts connections and responds to /status.""" + t_stop = ttime.time() + timeout + url = f"http://{SERVER_ADDRESS}:{SERVER_PORT}{request_prefix}/status" + while ttime.time() < t_stop: + try: + response = requests.get(url, timeout=0.5) + # Any HTTP response means the server is up (auth may still reject request). + if response.status_code: + return + except requests.RequestException: + pass + ttime.sleep(0.1) + raise TimeoutError(f"HTTP server is not ready after {timeout} s: {url}") + + @pytest.fixture(scope="module") def fastapi_server(xprocess): class Starter(ProcessStarter): @@ -29,6 +46,7 @@ class Starter(ProcessStarter): # args = f"start-bluesky-httpserver --host={SERVER_ADDRESS} --port {SERVER_PORT}".split() xprocess.ensure("fastapi_server", Starter) + _wait_for_http_server_ready() yield @@ -43,7 +61,11 @@ def fastapi_server_fs(xprocess): to perform additional steps (such as setting environmental variables) before the server is started. """ - def start(http_server_host=SERVER_ADDRESS, http_server_port=SERVER_PORT, api_key=API_KEY_FOR_TESTS): + def start( + http_server_host=SERVER_ADDRESS, + http_server_port=SERVER_PORT, + api_key=API_KEY_FOR_TESTS, + ): class Starter(ProcessStarter): max_read_lines = 53 @@ -55,7 +77,7 @@ class Starter(ProcessStarter): args = f"uvicorn --host={http_server_host} --port {http_server_port} {bqss.__name__}:app".split() xprocess.ensure("fastapi_server", Starter) - ttime.sleep(1) + _wait_for_http_server_ready() yield start @@ -95,7 +117,12 @@ def add_plans_to_queue(): user_group = _user_group user = "HTTP unit test setup" - plan1 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 10, "delay": 1}, "item_type": "plan"} + plan1 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 10, "delay": 1}, + "item_type": "plan", + } plan2 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} for plan in (plan1, plan2, plan2): resp2, _ = zmq_single_request("queue_item_add", {"item": plan, "user": user, "user_group": user_group}) @@ -103,7 +130,14 @@ def add_plans_to_queue(): def request_to_json( - request_type, path, *, request_prefix="/api", api_key=API_KEY_FOR_TESTS, token=None, login=None, **kwargs + request_type, + path, + *, + request_prefix="/api", + api_key=API_KEY_FOR_TESTS, + token=None, + login=None, + **kwargs, ): if login: auth = None @@ -195,3 +229,28 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES return True return False + + +# ============================================================================ +# OIDC Test Fixtures +# ============================================================================ + + +@pytest.fixture +def oidc_base_url() -> str: + """Base URL for mock OIDC provider.""" + return "https://example.com/realms/example/" + + +@pytest.fixture +def well_known_response(oidc_base_url: str) -> dict: + """Mock OIDC well-known configuration response.""" + return { + "id_token_signing_alg_values_supported": ["RS256"], + "issuer": oidc_base_url.rstrip("/"), + "jwks_uri": f"{oidc_base_url}protocol/openid-connect/certs", + "authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth", + "token_endpoint": f"{oidc_base_url}protocol/openid-connect/token", + "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device", + "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout", + } diff --git a/bluesky_httpserver/tests/test_access_control.py b/bluesky_httpserver/tests/test_access_control.py index e6afdf0..244cf46 100644 --- a/bluesky_httpserver/tests/test_access_control.py +++ b/bluesky_httpserver/tests/test_access_control.py @@ -3,7 +3,7 @@ import pprint import pytest -from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd, re_manager_factory # noqa F401 from bluesky_httpserver.authorization._defaults import ( _DEFAULT_RESOURCE_ACCESS_GROUP, @@ -47,7 +47,7 @@ allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -61,7 +61,7 @@ allow_anonymous_access: False providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -177,7 +177,7 @@ allow_anonymous_access: False providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -688,7 +688,7 @@ def test_authentication_and_authorization_08( authentication: providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_access_policies.py b/bluesky_httpserver/tests/test_access_policies.py index 374d711..627a8b4 100644 --- a/bluesky_httpserver/tests/test_access_policies.py +++ b/bluesky_httpserver/tests/test_access_policies.py @@ -6,7 +6,7 @@ import pytest import requests -from bluesky_queueserver.manager.tests.common import re_manager # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_factory # noqa F401 from xprocess import ProcessStarter from bluesky_httpserver.authorization import ( @@ -465,7 +465,7 @@ async def read_info(): authentication: providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_auth_api.py b/bluesky_httpserver/tests/test_auth_api.py index ad23e5a..117a1e3 100644 --- a/bluesky_httpserver/tests/test_auth_api.py +++ b/bluesky_httpserver/tests/test_auth_api.py @@ -1,7 +1,7 @@ import pprint import time as ttime -from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd, re_manager_factory # noqa F401 from bluesky_httpserver.authorization._defaults import _DEFAULT_ROLES @@ -13,7 +13,7 @@ allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py index 3d26e22..91f69d2 100644 --- a/bluesky_httpserver/tests/test_auth_for_websockets.py +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -4,7 +4,7 @@ import time as ttime import pytest -from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd, re_manager_factory # noqa F401 from websockets.sync.client import connect from .conftest import fastapi_server_fs # noqa: F401 @@ -22,7 +22,7 @@ allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -50,12 +50,13 @@ class _ReceiveSystemInfoSocket(threading.Thread): save messages to the buffer. """ - def __init__(self, *, endpoint, api_key=None, token=None, **kwargs): + def __init__(self, *, endpoint, api_key=None, token=None, auth_message=None, **kwargs): super().__init__(**kwargs) self.received_data_buffer = [] self._exit = False self._api_key = api_key self._token = token + self._auth_message = auth_message self._endpoint = endpoint def run(self): @@ -69,6 +70,8 @@ def run(self): try: with connect(websocket_uri, additional_headers=additional_headers) as websocket: + if self._auth_message is not None: + websocket.send(json.dumps(self._auth_message)) while not self._exit: try: msg_json = websocket.recv(timeout=0.1, decode=False) @@ -94,7 +97,10 @@ def __del__(self): # fmt: off -@pytest.mark.parametrize("ws_auth_type", ["apikey", "apikey_invalid", "none"]) +@pytest.mark.parametrize( + "ws_auth_type", + ["apikey", "apikey_invalid", "token", "token_invalid", "none", "first_message_apikey", "first_message_token"], +) # fmt: on def test_websocket_auth_01( tmpdir, @@ -135,10 +141,14 @@ def test_websocket_auth_01( ws_params = {"api_key": api_key} elif ws_auth_type == "apikey_invalid": ws_params = {"api_key": "InvalidApiKey"} - # elif ws_auth_type == "token": - # ws_params = {"token": token} - # elif ws_auth_type == "token_invalid": - # ws_params = {"token": "InvalidToken"} + elif ws_auth_type == "token": + ws_params = {"token": token} + elif ws_auth_type == "token_invalid": + ws_params = {"token": "InvalidToken"} + elif ws_auth_type == "first_message_apikey": + ws_params = {"auth_message": {"type": "auth", "api_key": api_key}} + elif ws_auth_type == "first_message_token": + ws_params = {"auth_message": {"type": "auth", "access_token": token}} else: assert False, f"Unknown authentication type: {ws_auth_type!r}" @@ -164,7 +174,7 @@ def test_websocket_auth_01( buffer = rsc.received_data_buffer if ws_auth_type in ("none", "apikey_invalid", "token_invalid"): assert len(buffer) == 0 - elif ws_auth_type in ("apikey", "token"): + elif ws_auth_type in ("apikey", "token", "first_message_apikey", "first_message_token"): assert len(buffer) > 0 for msg in buffer: assert "time" in msg, msg diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py deleted file mode 100644 index cc2984c..0000000 --- a/bluesky_httpserver/tests/test_authenticators.py +++ /dev/null @@ -1,43 +0,0 @@ -import asyncio - -import pytest - -# fmt: off -from ..authenticators import LDAPAuthenticator - - -@pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ - ("localhost", 1389), - ("localhost:1389", 904), # Random port, ignored - ("localhost:1389", None), - ("127.0.0.1", 1389), - ("127.0.0.1:1389", 904), - (["localhost"], 1389), - (["localhost", "127.0.0.1"], 1389), - (["localhost", "127.0.0.1:1389"], 1389), - (["localhost:1389", "127.0.0.1:1389"], None), -]) -# fmt: on -@pytest.mark.parametrize("use_tls,use_ssl", [(False, False)]) -def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server_port): - """ - Basic test for ``LDAPAuthenticator``. - - TODO: The test could be extended with enabled TLS or SSL, but it requires configuration - of the LDAP server. - """ - authenticator = LDAPAuthenticator( - ldap_server_address, - ldap_server_port, - bind_dn_template="cn={username},ou=users,dc=example,dc=org", - use_tls=use_tls, - use_ssl=use_ssl, - ) - - async def testing(): - assert await authenticator.authenticate("user01", "password1") == "user01" - assert await authenticator.authenticate("user02", "password2") == "user02" - assert await authenticator.authenticate("user02a", "password2") is None - assert await authenticator.authenticate("user02", "password2a") is None - - asyncio.run(testing()) diff --git a/bluesky_httpserver/tests/test_console_output.py b/bluesky_httpserver/tests/test_console_output.py index 1f089ec..6193db0 100644 --- a/bluesky_httpserver/tests/test_console_output.py +++ b/bluesky_httpserver/tests/test_console_output.py @@ -3,17 +3,16 @@ import re import threading import time as ttime +from typing import Any import pytest import requests -from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa F401 from websockets.sync.client import connect from bluesky_httpserver.tests.conftest import ( # noqa F401 API_KEY_FOR_TESTS, SERVER_ADDRESS, SERVER_PORT, - fastapi_server_fs, request_to_json, set_qserver_zmq_encoding, wait_for_environment_to_be_closed, @@ -36,37 +35,42 @@ def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs): self._api_key = api_key def run(self): - kwargs = {"stream": True} + kwargs: dict[str, Any] = {"stream": True} if self._api_key: - auth = None headers = {"Authorization": f"ApiKey {self._api_key}"} - kwargs.update({"auth": auth, "headers": headers}) + kwargs.update({"headers": headers}) + + kwargs["timeout"] = (5, 1) - with requests.get(f"http://{SERVER_ADDRESS}:{SERVER_PORT}/api/stream_console_output", **kwargs) as r: - r.encoding = "utf-8" + while not self._exit: + try: + with requests.get( + f"http://{SERVER_ADDRESS}:{SERVER_PORT}/api/stream_console_output", + **kwargs, + ) as r: + r.encoding = "utf-8" - characters = [] - n_brackets = 0 + characters = [] + n_brackets = 0 - for ch in r.iter_content(decode_unicode=True): - # Note, that some output must be received from the server before the loop exits - if self._exit: - break + for ch in r.iter_content(decode_unicode=True): + if self._exit: + return - characters.append(ch) - if ch == "{": - n_brackets += 1 - elif ch == "}": - n_brackets -= 1 + characters.append(ch) + if ch == "{": + n_brackets += 1 + elif ch == "}": + n_brackets -= 1 - # If the received buffer ('characters') is not empty and the message contains - # equal number of opening and closing brackets then consider the message complete. - if characters and not n_brackets: - line = "".join(characters) - characters = [] + if characters and not n_brackets: + line = "".join(characters) + characters = [] - print(f"{line}") - self.received_data_buffer.append(json.loads(line)) + print(f"{line}") + self.received_data_buffer.append(json.loads(line)) + except requests.exceptions.ReadTimeout: + continue def stop(self): """ @@ -81,7 +85,10 @@ def __del__(self): @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_stream_console_output_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``stream_console_output`` API @@ -122,7 +129,8 @@ def test_http_server_stream_console_output_1( assert resp2["items"][0] == resp1["item"] assert resp2["running_item"] == {} - rsc.join() + rsc.join(timeout=10) + assert not rsc.is_alive(), "Timed out waiting for stream_console_output thread to terminate" assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) @@ -160,7 +168,11 @@ def test_http_server_stream_console_output_1( @pytest.mark.parametrize("zmq_encoding", (None, "json", "msgpack")) @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port, zmq_encoding # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, + zmq_encoding, # noqa F811 ): """ Test for ``console_output`` API (not a streaming version). @@ -238,7 +250,10 @@ def test_http_server_console_output_1( @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_update_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``console_output`` API (not a streaming version). @@ -379,7 +394,10 @@ def __del__(self): @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_socket_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``/console_output/ws`` websocket @@ -421,7 +439,8 @@ def test_http_server_console_output_socket_1( assert resp2["items"][0] == resp1["item"] assert resp2["running_item"] == {} - rsc.join() + rsc.join(timeout=10) + assert not rsc.is_alive(), "Timed out waiting for console_output websocket thread to terminate" assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) diff --git a/bluesky_httpserver/tests/test_core_api_fs.py b/bluesky_httpserver/tests/test_core_api_fs.py index a00d99a..ace9205 100644 --- a/bluesky_httpserver/tests/test_core_api_fs.py +++ b/bluesky_httpserver/tests/test_core_api_fs.py @@ -7,6 +7,7 @@ copy_default_profile_collection, re_manager, re_manager_cmd, + re_manager_factory, re_manager_pc_copy, set_qserver_zmq_address, set_qserver_zmq_public_key, diff --git a/bluesky_httpserver/tests/test_core_api_main.py b/bluesky_httpserver/tests/test_core_api_main.py index b2b5140..815669e 100644 --- a/bluesky_httpserver/tests/test_core_api_main.py +++ b/bluesky_httpserver/tests/test_core_api_main.py @@ -11,6 +11,7 @@ ip_kernel_simple_client, re_manager, re_manager_cmd, + re_manager_factory, re_manager_pc_copy, ) @@ -30,8 +31,17 @@ # Plans used in most of the tests: '_plan1' and '_plan2' are quickly executed '_plan3' runs for 5 seconds. _plan1 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} -_plan2 = {"name": "scan", "args": [["det1", "det2"], "motor", -1, 1, 10], "item_type": "plan"} -_plan3 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 5, "delay": 1}, "item_type": "plan"} +_plan2 = { + "name": "scan", + "args": [["det1", "det2"], "motor", -1, 1, 10], + "item_type": "plan", +} +_plan3 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 5, "delay": 1}, + "item_type": "plan", +} _instruction_stop = {"name": "queue_stop", "item_type": "instruction"} @@ -515,8 +525,10 @@ def test_http_server_queue_item_update_2_fail(re_manager, fastapi_server, replac resp2 = request_to_json("post", "/queue/item/update", json=params) assert resp2["success"] is False - assert resp2["msg"] == "Failed to add an item: Failed to replace item: " \ - "Item with UID 'incorrect_uid' is not in the queue" + assert ( + resp2["msg"] == "Failed to add an item: Failed to replace item: " + "Item with UID 'incorrect_uid' is not in the queue" + ) resp3 = request_to_json("get", "/queue/get") assert resp3["items"] != [] @@ -1286,16 +1298,33 @@ def test_http_server_history_clear(re_manager, fastapi_server, clear_params, exp def test_http_server_manager_kill(re_manager, fastapi_server): # noqa F811 + timeout_variants = ( + "Request timeout: ZMQ communication error: timeout occurred", + "Request timeout: ZMQ communication error: Resource temporarily unavailable", + ) + request_to_json("post", "/environment/open") assert wait_for_environment_to_be_created(10), "Timeout" resp = request_to_json("post", "/test/manager/kill") assert "success" not in resp - assert "Request timeout: ZMQ communication error: timeout occurred" in resp["detail"] - - ttime.sleep(10) + assert any(_ in resp["detail"] for _ in timeout_variants) + + deadline = ttime.time() + 20 + last_status = None + while ttime.time() < deadline: + ttime.sleep(0.2) + last_status = request_to_json("get", "/status") + if ( + isinstance(last_status, dict) + and last_status.get("manager_state") == "idle" + and last_status.get("worker_environment_exists") is True + ): + break + else: + assert False, f"Timeout while waiting for manager recovery after kill. Last status: {last_status!r}" - resp = request_to_json("get", "/status") + resp = last_status assert resp["msg"].startswith("RE Manager") assert resp["manager_state"] == "idle" assert resp["items_in_queue"] == 0 diff --git a/bluesky_httpserver/tests/test_server.py b/bluesky_httpserver/tests/test_server.py index 117f4df..10f0d6c 100644 --- a/bluesky_httpserver/tests/test_server.py +++ b/bluesky_httpserver/tests/test_server.py @@ -8,6 +8,7 @@ copy_default_profile_collection, re_manager, re_manager_cmd, + re_manager_factory, re_manager_pc_copy, set_qserver_zmq_address, set_qserver_zmq_public_key, @@ -27,8 +28,17 @@ # Plans used in most of the tests: '_plan1' and '_plan2' are quickly executed '_plan3' runs for 5 seconds. _plan1 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} -_plan2 = {"name": "scan", "args": [["det1", "det2"], "motor", -1, 1, 10], "item_type": "plan"} -_plan3 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 5, "delay": 1}, "item_type": "plan"} +_plan2 = { + "name": "scan", + "args": [["det1", "det2"], "motor", -1, 1, 10], + "item_type": "plan", +} +_plan3 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 5, "delay": 1}, + "item_type": "plan", +} _config_public_key = """ @@ -122,7 +132,7 @@ def test_http_server_secure_1(monkeypatch, tmpdir, re_manager_cmd, fastapi_serve @pytest.mark.parametrize("option", ["ev", "cfg_file", "both"]) # fmt: on def test_http_server_set_zmq_address_1( - monkeypatch, tmpdir, re_manager_cmd, fastapi_server_fs, option # noqa: F811 + monkeypatch, tmpdir, re_manager_cmd, fastapi_server_fs, free_tcp_port_factory, option # noqa: F811 ): """ Test if ZMQ address of RE Manager is passed to the HTTP server using 'QSERVER_ZMQ_ADDRESS_CONTROL' @@ -130,11 +140,12 @@ def test_http_server_set_zmq_address_1( channel different from default address, add and execute a plan. """ - # Change ZMQ address to use port 60616 instead of the default port 60615. - zmq_control_address_server = "tcp://*:60616" - zmq_info_address_server = "tcp://*:60617" - zmq_control_address = "tcp://localhost:60616" - zmq_info_address = "tcp://localhost:60617" + zmq_control_port = free_tcp_port_factory() + zmq_info_port = free_tcp_port_factory() + zmq_control_address_server = f"tcp://*:{zmq_control_port}" + zmq_info_address_server = f"tcp://*:{zmq_info_port}" + zmq_control_address = f"tcp://localhost:{zmq_control_port}" + zmq_info_address = f"tcp://localhost:{zmq_info_port}" if option == "ev": monkeypatch.setenv("QSERVER_ZMQ_CONTROL_ADDRESS", zmq_control_address) monkeypatch.setenv("QSERVER_ZMQ_INFO_ADDRESS", zmq_info_address) diff --git a/continuous_integration/docker-configs/ldap-docker-compose.yml b/continuous_integration/docker-configs/ldap-docker-compose.yml new file mode 100644 index 0000000..5fbfc53 --- /dev/null +++ b/continuous_integration/docker-configs/ldap-docker-compose.yml @@ -0,0 +1,16 @@ +services: + openldap: + image: osixia/openldap:1.5.0 + ports: + - '1389:389' + - '1636:636' + environment: + - LDAP_ORGANISATION=Example Inc. + - LDAP_DOMAIN=example.org + - LDAP_ADMIN_PASSWORD=adminpassword + volumes: + - 'openldap_data:/var/lib/ldap' + +volumes: + openldap_data: + driver: local diff --git a/continuous_integration/dockerfiles/test.Dockerfile b/continuous_integration/dockerfiles/test.Dockerfile new file mode 100644 index 0000000..2e994cf --- /dev/null +++ b/continuous_integration/dockerfiles/test.Dockerfile @@ -0,0 +1,29 @@ +ARG PYTHON_VERSION=3.13 +FROM python:${PYTHON_VERSION}-slim + +ENV PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bash \ + build-essential \ + git \ + redis-server \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY requirements.txt requirements-dev.txt ./ +COPY pyproject.toml setup.py setup.cfg MANIFEST.in versioneer.py README.rst AUTHORS.rst LICENSE ./ +COPY bluesky_httpserver ./bluesky_httpserver + +RUN python -m pip install --upgrade pip setuptools wheel numpy && \ + python -m pip install git+https://github.com/bluesky/bluesky-queueserver.git && \ + python -m pip install git+https://github.com/bluesky/bluesky-queueserver-api.git && \ + python -m pip install -r requirements-dev.txt && \ + python -m pip install . + +COPY scripts/docker/run_shard_in_container.sh /usr/local/bin/run_shard_in_container.sh +RUN chmod +x /usr/local/bin/run_shard_in_container.sh + +ENTRYPOINT ["/usr/local/bin/run_shard_in_container.sh"] diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh new file mode 100755 index 0000000..d2bd48d --- /dev/null +++ b/continuous_integration/scripts/start_LDAP.sh @@ -0,0 +1,195 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +COMPOSE_FILE="${LDAP_COMPOSE_FILE:-$ROOT_DIR/continuous_integration/docker-configs/ldap-docker-compose.yml}" +COMPOSE_PROJECT="${LDAP_COMPOSE_PROJECT:-}" +LDAP_HOST="${LDAP_HOST:-127.0.0.1}" +LDAP_PORT="${LDAP_PORT:-1389}" +LDAP_ADMIN_DN="cn=admin,dc=example,dc=org" +LDAP_ADMIN_PASSWORD="adminpassword" +LDAP_BASE_DN="dc=example,dc=org" + +compose_cmd() { + if [[ -n "$COMPOSE_PROJECT" ]]; then + docker compose -p "$COMPOSE_PROJECT" -f "$COMPOSE_FILE" "$@" + else + docker compose -f "$COMPOSE_FILE" "$@" + fi +} + +get_openldap_container_id() { + compose_cmd ps -q openldap | tr -d '[:space:]' +} + +wait_for_ldap() { + local timeout_seconds="${1:-60}" + local deadline=$((SECONDS + timeout_seconds)) + + while (( SECONDS < deadline )); do + if python - </dev/null 2>&1 +import socket + +with socket.create_connection(("${LDAP_HOST}", ${LDAP_PORT}), timeout=1): + pass +PY + then + return 0 + fi + sleep 1 + done + + return 1 +} + +wait_for_ldap_bind() { + local container_id="$1" + local timeout_seconds="${2:-60}" + local deadline=$((SECONDS + timeout_seconds)) + local rc=0 + + while (( SECONDS < deadline )); do + rc=0 + docker exec "$container_id" ldapsearch \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" \ + -b "$LDAP_BASE_DN" \ + -s base \ + "(objectclass=*)" dn >/dev/null 2>&1 || rc=$? + if [[ "$rc" -eq 0 ]]; then + return 0 + fi + sleep 1 + done + + return 1 +} + +wait_for_ldap_test_user_bind() { + local container_id="$1" + local timeout_seconds="${2:-60}" + local deadline=$((SECONDS + timeout_seconds)) + local rc=0 + + while (( SECONDS < deadline )); do + rc=0 + docker exec "$container_id" ldapwhoami \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "cn=user01,ou=users,$LDAP_BASE_DN" \ + -w "password1" >/dev/null 2>&1 || rc=$? + if [[ "$rc" -eq 0 ]]; then + return 0 + fi + sleep 1 + done + + return 1 +} + +print_ldap_diagnostics() { + local container_id="${1:-}" + + echo "LDAP startup diagnostics:" >&2 + compose_cmd ps >&2 || true + + if [[ -z "$container_id" ]]; then + container_id="$(get_openldap_container_id)" + fi + + if [[ -n "$container_id" ]]; then + docker logs --tail 200 "$container_id" >&2 || true + else + compose_cmd logs --tail 200 openldap >&2 || true + fi +} + +ldap_entry_exists() { + local container_id="$1" + local dn="$2" + + docker exec "$container_id" ldapsearch \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" \ + -b "$dn" \ + -s base \ + "(objectclass=*)" dn >/dev/null 2>&1 +} + +ldap_add_if_missing() { + local container_id="$1" + local dn="$2" + local ldif="$3" + + if ldap_entry_exists "$container_id" "$dn"; then + return 0 + fi + + docker exec -i "$container_id" ldapadd \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" >/dev/null <&2 + print_ldap_diagnostics + exit 1 +fi + +if ! wait_for_ldap 120; then + echo "LDAP port ${LDAP_HOST}:${LDAP_PORT} did not become reachable in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +echo "LDAP port ${LDAP_HOST}:${LDAP_PORT} is reachable. Waiting for slapd initialization..." +sleep 3 + +if ! wait_for_ldap_bind "$CONTAINER_ID" 120; then + echo "LDAP admin bind did not become ready in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +seed_ldap_test_users "$CONTAINER_ID" + +if ! wait_for_ldap_test_user_bind "$CONTAINER_ID" 60; then + echo "LDAP test-user bind did not become ready in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +docker ps diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index eb31efa..989fd31 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -14,6 +14,12 @@ and allow high level of customization of functionality, using configuration YML allows greater flexibility and is considered a preferable way of configuring the server in production deployments. +.. note:: + + Canonical authenticator import paths are in ``bluesky_authentication.authenticators``. + Legacy paths in ``bluesky_httpserver.authenticators`` remain supported for backward + compatibility. + Environment variable for passing the path to server configuration file(s): - ``QSERVER_HTTP_SERVER_CONFIG`` - path to a single YML file or a directory with multiple YML files. @@ -218,7 +224,7 @@ authorization policy and enabled public access:: allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: ${BOB_PASSWORD} @@ -269,7 +275,7 @@ respectively. The configuration does not enable public access. :: authentication: providers: - provider: ldap - authenticator: bluesky_httpserver.authenticators:LDAPAuthenticator + authenticator: bluesky_authentication.authenticators:LDAPAuthenticator args: server_address: localhost server_port: 1389 @@ -294,6 +300,85 @@ See the documentation on ``LDAPAuthenticator`` for more details. authenticators.LDAPAuthenticator +OIDC Authenticator +++++++++++++++++++ + +``OIDCAuthenticator`` integrates the server with third-party OpenID Connect providers +such as Google, Microsoft Entra ID, ORCID and others. The server does not process user +passwords directly: authentication is delegated to the provider and the server validates +the returned OIDC token. + +General setup steps: + +#. Register an application with the OIDC provider. +#. Configure redirect URIs for the provider application. For provider name ``entra`` and + host ``https://your-server.example`` the redirect URIs are: + + - ``https://your-server.example/api/auth/provider/entra/code`` + - ``https://your-server.example/api/auth/provider/entra/device_code`` + +#. Store the client secret in environment variable and reference it in config. +#. Use provider's ``.well-known/openid-configuration`` URL. + +Typical ``well_known_uri`` values: + +- Google: ``https://accounts.google.com/.well-known/openid-configuration`` +- Microsoft Entra ID: ``https://login.microsoftonline.com//v2.0/.well-known/openid-configuration`` +- ORCID: ``https://orcid.org/.well-known/openid-configuration`` + +Example configuration (Microsoft Entra ID):: + + authentication: + providers: + - provider: entra + authenticator: bluesky_authentication.authenticators:OIDCAuthenticator + args: + audience: 00000000-0000-0000-0000-000000000000 + client_id: 00000000-0000-0000-0000-000000000000 + client_secret: ${BSKY_ENTRA_SECRET} + well_known_uri: https://login.microsoftonline.com//v2.0/.well-known/openid-configuration + confirmation_message: "You have logged in successfully." + api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + : + roles: + - admin + - expert + +Example configuration (Google):: + + authentication: + providers: + - provider: google + authenticator: bluesky_authentication.authenticators:OIDCAuthenticator + args: + audience: + client_id: + client_secret: ${BSKY_GOOGLE_SECRET} + well_known_uri: https://accounts.google.com/.well-known/openid-configuration + api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + : + roles: user + +.. note:: + + The name used in ``api_access/args/users`` must match the identity string produced by + the authenticator for your provider configuration. Verify with ``/api/auth/whoami`` after + successful login. + +See the documentation on ``OIDCAuthenticator`` for parameter details. + +.. autosummary:: + :nosignatures: + :toctree: generated + + authenticators.OIDCAuthenticator + Expiration Time for Tokens and Sessions +++++++++++++++++++++++++++++++++++++++ diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 5e1e9b3..07cd1f3 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -106,6 +106,12 @@ API calls with invalid token or API key are rejected even if public access is en Authentication API for Users ============================ +.. note:: + + Canonical authenticator import paths are in ``bluesky_authentication.authenticators``. + Legacy paths in ``bluesky_httpserver.authenticators`` remain supported for backward + compatibility. + Logging into the Server (Requesting Token) ------------------------------------------ @@ -122,7 +128,7 @@ is an example of a config file sets up ``DictionaryAPIAccessControl`` as a provi authentication: providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: ${BOB_PASSWORD} @@ -154,6 +160,61 @@ Then users ``bob``, ``alice`` and ``tom`` can log into the server as :: If authentication is successful, then the server returns access and refresh tokens. +Logging in with OIDC Providers (Google, Entra, ORCID, ...) +----------------------------------------------------------- + +For providers configured with ``OIDCAuthenticator``, use provider-specific endpoints +under ``/api/auth/provider//...``. + +Browser-first flow +****************** + +If you are already in a browser context, open: + +``/api/auth/provider//authorize`` + +This redirects to the OIDC provider login page and then back to the server callback. + +This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting +the authorization URI from the server:: + + http POST http://localhost:60610/api/auth/provider/entra/authorize + +Which will return a token back to the bluesky http server after the user logs in to the provider +in their browser (or automatically if already logged in). The user then gets a token +for the bluesky HTTP server to use for subsequent API requests. This flow can be used +even when using the bluesky queueserver api in a terminal so long as that session can +spawn a browser for the user to log in to the provider. + +CLI/device flow +*************** + +For terminal clients (i.e. no browser possible), start with +``POST /api/auth/provider//authorize``. +The response includes: + +- ``authorization_uri``: open this URL in a browser +- ``verification_uri``: polling endpoint for the terminal client +- ``device_code`` and ``interval``: values for polling + +Example using ``httpie`` (provider ``entra``):: + + http POST http://localhost:60610/api/auth/provider/entra/authorize + +After opening ``authorization_uri`` in a browser and completing provider login, +poll ``verification_uri`` using ``device_code`` until tokens are issued:: + + http POST http://localhost:60610/api/auth/provider/entra/token \ + device_code='' + +When authorization is still pending, the endpoint returns ``authorization_pending``. +When complete, it returns access and refresh tokens. + +.. note:: + + In common same-device flows the callback can complete automatically without manually + typing the user code. Manual code entry remains available as a fallback path. + Generating API Keys ------------------- diff --git a/requirements-dev.txt b/requirements-dev.txt index dd7212a..e47dd72 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,13 +3,16 @@ black codecov coverage +cryptography fastapi[all] flake8 isort pre-commit pytest +pytest-asyncio pytest-xprocess py +respx sphinx ipython numpydoc diff --git a/requirements.txt b/requirements.txt index f465abd..5d234b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ alembic +bluesky-authentication bluesky-queueserver bluesky-queueserver-api +cachetools fastapi +httpx ldap3 orjson pamela diff --git a/scripts/docker/run_shard_in_container.sh b/scripts/docker/run_shard_in_container.sh new file mode 100755 index 0000000..7fc23a7 --- /dev/null +++ b/scripts/docker/run_shard_in_container.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -euo pipefail + +SHARD_GROUP="${SHARD_GROUP:-1}" +SHARD_COUNT="${SHARD_COUNT:-3}" +ARTIFACTS_DIR="${ARTIFACTS_DIR:-/artifacts}" +PYTEST_EXTRA_ARGS="${PYTEST_EXTRA_ARGS:-}" + +mkdir -p "$ARTIFACTS_DIR" + +if [[ "$SHARD_GROUP" -lt 1 || "$SHARD_COUNT" -lt 1 || "$SHARD_GROUP" -gt "$SHARD_COUNT" ]]; then + echo "Invalid shard settings: SHARD_GROUP=$SHARD_GROUP SHARD_COUNT=$SHARD_COUNT" >&2 + exit 2 +fi + +export COVERAGE_FILE="$ARTIFACTS_DIR/.coverage.${SHARD_GROUP}" + +redis-server --save "" --appendonly no --daemonize yes +for _ in $(seq 1 50); do + if redis-cli ping >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! redis-cli ping >/dev/null 2>&1; then + echo "Failed to start redis-server inside container" >&2 + exit 2 +fi + +mapfile -t shard_tests < <( + python - <<'PY' "$SHARD_GROUP" "$SHARD_COUNT" +import glob +import sys + +group = int(sys.argv[1]) +count = int(sys.argv[2]) + +tests = sorted(glob.glob("bluesky_httpserver/tests/test_*.py")) +selected = [path for idx, path in enumerate(tests) if idx % count == (group - 1)] + +for path in selected: + print(path) +PY +) + +if [[ "${#shard_tests[@]}" -eq 0 ]]; then + echo "No tests selected for shard ${SHARD_GROUP}/${SHARD_COUNT}; treating as success." + exit 0 +fi + +pytest_cmd=( + coverage + run + -m + pytest + --junitxml="$ARTIFACTS_DIR/junit.${SHARD_GROUP}.xml" + -vv +) + +if [[ -n "$PYTEST_EXTRA_ARGS" ]]; then + read -r -a extra_args <<< "$PYTEST_EXTRA_ARGS" + pytest_cmd+=("${extra_args[@]}") +fi + +pytest_cmd+=("${shard_tests[@]}") + +set +e +"${pytest_cmd[@]}" +test_status=$? +set -e + +if [[ "$test_status" -eq 5 ]]; then + echo "Pytest collected no tests for shard ${SHARD_GROUP}/${SHARD_COUNT}; treating as success." + test_status=0 +fi + +redis-cli shutdown nosave >/dev/null 2>&1 || true + +exit "$test_status" diff --git a/scripts/run_ci_docker_parallel.sh b/scripts/run_ci_docker_parallel.sh new file mode 100755 index 0000000..efb6594 --- /dev/null +++ b/scripts/run_ci_docker_parallel.sh @@ -0,0 +1,452 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +IMAGE_TAG_BASE="bluesky-httpserver-test:local" +WORKER_COUNT="3" +CHUNK_COUNT="" +PYTHON_VERSIONS="latest" +PYTEST_EXTRA_ARGS="" +ARTIFACTS_DIR="$ROOT_DIR/.docker-test-artifacts" +LDAP_COMPOSE_FILE="$ROOT_DIR/continuous_integration/docker-configs/ldap-docker-compose.yml" +LDAP_COMPOSE_PROJECT="bhs-ci-ldap-parallel-$$" +LDAP_SERVICE_NAME="openldap" +DOCKER_NETWORK_NAME="${LDAP_COMPOSE_PROJECT}_default" + +SUMMARY_TSV="" +SUMMARY_FAIL_LOGS="" +SUMMARY_TXT="" +SUMMARY_JSON="" +TESTS_START_EPOCH="" +TESTS_START_HUMAN="" + +SUPPORTED_PYTHON_VERSIONS=("3.10" "3.11" "3.12" "3.13") + +usage() { + cat <<'EOF' +Run bluesky-httpserver unit tests in Docker with dynamic chunk dispatch and optional Python-version matrix. + +Usage: + scripts/run_ci_docker_parallel.sh [options] + +Options: + --workers N, --worker-count N + Number of concurrent chunk workers (default: 3). + + --chunks N, --chunk-count N + Number of total chunks/splits to execute per Python version. + Default: workers * 3. + + --python-versions VALUE + Python version selection: latest | all | comma-separated list. + Examples: latest, all, 3.12, 3.11,3.13 + Default: latest (currently 3.13). + + --pytest-args "ARGS" + Extra arguments passed to pytest in each chunk. + Example: --pytest-args "-k oidc --maxfail=1" + + --artifacts-dir PATH + Output directory for all artifacts. + Default: .docker-test-artifacts under repository root. + + --image-tag TAG + Base docker image tag. Per-version tags will append -py. + Default: bluesky-httpserver-test:local + + -h, --help + Show this help message. + +Examples: + scripts/run_ci_docker_parallel.sh + scripts/run_ci_docker_parallel.sh --workers 8 --chunks 24 + scripts/run_ci_docker_parallel.sh --python-versions all --workers 8 --chunks 24 + scripts/run_ci_docker_parallel.sh --python-versions 3.11,3.13 --pytest-args "-k test_access_control" +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --workers|--worker-count) + WORKER_COUNT="$2" + shift 2 + ;; + --chunks|--chunk-count) + CHUNK_COUNT="$2" + shift 2 + ;; + --python-versions) + PYTHON_VERSIONS="$2" + shift 2 + ;; + --pytest-args) + PYTEST_EXTRA_ARGS="$2" + shift 2 + ;; + --artifacts-dir) + ARTIFACTS_DIR="$2" + shift 2 + ;; + --image-tag) + IMAGE_TAG_BASE="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + usage + exit 2 + ;; + esac +done + +if [[ "$WORKER_COUNT" -lt 1 ]]; then + echo "WORKER_COUNT must be >= 1" >&2 + exit 2 +fi + +if [[ -z "$CHUNK_COUNT" ]]; then + CHUNK_COUNT=$(( WORKER_COUNT * 3 )) +fi + +if [[ "$CHUNK_COUNT" -lt 1 ]]; then + echo "CHUNK_COUNT must be >= 1" >&2 + exit 2 +fi + +if ! command -v docker >/dev/null 2>&1; then + echo "docker is required but not found in PATH" >&2 + exit 2 +fi + +if ! docker info >/dev/null 2>&1; then + echo "docker daemon is not available" >&2 + exit 2 +fi + +normalize_python_versions() { + local selection="$1" + local raw + local normalized=() + + if [[ "$selection" == "latest" ]]; then + normalized=("3.13") + elif [[ "$selection" == "all" ]]; then + normalized=("${SUPPORTED_PYTHON_VERSIONS[@]}") + else + raw="${selection//,/ }" + read -r -a normalized <<< "$raw" + fi + + if [[ "${#normalized[@]}" -eq 0 ]]; then + echo "PYTHON_VERSIONS selection produced no versions" >&2 + exit 2 + fi + + for version in "${normalized[@]}"; do + if [[ ! " ${SUPPORTED_PYTHON_VERSIONS[*]} " =~ " ${version} " ]]; then + echo "Unsupported Python version '${version}'. Supported: ${SUPPORTED_PYTHON_VERSIONS[*]}" >&2 + exit 2 + fi + done + + echo "${normalized[@]}" +} + +start_services() { + LDAP_COMPOSE_FILE="$LDAP_COMPOSE_FILE" \ + LDAP_COMPOSE_PROJECT="$LDAP_COMPOSE_PROJECT" \ + LDAP_HOST="127.0.0.1" \ + LDAP_PORT="1389" \ + bash "$ROOT_DIR/continuous_integration/scripts/start_LDAP.sh" >/dev/null +} + +stop_services() { + docker compose -p "$LDAP_COMPOSE_PROJECT" -f "$LDAP_COMPOSE_FILE" down -v >/dev/null 2>&1 || true +} + +cleanup() { + stop_services +} + +collect_junit_totals() { + local artifacts_dir="$1" + + python - "$artifacts_dir" <<'PY' +import glob +import os +import sys +import xml.etree.ElementTree as ET + +artifacts_dir = sys.argv[1] +tests = failures = errors = files = 0 + +for path in sorted(glob.glob(os.path.join(artifacts_dir, "junit.*.xml"))): + files += 1 + try: + root = ET.parse(path).getroot() + except Exception: + continue + + if root.tag == "testsuite": + suites = [root] + elif root.tag == "testsuites": + suites = root.findall("testsuite") + else: + suites = [] + + for suite in suites: + tests += int(suite.attrib.get("tests", 0) or 0) + failures += int(suite.attrib.get("failures", 0) or 0) + errors += int(suite.attrib.get("errors", 0) or 0) + +print(f"{tests} {failures} {errors} {files}") +PY +} + +append_summary_row() { + local py_version="$1" + local chunks_total="$2" + local junit_files="$3" + local tests="$4" + local failures="$5" + local errors="$6" + local status="$7" + + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "$py_version" "$chunks_total" "$junit_files" "$tests" "$failures" "$errors" "$status" >> "$SUMMARY_TSV" +} + +write_summary_files() { + local end_epoch end_human elapsed_sec + + if [[ -z "$SUMMARY_TSV" || -z "$SUMMARY_TXT" || -z "$SUMMARY_JSON" ]]; then + return + fi + + if [[ ! -f "$SUMMARY_TSV" ]]; then + return + fi + + end_epoch="$(date +%s)" + end_human="$(date -u +"%Y-%m-%dT%H:%M:%SZ")" + + if [[ -n "$TESTS_START_EPOCH" ]]; then + elapsed_sec=$(( end_epoch - TESTS_START_EPOCH )) + else + elapsed_sec=0 + fi + + { + echo "Test Run Summary" + echo "Start (UTC): ${TESTS_START_HUMAN:-N/A}" + echo "End (UTC): $end_human" + echo "Elapsed: ${elapsed_sec}s" + echo + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "Python" "Status" "Chunks" "JUnit" "Tests" "Failures" "Errors" + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "------" "------" "------" "-----" "-----" "--------" "------" + + if [[ -s "$SUMMARY_TSV" ]]; then + while IFS=$'\t' read -r py_version chunks_total junit_files tests failures errors status; do + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "$py_version" "$status" "$chunks_total" "$junit_files" "$tests" "$failures" "$errors" + done < "$SUMMARY_TSV" + else + echo "No per-version summary rows were recorded." + fi + + if [[ -s "$SUMMARY_FAIL_LOGS" ]]; then + echo + echo "Failed Chunk Logs" + cat "$SUMMARY_FAIL_LOGS" + fi + } > "$SUMMARY_TXT" + + python - "$SUMMARY_TSV" "$SUMMARY_FAIL_LOGS" "$SUMMARY_JSON" "${TESTS_START_HUMAN:-N/A}" "$end_human" "$elapsed_sec" <<'PY' +import json +import sys + +summary_tsv, fail_logs_path, output_path, start_utc, end_utc, elapsed_sec = sys.argv[1:] + +rows = [] +with open(summary_tsv) as f: + for line in f: + parts = line.rstrip("\n").split("\t") + if len(parts) != 7: + continue + py_version, chunks_total, junit_files, tests, failures, errors, status = parts + rows.append( + { + "python_version": py_version, + "status": status, + "chunks_total": int(chunks_total), + "junit_files": int(junit_files), + "tests": int(tests), + "failures": int(failures), + "errors": int(errors), + } + ) + +failed_logs = [] +with open(fail_logs_path) as f: + failed_logs = [line.strip() for line in f if line.strip()] + +payload = { + "start_utc": start_utc, + "end_utc": end_utc, + "elapsed_seconds": int(elapsed_sec), + "python_versions": rows, + "failed_chunk_logs": failed_logs, +} + +with open(output_path, "w") as f: + json.dump(payload, f, indent=2) + f.write("\n") +PY + + echo "==> Test run end time (UTC): $end_human" + echo "==> Test run elapsed: ${elapsed_sec}s" + echo "==> Summary written: $SUMMARY_TXT" + echo "==> Summary JSON: $SUMMARY_JSON" +} + +on_exit() { + local exit_code=$? + write_summary_files || true + cleanup + trap - EXIT + exit "$exit_code" +} + +trap on_exit EXIT + +read -r -a SELECTED_PYTHON_VERSIONS <<< "$(normalize_python_versions "$PYTHON_VERSIONS")" + +echo "==> Preparing artifacts directory: $ARTIFACTS_DIR" +rm -rf "$ARTIFACTS_DIR" +mkdir -p "$ARTIFACTS_DIR" + +SUMMARY_TSV="$ARTIFACTS_DIR/.summary_rows.tsv" +SUMMARY_FAIL_LOGS="$ARTIFACTS_DIR/.summary_fail_logs.txt" +SUMMARY_TXT="$ARTIFACTS_DIR/summary.txt" +SUMMARY_JSON="$ARTIFACTS_DIR/summary.json" + +: > "$SUMMARY_TSV" +: > "$SUMMARY_FAIL_LOGS" + +echo "==> Starting shared services (LDAP)" +start_services + +TESTS_START_EPOCH="$(date +%s)" +TESTS_START_HUMAN="$(date -u +"%Y-%m-%dT%H:%M:%SZ")" +echo "==> Test run start time (UTC): $TESTS_START_HUMAN" +echo "==> Python versions selected: ${SELECTED_PYTHON_VERSIONS[*]}" + +run_chunk() { + local group="$1" + local log_file="$CURRENT_ARTIFACTS_DIR/shard.${group}.log" + + if docker run --rm \ + --network "$DOCKER_NETWORK_NAME" \ + -e SHARD_GROUP="$group" \ + -e SHARD_COUNT="$CHUNK_COUNT" \ + -e ARTIFACTS_DIR="/artifacts" \ + -e PYTEST_EXTRA_ARGS="$PYTEST_EXTRA_ARGS" \ + -e QSERVER_TEST_LDAP_HOST="$LDAP_SERVICE_NAME" \ + -e QSERVER_TEST_LDAP_PORT="389" \ + -e QSERVER_TEST_REDIS_ADDR="localhost" \ + -e QSERVER_HTTP_TEST_BIND_HOST="127.0.0.1" \ + -e QSERVER_HTTP_TEST_HOST="127.0.0.1" \ + -v "$CURRENT_ARTIFACTS_DIR:/artifacts" \ + "$CURRENT_IMAGE_TAG" >"$log_file" 2>&1; then + : > "$CURRENT_ARTIFACTS_DIR/.status.${group}.ok" + else + : > "$CURRENT_ARTIFACTS_DIR/.status.${group}.fail" + exit 1 + fi +} + +export -f run_chunk +export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_SERVICE_NAME + +for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do + CURRENT_IMAGE_TAG="${IMAGE_TAG_BASE}-py${PYTHON_VERSION}" + CURRENT_ARTIFACTS_DIR="$ARTIFACTS_DIR/py${PYTHON_VERSION}" + export CURRENT_IMAGE_TAG CURRENT_ARTIFACTS_DIR + + echo "==> Building test image: $CURRENT_IMAGE_TAG (Python $PYTHON_VERSION)" + docker build \ + --build-arg PYTHON_VERSION="$PYTHON_VERSION" \ + -f "$ROOT_DIR/continuous_integration/dockerfiles/test.Dockerfile" \ + -t "$CURRENT_IMAGE_TAG" \ + "$ROOT_DIR" + + mkdir -p "$CURRENT_ARTIFACTS_DIR" + + echo "==> [Python $PYTHON_VERSION] Starting dynamic dispatch: $WORKER_COUNT workers over $CHUNK_COUNT chunks" + if ! seq 1 "$CHUNK_COUNT" | xargs -P "$WORKER_COUNT" -I {} bash -lc 'run_chunk "$1"' _ {}; then + echo "One or more chunks failed for Python $PYTHON_VERSION." >&2 + read -r TOTAL_TESTS TOTAL_FAILURES TOTAL_ERRORS TOTAL_JUNIT_FILES < <(collect_junit_totals "$CURRENT_ARTIFACTS_DIR") + for group in $(seq 1 "$CHUNK_COUNT"); do + if [[ -f "$CURRENT_ARTIFACTS_DIR/.status.${group}.fail" ]]; then + echo "Chunk $group failed. Log: $CURRENT_ARTIFACTS_DIR/shard.${group}.log" >&2 + echo "$CURRENT_ARTIFACTS_DIR/shard.${group}.log" >> "$SUMMARY_FAIL_LOGS" + fi + done + append_summary_row "py${PYTHON_VERSION}" "$CHUNK_COUNT" "$TOTAL_JUNIT_FILES" \ + "$TOTAL_TESTS" "$TOTAL_FAILURES" "$TOTAL_ERRORS" "FAIL" + exit 1 + fi + + for group in $(seq 1 "$CHUNK_COUNT"); do + if [[ -f "$CURRENT_ARTIFACTS_DIR/.status.${group}.ok" ]]; then + echo "[Python $PYTHON_VERSION] Chunk $group completed successfully" + fi + done + + rm -f "$CURRENT_ARTIFACTS_DIR"/.status.*.ok "$CURRENT_ARTIFACTS_DIR"/.status.*.fail + + echo "==> [Python $PYTHON_VERSION] Merging coverage artifacts" + docker run --rm \ + --entrypoint bash \ + -v "$CURRENT_ARTIFACTS_DIR:/artifacts" \ + "$CURRENT_IMAGE_TAG" \ + -lc "set -euo pipefail; \ + python -m coverage combine /artifacts/.coverage.* && \ + python -m coverage xml -o /artifacts/coverage.xml && \ + python -m coverage report -m > /artifacts/coverage.txt" + + if [[ "${#SELECTED_PYTHON_VERSIONS[@]}" -eq 1 ]]; then + cp "$CURRENT_ARTIFACTS_DIR/coverage.xml" "$ROOT_DIR/coverage.xml" + else + cp "$CURRENT_ARTIFACTS_DIR/coverage.xml" "$ROOT_DIR/coverage.py${PYTHON_VERSION}.xml" + fi + + read -r TOTAL_TESTS TOTAL_FAILURES TOTAL_ERRORS TOTAL_JUNIT_FILES < <(collect_junit_totals "$CURRENT_ARTIFACTS_DIR") + echo "==> [Python $PYTHON_VERSION] JUnit summary: tests=$TOTAL_TESTS failures=$TOTAL_FAILURES errors=$TOTAL_ERRORS files=$TOTAL_JUNIT_FILES" + + VERSION_STATUS="PASS" + if [[ "$TOTAL_FAILURES" -gt 0 || "$TOTAL_ERRORS" -gt 0 ]]; then + VERSION_STATUS="FAIL" + fi + + append_summary_row "py${PYTHON_VERSION}" "$CHUNK_COUNT" "$TOTAL_JUNIT_FILES" \ + "$TOTAL_TESTS" "$TOTAL_FAILURES" "$TOTAL_ERRORS" "$VERSION_STATUS" +done + +echo "==> Completed. Artifacts:" +echo " versioned logs : $ARTIFACTS_DIR/py/shard..log" +echo " versioned junit : $ARTIFACTS_DIR/py/junit..xml" +echo " versioned coverage : $ARTIFACTS_DIR/py/{coverage.txt,coverage.xml}" +echo " run summary : $ARTIFACTS_DIR/{summary.txt,summary.json}" + +if [[ "${#SELECTED_PYTHON_VERSIONS[@]}" -eq 1 ]]; then + echo " root coverage xml : $ROOT_DIR/coverage.xml" +else + echo " root coverage xmls : $ROOT_DIR/coverage.py.xml" +fi diff --git a/start_LDAP.sh b/start_LDAP.sh deleted file mode 100644 index 8b612de..0000000 --- a/start_LDAP.sh +++ /dev/null @@ -1,8 +0,0 @@ - -#!/bin/bash -set -e - -# Start LDAP server in docker container -# sudo docker pull osixia/openldap:latest -sudo docker compose -f .github/workflows/docker-configs/ldap-docker-compose.yml up -d -sudo docker ps