From c474afd59bf019e5f61c2fca963cb0bad0f061c3 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 27 Jan 2026 12:01:34 -0600 Subject: [PATCH 01/18] Updating authenticators from latest in Tiled --- bluesky_httpserver/app.py | 9 +- bluesky_httpserver/authentication.py | 17 +- bluesky_httpserver/authentication/__init__.py | 11 + .../authentication/authenticator_base.py | 39 ++ bluesky_httpserver/authenticators.py | 462 ++++++++++++------ 5 files changed, 365 insertions(+), 173 deletions(-) create mode 100644 bluesky_httpserver/authentication/__init__.py create mode 100644 bluesky_httpserver/authentication/authenticator_base.py diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index f09acb3..9a8420a 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from .authentication import Mode +from .authentication import ExternalAuthenticator, InternalAuthenticator from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired @@ -179,12 +179,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, InternalAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) @@ -192,7 +191,7 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_auth_code_route(authenticator, provider) ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router(custom_router, prefix=f"/provider/{provider}") diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/authentication.py index 9772974..a30db6a 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/authentication.py @@ -31,6 +31,11 @@ from pydantic_settings import BaseSettings from . import schemas +from .authentication.authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm @@ -54,12 +59,6 @@ def utcnow(): "UTC now with second resolution" return datetime.utcnow().replace(microsecond=0) - -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -455,7 +454,8 @@ async def auth_code( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(request) + user_session_state = await authenticator.authenticate(request) + username = user_session_state.user_name if user_session_state else None if username and api_access_manager.is_user_known(username): scopes = api_access_manager.get_user_scopes(username) @@ -484,7 +484,8 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate(username=form_data.username, password=form_data.password) + username = user_session_state.user_name if user_session_state else None err_msg = None if not username: diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py new file mode 100644 index 0000000..58c758f --- /dev/null +++ b/bluesky_httpserver/authentication/__init__.py @@ -0,0 +1,11 @@ +from .authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +__all__ = [ + "ExternalAuthenticator", + "InternalAuthenticator", + "UserSessionState", +] diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py new file mode 100644 index 0000000..7a2cff3 --- /dev/null +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -0,0 +1,39 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Optional + +from fastapi import Request + + +@dataclass +class UserSessionState: + """Data transfer class to communicate custom session state information.""" + + user_name: str + state: dict = None + + +class InternalAuthenticator(ABC): + """ + Base class for authenticators that use username/password credentials. + + Subclasses must implement the authenticate method which takes a username + and password and returns a UserSessionState on success or None on failure. + """ + + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: + raise NotImplementedError + + +class ExternalAuthenticator(ABC): + """ + Base class for authenticators that use external identity providers. + + Subclasses must implement the authenticate method which takes a FastAPI + Request object and returns a UserSessionState on success or None on failure. + """ + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 61c2da4..3b439f4 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,21 +1,32 @@ import asyncio +import base64 import functools import logging import re import secrets from collections.abc import Iterable +from datetime import timedelta +from typing import Any, List, Mapping, Optional, cast +import httpx +from cachetools import TTLCache, cached from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt +from fastapi.security import OAuth2, OAuth2AuthorizationCodeBearer +from jose import JWTError, jwt +from pydantic import Secret from starlette.responses import RedirectResponse -from .authentication import Mode -from .utils import modules_available +from .authentication import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) +from .utils import get_root_url, modules_available logger = logging.getLogger(__name__) -class DummyAuthenticator: +class DummyAuthenticator(InternalAuthenticator): """ For test and demo purposes only! @@ -23,26 +34,20 @@ class DummyAuthenticator: """ - mode = Mode.password + def __init__(self, confirmation_message: str = ""): + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): - return username + async def authenticate(self, username: str, password: str) -> UserSessionState: + return UserSessionState(username, {}) -class DictionaryAuthenticator: +class DictionaryAuthenticator(InternalAuthenticator): """ For test and demo purposes only! Check passwords from a dictionary of usernames mapped to passwords. - - Parameters - ---------- - - users_to_passwords: dict(str, str) - Mapping of usernames to passwords. """ - mode = Mode.password configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -50,25 +55,32 @@ class DictionaryAuthenticator: properties: users_to_password: type: object - description: | - Mapping usernames to password. Environment variable expansion should be - used to avoid placing passwords directly in configuration. + description: | + Mapping usernames to password. Environment variable expansion should be + used to avoid placing passwords directly in configuration. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, users_to_passwords): + def __init__( + self, users_to_passwords: Mapping[str, str], confirmation_message: str = "" + ): self._users_to_passwords = users_to_passwords + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: true_password = self._users_to_passwords.get(username) if not true_password: # Username is not valid. - return + return None if secrets.compare_digest(true_password, password): - return username + return UserSessionState(username, {}) -class PAMAuthenticator: - mode = Mode.password +class PAMAuthenticator(InternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -77,90 +89,149 @@ class PAMAuthenticator: service: type: string description: PAM service. Default is 'login'. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, service="login"): + def __init__(self, service: str = "login", confirmation_message: str = ""): if not modules_available("pamela"): - raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") + raise ModuleNotFoundError( + "This PAMAuthenticator requires the module 'pamela' to be installed." + ) self.service = service + self.confirmation_message = confirmation_message # TODO Try to open a PAM session. - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: import pamela try: pamela.authenticate(username, password, service=self.service) + return UserSessionState(username, {}) except pamela.PAMError: # Authentication failed. - return - else: - return username + return None -class OIDCAuthenticator: - mode = Mode.external +class OIDCAuthenticator(ExternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object additionalProperties: false properties: + audience: + type: string client_id: type: string client_secret: type: string - redirect_uri: + well_known_uri: type: string - token_uri: + confirmation_message: type: string - authorization_endpoint: + redirect_on_success: + type: string + redirect_on_failure: type: string - public_keys: - type: array - item: - type: object - properties: - - alg: - type: string - - e - type: string - - kid - type: string - - kty - type: string - - n - type: string - - use - type: string - required: - - alg - - e - - kid - - kty - - n - - use """ def __init__( self, - client_id, - client_secret, - redirect_uri, - public_keys, - token_uri, - authorization_endpoint, - confirmation_message, + audience: str, + client_id: str, + client_secret: str, + well_known_uri: str, + confirmation_message: str = "", + redirect_on_success: Optional[str] = None, + redirect_on_failure: Optional[str] = None, ): - self.client_id = client_id - self.client_secret = client_secret + self._audience = audience + self._client_id = client_id + self._client_secret = Secret(client_secret) + self._well_known_url = well_known_uri self.confirmation_message = confirmation_message - self.redirect_uri = redirect_uri - self.public_keys = public_keys - self.token_uri = token_uri - self.authorization_endpoint = authorization_endpoint.format(client_id=client_id, redirect_uri=redirect_uri) - - async def authenticate(self, request): - code = request.query_params["code"] - response = await exchange_code(self.token_uri, code, self.client_id, self.client_secret, self.redirect_uri) + self.redirect_on_success = redirect_on_success + self.redirect_on_failure = redirect_on_failure + + @functools.cached_property + def _config_from_oidc_url(self) -> dict[str, Any]: + response: httpx.Response = httpx.get(self._well_known_url) + response.raise_for_status() + return response.json() + + @functools.cached_property + def client_id(self) -> str: + return self._client_id + + @functools.cached_property + def id_token_signing_alg_values_supported(self) -> list[str]: + return cast( + list[str], + self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), + ) + + @functools.cached_property + def issuer(self) -> str: + return cast(str, self._config_from_oidc_url.get("issuer")) + + @functools.cached_property + def jwks_uri(self) -> str: + return cast(str, self._config_from_oidc_url.get("jwks_uri")) + + @functools.cached_property + def token_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("token_endpoint")) + + @functools.cached_property + def authorization_endpoint(self) -> httpx.URL: + return httpx.URL( + cast(str, self._config_from_oidc_url.get("authorization_endpoint")) + ) + + @functools.cached_property + def device_authorization_endpoint(self) -> str: + return cast( + str, self._config_from_oidc_url.get("device_authorization_endpoint") + ) + + @functools.cached_property + def end_session_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) + + @cached(TTLCache(maxsize=1, ttl=timedelta(days=7).total_seconds())) + def keys(self) -> List[str]: + return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) + + def decode_token(self, token: str) -> dict[str, Any]: + return jwt.decode( + token, + key=self.keys(), + algorithms=self.id_token_signing_alg_values_supported, + audience=self._audience, + issuer=self.issuer, + ) + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + code = request.query_params.get("code") + if not code: + logger.warning( + "Authentication failed: No authorization code parameter provided." + ) + return None + # A proxy in the middle may make the request into something like + # 'http://localhost:8000/...' so we fix the first part but keep + # the original URI path. + redirect_uri = f"{get_root_url(request)}{request.url.path}" + response = await exchange_code( + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, + ) response_body = response.json() if response.is_error: logger.error("Authentication error: %r", response_body) @@ -168,63 +239,84 @@ async def authenticate(self, request): response_body = response.json() id_token = response_body["id_token"] access_token = response_body["access_token"] - # Match the kid in id_token to a key in the list of public_keys. - key = find_key(id_token, self.public_keys) try: - verified_body = jwt.decode(id_token, key, access_token=access_token, audience=self.client_id) + verified_body = self.decode_token(access_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return verified_body["sub"] + return UserSessionState(verified_body["sub"], {}) -class KeyNotFoundError(Exception): - pass - - -def find_key(token, keys): - """ - Find a key from the configured keys based on the kid claim of the token - - Parameters - ---------- - token : token to search for the kid from - keys: list of keys - - Raises - ------ - KeyNotFoundError: - returned if the token does not have a kid claim - - Returns - ------ - key: found key object - """ +class ProxiedOIDCAuthenticator(OIDCAuthenticator): + configuration_schema = """ +$schema": http://json-schema.org/draft-07/schema# +type: object +additionalProperties: false +properties: + audience: + type: string + client_id: + type: string + well_known_uri: + type: string + scopes: + type: array + items: + type: string + description: | + Optional list of OAuth2 scopes to request. If provided, authorization + should be enforced by an external policy agent (for example ExternalPolicyDecisionPoint) + rather than by this authenticator. + device_flow_client_id: + type: string + confirmation_message: + type: string +""" - unverified = jwt.get_unverified_header(token) - kid = unverified.get("kid") - if not kid: - raise KeyNotFoundError("No 'kid' in token") + def __init__( + self, + audience: str, + client_id: str, + well_known_uri: str, + device_flow_client_id: str, + scopes: Optional[List[str]] = None, + confirmation_message: str = "", + ): + super().__init__( + audience=audience, + client_id=client_id, + client_secret="", + well_known_uri=well_known_uri, + confirmation_message=confirmation_message, + ) + self.scopes = scopes + self.device_flow_client_id = device_flow_client_id + self._oidc_bearer = OAuth2AuthorizationCodeBearer( + authorizationUrl=str(self.authorization_endpoint), + tokenUrl=self.token_endpoint, + ) - for key in keys: - if key["kid"] == kid: - return jwk.construct(key) - return KeyNotFoundError(f"Token specifies {kid} but we have {[k['kid'] for k in keys]}") + @property + def oauth2_schema(self) -> OAuth2: + return self._oidc_bearer -async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri): +async def exchange_code( + token_uri: str, + auth_code: str, + client_id: str, + client_secret: str, + redirect_uri: str, +) -> httpx.Response: """Method that talks to an IdP to exchange a code for an access_token and/or id_token Args: token_url ([type]): [description] auth_code ([type]): [description] """ - if not modules_available("httpx"): - raise ModuleNotFoundError("This authenticator requires 'httpx'. (pip install httpx)") - import httpx - + auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() response = httpx.post( url=token_uri, data={ @@ -234,18 +326,18 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect "code": auth_code, "client_secret": client_secret, }, + headers={"Authorization": f"Basic {auth_value}"}, ) return response -class SAMLAuthenticator: - mode = Mode.external +class SAMLAuthenticator(ExternalAuthenticator): def __init__( self, saml_settings, # See EXAMPLE_SAML_SETTINGS below. - attribute_name, # which SAML attribute to use as 'id' for Idenity - confirmation_message=None, + attribute_name: str, # which SAML attribute to use as 'id' for Identity + confirmation_message: str = "", ): self.saml_settings = saml_settings self.attribute_name = attribute_name @@ -258,30 +350,26 @@ def __init__( # The PyPI package name is 'python3-saml' # but it imports as 'onelogin'. # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") + raise ModuleNotFoundError( + "This SAMLAuthenticator requires 'python3-saml' to be installed." + ) from onelogin.saml2.auth import OneLogin_Saml2_Auth @router.get("/login") - async def saml_login(request: Request): + async def saml_login(request: Request) -> RedirectResponse: req = await prepare_saml_from_fastapi_request(request) auth = OneLogin_Saml2_Auth(req, self.saml_settings) - # saml_settings = auth.get_settings() - # metadata = saml_settings.get_sp_metadata() - # errors = saml_settings.validate_metadata(metadata) - # if len(errors) == 0: - # print(metadata) - # else: - # print("Error found on Metadata: %s" % (', '.join(errors))) callback_url = auth.login() - response = RedirectResponse(url=callback_url) - return response + return RedirectResponse(url=callback_url) self.include_routers = [router] - async def authenticate(self, request): + async def authenticate(self, request: Request) -> Optional[UserSessionState]: if not modules_available("onelogin"): - raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") + raise ModuleNotFoundError( + "This SAMLAuthenticator requires the module 'oneline' to be installed." + ) from onelogin.saml2.auth import OneLogin_Saml2_Auth req = await prepare_saml_from_fastapi_request(request, True) @@ -290,26 +378,27 @@ async def authenticate(self, request): errors = auth.get_errors() # This method receives an array with the errors if errors: raise Exception( - "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) + "Error when processing SAML Response: %s %s" + % (", ".join(errors), auth.get_last_error_reason()) ) if auth.is_authenticated(): # Return a string that the Identity can use as id. attribute_as_list = auth.get_attributes()[self.attribute_name] # Confused in what situation this would have more than one item.... assert len(attribute_as_list) == 1 - return attribute_as_list[0] + return UserSessionState(attribute_as_list[0], {}) else: return None -async def prepare_saml_from_fastapi_request(request, debug=False): +async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]: form_data = await request.form() rv = { "http_host": request.client.host, "server_port": request.url.port, "script_name": request.url.path, "post_data": {}, - "get_data": {}, + "get_data": {} # Advanced request options # "https": "", # "request_uri": "", @@ -328,7 +417,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False): return rv -class LDAPAuthenticator: +class LDAPAuthenticator(InternalAuthenticator): """ LDAP authenticator. The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator @@ -472,6 +561,8 @@ class LDAPAuthenticator: This can be useful in an heterogeneous environment, when supplying a UNIX username to authenticate against AD. + confirmation_message: str + May be displayed by client after successful login. Examples -------- @@ -510,8 +601,6 @@ class LDAPAuthenticator: id: user02 """ - mode = Mode.password - def __init__( self, server_address, @@ -536,6 +625,7 @@ def __init__( attributes=None, auth_state_attributes=None, use_lookup_dn_username=True, + confirmation_message="", ): self.use_ssl = use_ssl self.use_tls = use_tls @@ -554,7 +644,9 @@ def __init__( self.escape_userdn = escape_userdn self.search_filter = search_filter self.attributes = attributes if attributes else [] - self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] + self.auth_state_attributes = ( + auth_state_attributes if auth_state_attributes else [] + ) self.use_lookup_dn_username = use_lookup_dn_username if isinstance(server_address, str): @@ -567,10 +659,15 @@ def __init__( f"type(server_address)={type(server_address)}" ) if not server_address_list: - raise ValueError("No servers are specified: 'server_address' is an empty list") + raise ValueError( + "No servers are specified: 'server_address' is an empty list" + ) self.server_address_list = server_address_list - self.server_port = server_port if server_port is not None else self._server_port_default() + self.server_port = ( + server_port if server_port is not None else self._server_port_default() + ) + self.confirmation_message = confirmation_message def _server_port_default(self): if self.use_ssl: @@ -623,8 +720,15 @@ async def resolve_username(self, username_supplied_by_user): response = conn.response if len(response) == 0 or "attributes" not in response[0].keys(): - msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" - logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) + msg = ( + "No entry found for user '{username}' " + "when looking up attribute '{attribute}'" + ) + logger.warning( + msg.format( + username=username_supplied_by_user, attribute=self.user_attribute + ) + ) return (None, None) user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] @@ -655,7 +759,7 @@ async def resolve_username(self, username_supplied_by_user): def get_connection(self, userdn, password): import ldap3 - # NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool. + # NOTE: setting 'active=False' essentially disables exclusion of inactive servers from the pool. # It probably does not matter if the pool contains only one server, but it could have implications # when there are multiple servers in the pool. It is not clear what those implications are. # But using the default 'activate=True' results in the thread being blocked indefinitely @@ -675,14 +779,23 @@ def get_connection(self, userdn, password): server_port = self.server_port server = ldap3.Server( - server_addr, port=server_port, use_ssl=self.use_ssl, connect_timeout=self.connect_timeout + server_addr, + port=server_port, + use_ssl=self.use_ssl, + connect_timeout=self.connect_timeout, ) server_pool.add(server) - auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS + auto_bind_no_ssl = ( + ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS + ) auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( - server_pool, user=userdn, password=password, auto_bind=auto_bind, receive_timeout=self.receive_timeout + server_pool, + user=userdn, + password=password, + auto_bind=auto_bind, + receive_timeout=self.receive_timeout, ) return conn @@ -690,14 +803,19 @@ async def get_user_attributes(self, conn, userdn): attrs = {} if self.auth_state_attributes: search_func = functools.partial( - conn.search, userdn, "(objectClass=*)", attributes=self.auth_state_attributes + conn.search, + userdn, + "(objectClass=*)", + attributes=self.auth_state_attributes, ) found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: attrs = conn.entries[0].entry_attributes_as_dict return attrs - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: import ldap3 username_saved = username # Save the user name passed as a parameter @@ -723,7 +841,9 @@ async def authenticate(self, username: str, password: str): # sanity check if not self.lookup_dn and not bind_dn_template: - logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") + logger.warning( + "Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'." + ) return None if self.lookup_dn: @@ -761,7 +881,9 @@ async def authenticate(self, username: str, password: str): if conn.bound: is_bound = True else: - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) + is_bound = await asyncio.get_running_loop().run_in_executor( + None, conn.bind + ) msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) logger.debug(msg) @@ -774,7 +896,9 @@ async def authenticate(self, username: str, password: str): return None if self.search_filter: - search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) + search_filter = self.search_filter.format( + userattr=self.user_attribute, username=username + ) search_func = functools.partial( conn.search, @@ -788,18 +912,33 @@ async def authenticate(self, username: str, password: str): n_users = len(conn.response) if n_users == 0: msg = "User with '{userattr}={username}' not found in directory" - logger.warning(msg.format(userattr=self.user_attribute, username=username)) + logger.warning( + msg.format(userattr=self.user_attribute, username=username) + ) return None if n_users > 1: - msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" - logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) + msg = ( + "Duplicate users found! " + "{n_users} users found with '{userattr}={username}'" + ) + logger.warning( + msg.format( + userattr=self.user_attribute, username=username, n_users=n_users + ) + ) return None if self.allowed_groups: logger.debug("username:%s Using dn %s", username, userdn) found = False for group in self.allowed_groups: - group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" + group_filter = ( + "(|" + "(member={userdn})" + "(uniqueMember={userdn})" + "(memberUid={uid})" + ")" + ) group_filter = group_filter.format(userdn=userdn, uid=username) group_attributes = ["member", "uniqueMember", "memberUid"] @@ -810,7 +949,9 @@ async def authenticate(self, username: str, password: str): search_filter=group_filter, attributes=group_attributes, ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) + found = await asyncio.get_running_loop().run_in_executor( + None, search_func + ) if found: break @@ -826,5 +967,6 @@ async def authenticate(self, username: str, password: str): user_info = await self.get_user_attributes(conn, userdn) if user_info: logger.debug("username:%s attributes:%s", username, user_info) - return {"name": username, "auth_state": user_info} - return username + # this path might never have been worked out...is it ever hit? + return UserSessionState(username, user_info) + return UserSessionState(username, {}) From f54ef363ba1b9a04dad06b1fe7cb76ab184a5078 Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Thu, 5 Feb 2026 11:20:34 -0500 Subject: [PATCH 02/18] TST: fix unit tests --- .../{authentication.py => _authentication.py} | 11 ++++------- bluesky_httpserver/authentication/__init__.py | 14 ++++++++++++++ bluesky_httpserver/tests/test_authenticators.py | 6 +++--- requirements.txt | 1 + 4 files changed, 22 insertions(+), 10 deletions(-) rename bluesky_httpserver/{authentication.py => _authentication.py} (99%) diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/_authentication.py similarity index 99% rename from bluesky_httpserver/authentication.py rename to bluesky_httpserver/_authentication.py index a30db6a..0375794 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -1,5 +1,4 @@ import asyncio -import enum import hashlib import secrets import uuid as uuid_module @@ -31,11 +30,6 @@ from pydantic_settings import BaseSettings from . import schemas -from .authentication.authenticator_base import ( - ExternalAuthenticator, - InternalAuthenticator, - UserSessionState, -) from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm @@ -59,6 +53,7 @@ def utcnow(): "UTC now with second resolution" return datetime.utcnow().replace(microsecond=0) + class Token(BaseModel): access_token: str token_type: str @@ -484,7 +479,9 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - user_session_state = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate( + username=form_data.username, password=form_data.password + ) username = user_session_state.user_name if user_session_state else None err_msg = None diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index 58c758f..fc35cdd 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,3 +1,11 @@ +from .._authentication import ( + base_authentication_router, + build_auth_code_route, + build_handle_credentials_route, + get_current_principal, + get_current_principal_websocket, + oauth2_scheme, +) from .authenticator_base import ( ExternalAuthenticator, InternalAuthenticator, @@ -8,4 +16,10 @@ "ExternalAuthenticator", "InternalAuthenticator", "UserSessionState", + "get_current_principal", + "get_current_principal_websocket", + "base_authentication_router", + "build_auth_code_route", + "build_handle_credentials_route", + "oauth2_scheme", ] diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index cc2984c..183ce75 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -3,7 +3,7 @@ import pytest # fmt: off -from ..authenticators import LDAPAuthenticator +from ..authenticators import LDAPAuthenticator, UserSessionState @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ @@ -35,8 +35,8 @@ def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server ) async def testing(): - assert await authenticator.authenticate("user01", "password1") == "user01" - assert await authenticator.authenticate("user02", "password2") == "user02" + assert await authenticator.authenticate("user01", "password1") == UserSessionState("user01", {}) + assert await authenticator.authenticate("user02", "password2") == UserSessionState("user02", {}) assert await authenticator.authenticate("user02a", "password2") is None assert await authenticator.authenticate("user02", "password2a") is None diff --git a/requirements.txt b/requirements.txt index f465abd..818362f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ alembic bluesky-queueserver bluesky-queueserver-api +cachetools fastapi ldap3 orjson From be73eda1c78b7a76e7784a979de38783325aa32e Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Thu, 5 Feb 2026 11:26:58 -0500 Subject: [PATCH 03/18] STY: reformat with black --- .../authentication/authenticator_base.py | 4 +- bluesky_httpserver/authenticators.py | 111 ++++-------------- 2 files changed, 27 insertions(+), 88 deletions(-) diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py index 7a2cff3..af103c5 100644 --- a/bluesky_httpserver/authentication/authenticator_base.py +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -21,9 +21,7 @@ class InternalAuthenticator(ABC): and password and returns a UserSessionState on success or None on failure. """ - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 3b439f4..78b6cf1 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -63,15 +63,11 @@ class DictionaryAuthenticator(InternalAuthenticator): description: May be displayed by client after successful login. """ - def __init__( - self, users_to_passwords: Mapping[str, str], confirmation_message: str = "" - ): + def __init__(self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""): self._users_to_passwords = users_to_passwords self.confirmation_message = confirmation_message - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: true_password = self._users_to_passwords.get(username) if not true_password: # Username is not valid. @@ -96,16 +92,12 @@ class PAMAuthenticator(InternalAuthenticator): def __init__(self, service: str = "login", confirmation_message: str = ""): if not modules_available("pamela"): - raise ModuleNotFoundError( - "This PAMAuthenticator requires the module 'pamela' to be installed." - ) + raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") self.service = service self.confirmation_message = confirmation_message # TODO Try to open a PAM session. - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: import pamela try: @@ -187,15 +179,11 @@ def token_endpoint(self) -> str: @functools.cached_property def authorization_endpoint(self) -> httpx.URL: - return httpx.URL( - cast(str, self._config_from_oidc_url.get("authorization_endpoint")) - ) + return httpx.URL(cast(str, self._config_from_oidc_url.get("authorization_endpoint"))) @functools.cached_property def device_authorization_endpoint(self) -> str: - return cast( - str, self._config_from_oidc_url.get("device_authorization_endpoint") - ) + return cast(str, self._config_from_oidc_url.get("device_authorization_endpoint")) @functools.cached_property def end_session_endpoint(self) -> str: @@ -217,9 +205,7 @@ def decode_token(self, token: str) -> dict[str, Any]: async def authenticate(self, request: Request) -> Optional[UserSessionState]: code = request.query_params.get("code") if not code: - logger.warning( - "Authentication failed: No authorization code parameter provided." - ) + logger.warning("Authentication failed: No authorization code parameter provided.") return None # A proxy in the middle may make the request into something like # 'http://localhost:8000/...' so we fix the first part but keep @@ -350,9 +336,7 @@ def __init__( # The PyPI package name is 'python3-saml' # but it imports as 'onelogin'. # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError( - "This SAMLAuthenticator requires 'python3-saml' to be installed." - ) + raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") from onelogin.saml2.auth import OneLogin_Saml2_Auth @@ -367,9 +351,7 @@ async def saml_login(request: Request) -> RedirectResponse: async def authenticate(self, request: Request) -> Optional[UserSessionState]: if not modules_available("onelogin"): - raise ModuleNotFoundError( - "This SAMLAuthenticator requires the module 'oneline' to be installed." - ) + raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") from onelogin.saml2.auth import OneLogin_Saml2_Auth req = await prepare_saml_from_fastapi_request(request, True) @@ -378,8 +360,7 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: errors = auth.get_errors() # This method receives an array with the errors if errors: raise Exception( - "Error when processing SAML Response: %s %s" - % (", ".join(errors), auth.get_last_error_reason()) + "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) ) if auth.is_authenticated(): # Return a string that the Identity can use as id. @@ -398,7 +379,7 @@ async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, st "server_port": request.url.port, "script_name": request.url.path, "post_data": {}, - "get_data": {} + "get_data": {}, # Advanced request options # "https": "", # "request_uri": "", @@ -644,9 +625,7 @@ def __init__( self.escape_userdn = escape_userdn self.search_filter = search_filter self.attributes = attributes if attributes else [] - self.auth_state_attributes = ( - auth_state_attributes if auth_state_attributes else [] - ) + self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] self.use_lookup_dn_username = use_lookup_dn_username if isinstance(server_address, str): @@ -659,14 +638,10 @@ def __init__( f"type(server_address)={type(server_address)}" ) if not server_address_list: - raise ValueError( - "No servers are specified: 'server_address' is an empty list" - ) + raise ValueError("No servers are specified: 'server_address' is an empty list") self.server_address_list = server_address_list - self.server_port = ( - server_port if server_port is not None else self._server_port_default() - ) + self.server_port = server_port if server_port is not None else self._server_port_default() self.confirmation_message = confirmation_message def _server_port_default(self): @@ -720,15 +695,8 @@ async def resolve_username(self, username_supplied_by_user): response = conn.response if len(response) == 0 or "attributes" not in response[0].keys(): - msg = ( - "No entry found for user '{username}' " - "when looking up attribute '{attribute}'" - ) - logger.warning( - msg.format( - username=username_supplied_by_user, attribute=self.user_attribute - ) - ) + msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" + logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) return (None, None) user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] @@ -786,9 +754,7 @@ def get_connection(self, userdn, password): ) server_pool.add(server) - auto_bind_no_ssl = ( - ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS - ) + auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( server_pool, @@ -813,9 +779,7 @@ async def get_user_attributes(self, conn, userdn): attrs = conn.entries[0].entry_attributes_as_dict return attrs - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: import ldap3 username_saved = username # Save the user name passed as a parameter @@ -841,9 +805,7 @@ async def authenticate( # sanity check if not self.lookup_dn and not bind_dn_template: - logger.warning( - "Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'." - ) + logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") return None if self.lookup_dn: @@ -881,9 +843,7 @@ async def authenticate( if conn.bound: is_bound = True else: - is_bound = await asyncio.get_running_loop().run_in_executor( - None, conn.bind - ) + is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) logger.debug(msg) @@ -896,9 +856,7 @@ async def authenticate( return None if self.search_filter: - search_filter = self.search_filter.format( - userattr=self.user_attribute, username=username - ) + search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) search_func = functools.partial( conn.search, @@ -912,33 +870,18 @@ async def authenticate( n_users = len(conn.response) if n_users == 0: msg = "User with '{userattr}={username}' not found in directory" - logger.warning( - msg.format(userattr=self.user_attribute, username=username) - ) + logger.warning(msg.format(userattr=self.user_attribute, username=username)) return None if n_users > 1: - msg = ( - "Duplicate users found! " - "{n_users} users found with '{userattr}={username}'" - ) - logger.warning( - msg.format( - userattr=self.user_attribute, username=username, n_users=n_users - ) - ) + msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" + logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) return None if self.allowed_groups: logger.debug("username:%s Using dn %s", username, userdn) found = False for group in self.allowed_groups: - group_filter = ( - "(|" - "(member={userdn})" - "(uniqueMember={userdn})" - "(memberUid={uid})" - ")" - ) + group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" group_filter = group_filter.format(userdn=userdn, uid=username) group_attributes = ["member", "uniqueMember", "memberUid"] @@ -949,9 +892,7 @@ async def authenticate( search_filter=group_filter, attributes=group_attributes, ) - found = await asyncio.get_running_loop().run_in_executor( - None, search_func - ) + found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: break From 122aa17434a1a35a2e0dab9ebc402a04b0722873 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 17 Feb 2026 15:04:02 -0600 Subject: [PATCH 04/18] Working version for logging in with Entra This is working okay, although it doens't really work smoothly for the API based login and the http command based login isn't great, as it requires the user to copy and past token around. Compared to ldap which just logs the user in. So still some work to do here to smooth out the user experience. --- bluesky_httpserver/_authentication.py | 363 +++++++++++++++++- bluesky_httpserver/app.py | 27 ++ bluesky_httpserver/authentication/__init__.py | 10 + bluesky_httpserver/authenticators.py | 27 +- .../config_schemas/examples/oidc_config.yml | 78 ++++ .../config_schemas/service_configuration.yml | 32 +- bluesky_httpserver/database/core.py | 40 +- bluesky_httpserver/database/orm.py | 21 + bluesky_httpserver/schemas.py | 17 + bluesky_httpserver/tests/conftest.py | 25 ++ .../tests/test_oidc_authenticators.py | 224 +++++++++++ requirements-dev.txt | 3 + requirements.txt | 1 + 13 files changed, 843 insertions(+), 25 deletions(-) create mode 100644 bluesky_httpserver/config_schemas/examples/oidc_config.yml create mode 100644 bluesky_httpserver/tests/test_oidc_authenticators.py diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index 0375794..a0d28b1 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -6,12 +6,13 @@ from datetime import datetime, timedelta from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket +from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn -from fastapi.responses import JSONResponse +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery from fastapi.security.utils import get_authorization_scheme_param +from sqlalchemy.exc import IntegrityError # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: @@ -33,7 +34,14 @@ from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm -from .database.core import create_user, latest_principal_activity, lookup_valid_api_key, lookup_valid_session +from .database.core import ( + create_user, + latest_principal_activity, + lookup_valid_api_key, + lookup_valid_pending_session_by_device_code, + lookup_valid_pending_session_by_user_code, + lookup_valid_session, +) from .settings import get_sessionmaker, get_settings from .utils import ( API_KEY_COOKIE_NAME, @@ -48,6 +56,10 @@ ALGORITHM = "HS256" UNIT_SECOND = timedelta(seconds=1) +# Device code flow constants +DEVICE_CODE_MAX_AGE = timedelta(minutes=10) +DEVICE_CODE_POLLING_INTERVAL = 5 # seconds + def utcnow(): "UTC now with second resolution" @@ -505,6 +517,351 @@ async def handle_credentials( return handle_credentials +def create_pending_session(db): + """ + Create a pending session for device code flow. + + Returns a dict with 'user_code' (user-facing code) and 'device_code' (for polling). + """ + device_code = secrets.token_bytes(32) + hashed_device_code = hashlib.sha256(device_code).digest() + for _ in range(3): + user_code = secrets.token_hex(4).upper() # 8 digit code + pending_session = orm.PendingSession( + user_code=user_code, + hashed_device_code=hashed_device_code, + expiration_time=utcnow() + DEVICE_CODE_MAX_AGE, + ) + db.add(pending_session) + try: + db.commit() + except IntegrityError: + # Since the user_code is short, we cannot completely dismiss the + # possibility of a collision. Retry. + db.rollback() + continue + break + formatted_user_code = f"{user_code[:4]}-{user_code[4:]}" + return { + "user_code": formatted_user_code, + "device_code": device_code.hex(), + } + + +def build_authorize_route(authenticator, provider): + """Build a GET route that redirects the browser to the OIDC provider for authentication.""" + + async def authorize_redirect( + request: Request, + state: Optional[str] = Query(None), + ): + """Redirect browser to OAuth provider for authentication.""" + redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code" + + params = { + "client_id": authenticator.client_id, + "response_type": "code", + "scope": "openid profile email", + "redirect_uri": redirect_uri, + } + if state: + params["state"] = state + + auth_url = authenticator.authorization_endpoint.copy_with(params=params) + return RedirectResponse(url=str(auth_url)) + + return authorize_redirect + + +def build_device_code_authorize_route(authenticator, provider): + """Build a POST route that initiates the device code flow for CLI/headless clients.""" + + async def device_code_authorize( + request: Request, + settings: BaseSettings = Depends(get_settings), + ): + """ + Initiate device code flow. + + Returns authorization_uri for the user to visit in browser, + and device_code + user_code for the CLI client to poll. + """ + request.state.endpoint = "auth" + with get_sessionmaker(settings.database_settings)() as db: + pending_session = create_pending_session(db) + + verification_uri = f"{get_base_url(request)}/auth/provider/{provider}/token" + authorization_uri = authenticator.authorization_endpoint.copy_with( + params={ + "client_id": authenticator.client_id, + "response_type": "code", + "scope": "openid profile email", + "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", + } + ) + return { + "authorization_uri": str(authorization_uri), # URL that user should visit in browser + "verification_uri": str(verification_uri), # URL that terminal client will poll + "interval": DEVICE_CODE_POLLING_INTERVAL, # suggested polling interval + "device_code": pending_session["device_code"], + "expires_in": int(DEVICE_CODE_MAX_AGE.total_seconds()), # seconds + "user_code": pending_session["user_code"], + } + + return device_code_authorize + + +def build_device_code_form_route(authenticator, provider): + """Build a GET route that shows the user code entry form.""" + + async def device_code_form( + request: Request, + code: str, + ): + """Show form for user to enter user code after browser auth.""" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + html_content = f""" + + + + Authorize Session + + + +

Authorize Bluesky HTTP Server Session

+
+ + + +
+ +
+ + +""" + return HTMLResponse(content=html_content) + + return device_code_form + + +def build_device_code_submit_route(authenticator, provider): + """Build a POST route that handles user code submission after browser auth.""" + + async def device_code_submit( + request: Request, + code: str = Form(), + user_code: str = Form(), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Handle user code submission and link to authenticated session.""" + request.state.endpoint = "auth" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + normalized_user_code = user_code.upper().replace("-", "").strip() + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) + if pending_session is None: + error_html = f""" + + +Error + + + +

Authorization Failed

+
Invalid user code. It may have been mistyped, or the pending request may have expired.
+
Try again + + +""" + return HTMLResponse(content=error_html, status_code=401) + + # Authenticate with the OIDC provider using the authorization code + user_session_state = await authenticator.authenticate(request) + if not user_session_state: + error_html = """ + + +Authentication Failed + + + +

Authentication Failed

+
User code was correct but authentication with the identity provider failed. Please contact the administrator.
+ + +""" + return HTMLResponse(content=error_html, status_code=401) + + username = user_session_state.user_name + if not api_access_manager.is_user_known(username): + error_html = f""" + + +Authorization Failed + + + +

Authorization Failed

+
User '{username}' is not authorized to access this server.
+ + +""" + return HTMLResponse(content=error_html, status_code=403) + + scopes = api_access_manager.get_user_scopes(username) + + # Create the session + session = await asyncio.get_running_loop().run_in_executor( + None, _create_session_orm, settings, provider, username, db + ) + + # Link the pending session to the real session + pending_session.session_id = session.id + db.add(pending_session) + db.commit() + + success_html = f""" + + +Success + + + +

Success!

+
You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.
+ + +""" + return HTMLResponse(content=success_html) + + return device_code_submit + + +def _create_session_orm(settings, identity_provider, id, db): + """ + Create a session and return the ORM object (for device code flow). + + Unlike create_session(), this returns the ORM object so we can link it + to the pending session. + """ + # Have we seen this Identity before? + identity = ( + db.query(orm.Identity) + .filter(orm.Identity.id == id) + .filter(orm.Identity.provider == identity_provider) + .first() + ) + now = utcnow() + if identity is None: + # We have not. Make a new Principal and link this new Identity to it. + principal = create_user(db, identity_provider, id) + (new_identity,) = principal.identities + new_identity.latest_login = now + else: + identity.latest_login = now + principal = identity.principal + + session = orm.Session( + principal_id=principal.id, + expiration_time=utcnow() + settings.session_max_age, + ) + db.add(session) + db.commit() + db.refresh(session) + return session + + +def build_device_code_token_route(authenticator, provider): + """Build a POST route for the CLI client to poll for tokens.""" + + async def device_code_token( + request: Request, + body: schemas.DeviceCode, + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """ + Poll for tokens after device code flow authentication. + + Returns tokens if the user has authenticated, or 400 with + 'authorization_pending' error if still waiting. + """ + request.state.endpoint = "auth" + device_code_hex = body.device_code + try: + device_code = bytes.fromhex(device_code_hex) + except Exception: + # Not valid hex, therefore not a valid device_code + raise HTTPException(status_code=401, detail="Invalid device code") + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_device_code(db, device_code) + if pending_session is None: + raise HTTPException( + status_code=404, + detail="No such device_code. The pending request may have expired.", + ) + if pending_session.session_id is None: + raise HTTPException(status_code=400, detail={"error": "authorization_pending"}) + + session = pending_session.session + principal = session.principal + + # Get scopes for the user + # Find an identity to get the username + identity = db.query(orm.Identity).filter(orm.Identity.principal_id == principal.id).first() + if identity and api_access_manager.is_user_known(identity.id): + scopes = api_access_manager.get_user_scopes(identity.id) + else: + scopes = set() + + # The pending session can only be used once + db.delete(pending_session) + db.commit() + + # Generate tokens + data = { + "sub": principal.uuid.hex, + "sub_typ": principal.type.value, + "scp": list(scopes), + "ids": [{"id": ident.id, "idp": ident.provider} for ident in principal.identities], + } + access_token = create_access_token( + data=data, + expires_delta=settings.access_token_max_age, + secret_key=settings.secret_keys[0], + ) + refresh_token = create_refresh_token( + session_id=session.uuid.hex, + expires_delta=settings.refresh_token_max_age, + secret_key=settings.secret_keys[0], + ) + + return { + "access_token": access_token, + "expires_in": int(settings.access_token_max_age / UNIT_SECOND), + "refresh_token": refresh_token, + "refresh_token_expires_in": int(settings.refresh_token_max_age / UNIT_SECOND), + "token_type": "bearer", + } + + return device_code_token + + def generate_apikey(db, principal, apikey_params, request, allowed_scopes, source_api_key_scopes): # Use API key scopes if API key is generated based on existing API key, otherwise used allowed scopes if (source_api_key_scopes is not None) and ("inherit" not in source_api_key_scopes): diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 9a8420a..0d96667 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -160,6 +160,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server from .authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, oauth2_scheme, ) @@ -184,12 +189,34 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_handle_credentials_route(authenticator, provider) ) elif isinstance(authenticator, ExternalAuthenticator): + # Standard OAuth callback route (authorization code flow) authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) authentication_router.post(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) + # Device code flow routes for CLI/headless clients + # GET /authorize - redirects browser to OIDC provider + authentication_router.get(f"/provider/{provider}/authorize")( + build_authorize_route(authenticator, provider) + ) + # POST /authorize - initiates device code flow (returns device_code, user_code, etc.) + authentication_router.post(f"/provider/{provider}/authorize")( + build_device_code_authorize_route(authenticator, provider) + ) + # GET /device_code - shows user code entry form + authentication_router.get(f"/provider/{provider}/device_code")( + build_device_code_form_route(authenticator, provider) + ) + # POST /device_code - handles user code submission after browser auth + authentication_router.post(f"/provider/{provider}/device_code")( + build_device_code_submit_route(authenticator, provider) + ) + # POST /token - CLI client polls this for tokens + authentication_router.post(f"/provider/{provider}/token")( + build_device_code_token_route(authenticator, provider) + ) else: raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index fc35cdd..85d835e 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,6 +1,11 @@ from .._authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, get_current_principal, get_current_principal_websocket, @@ -20,6 +25,11 @@ "get_current_principal_websocket", "base_authentication_router", "build_auth_code_route", + "build_authorize_route", + "build_device_code_authorize_route", + "build_device_code_form_route", + "build_device_code_submit_route", + "build_device_code_token_route", "build_handle_credentials_route", "oauth2_scheme", ] diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 78b6cf1..e8d108d 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -224,16 +224,37 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: return None response_body = response.json() id_token = response_body["id_token"] - access_token = response_body["access_token"] + # NOTE: We decode the id_token, not access_token, because: + # 1. The id_token is the OIDC identity assertion meant for the client + # 2. Some providers (like Microsoft Entra) return opaque access_tokens + # that cannot be decoded with the JWKS keys when the resource is + # a first-party Microsoft API (e.g., Graph API with User.Read scope) try: - verified_body = self.decode_token(access_token) + verified_body = self.decode_token(id_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return UserSessionState(verified_body["sub"], {}) + # Use preferred_username as the user identifier, extracting just the username + # part if it's in email format (user@domain.com -> user) + preferred_username = verified_body.get("preferred_username") + if preferred_username and "@" in preferred_username: + user_id = preferred_username.split("@")[0] + elif preferred_username: + user_id = preferred_username + else: + user_id = verified_body["sub"] + logger.info( + "OIDC authentication successful. user_id=%r (sub=%r, preferred_username=%r, email=%r, name=%r)", + user_id, + verified_body.get("sub"), + verified_body.get("preferred_username"), + verified_body.get("email"), + verified_body.get("name"), + ) + return UserSessionState(user_id, {}) class ProxiedOIDCAuthenticator(OIDCAuthenticator): diff --git a/bluesky_httpserver/config_schemas/examples/oidc_config.yml b/bluesky_httpserver/config_schemas/examples/oidc_config.yml new file mode 100644 index 0000000..c2f8d24 --- /dev/null +++ b/bluesky_httpserver/config_schemas/examples/oidc_config.yml @@ -0,0 +1,78 @@ +# Example OIDC Configuration for Bluesky HTTP Server +# +# This example shows how to configure OIDC (OpenID Connect) authentication. +# OIDC is used by providers like Google, Microsoft Entra (Azure AD), Okta, Keycloak, etc. +# +# Required environment variables: +# - OIDC_CLIENT_ID: The client ID from your OIDC provider +# - OIDC_CLIENT_SECRET: The client secret from your OIDC provider +# - OIDC_WELL_KNOWN_URI: The .well-known/openid-configuration URL +# +# Example for Google: +# OIDC_WELL_KNOWN_URI=https://accounts.google.com/.well-known/openid-configuration +# +# Example for Microsoft Entra (Azure AD): +# OIDC_WELL_KNOWN_URI=https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration +# +# Example for Keycloak: +# OIDC_WELL_KNOWN_URI=https://your-keycloak-server/realms/{realm}/.well-known/openid-configuration + +authentication: + providers: + - provider: oidc + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + # The audience should match the client_id or be a value expected by your OIDC provider + audience: ${OIDC_CLIENT_ID} + client_id: ${OIDC_CLIENT_ID} + client_secret: ${OIDC_CLIENT_SECRET} + well_known_uri: ${OIDC_WELL_KNOWN_URI} + confirmation_message: "You have successfully logged in via OIDC as {id}." + # Optional: redirect URLs after authentication + # redirect_on_success: https://your-app.example.com/success + # redirect_on_failure: https://your-app.example.com/login-failed + + # Secret keys used to sign secure tokens (generate with: openssl rand -hex 32) + secret_keys: + - ${SECRET_KEY} + + # Allow unauthenticated access to public endpoints + allow_anonymous_access: false + + # Token lifetimes (in seconds) + access_token_max_age: 900 # 15 minutes + refresh_token_max_age: 604800 # 7 days + +# Database for storing sessions and API keys +database: + uri: ${DATABASE_URI} + pool_size: 5 + pool_pre_ping: true + +# API access control - configure which users have access +api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + # Add users identified by their OIDC subject ID (sub claim) + # The ID typically looks like an email or UUID depending on your OIDC provider + user@example.com: + roles: + - admin + - user + +# Resource access control +resource_access: + policy: bluesky_httpserver.authorization:DefaultResourceAccessControl + args: + default_group: root + +# Queue Server connection +qserver_zmq_configuration: + control_address: tcp://localhost:60615 + info_address: tcp://localhost:60625 + +# HTTP Server configuration +uvicorn: + host: 0.0.0.0 + port: 8000 diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index 57343f7..12f01a3 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -47,14 +47,14 @@ properties: properties: custom_routers: type: array - item: + items: type: string description: | The list of Python modules with custom routers. Overrides the list of modules set using QSERVER_HTTP_CUSTOM_ROUTERS environment variable. custom_modules: type: array - item: + items: type: string description: | THE FUNCTIONALITY WILL BE DEPRECATED IN FAVOR OF CUSTOM ROUTERS. Overrides the list of modules @@ -65,7 +65,7 @@ properties: properties: providers: type: array - item: + items: type: object additionalProperties: false required: @@ -83,7 +83,7 @@ properties: description: | Type of Authenticator to use. - These are typically from the tiled.authenticators module, + These are typically from the bluesky_httpserver.authenticators module, though user-defined ones may be used as well. This is given as an import path. In an import path, packages/modules @@ -92,21 +92,21 @@ properties: Example: ```yaml - authenticator: bluesky_httpserver.examples.DummyAuthenticator + authenticator: bluesky_httpserver.authenticators:DummyAuthenticator ``` - args: - type: [object, "null"] - description: | - Named arguments to pass to Authenticator. If there are none, - `args` may be omitted or empty. + args: + type: object + description: | + Named arguments to pass to Authenticator. If there are none, + `args` may be omitted or empty. - Example: + Example: - ```yaml - authenticator: bluesky_httpserver.examples.PAMAuthenticator - args: - service: "custom_service" - ``` + ```yaml + authenticator: bluesky_httpserver.authenticators:PAMAuthenticator + args: + service: "custom_service" + ``` # qserver_admins: # type: array # items: diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index 163fac3..f096edc 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -1,6 +1,7 @@ import hashlib import uuid as uuid_module from datetime import datetime +from typing import Optional from alembic import command from alembic.config import Config @@ -10,13 +11,13 @@ from .alembic_utils import temp_alembic_ini from .base import Base -from .orm import APIKey, Identity, Principal, Session # , Role +from .orm import APIKey, Identity, PendingSession, Principal, Session # , Role # This is the alembic revision ID of the database revision # required by this version of Tiled. -REQUIRED_REVISION = "722ff4e4fcc7" +REQUIRED_REVISION = "a1b2c3d4e5f6" # This is list of all valid revisions (from current to oldest). -ALL_REVISIONS = ["722ff4e4fcc7", "481830dd6c11"] +ALL_REVISIONS = ["a1b2c3d4e5f6", "722ff4e4fcc7", "481830dd6c11"] # def create_default_roles(engine): @@ -294,3 +295,36 @@ def latest_principal_activity(db, principal): if all([t is None for t in all_activity]): return None return max(t for t in all_activity if t is not None) + + +def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optional[PendingSession]: + """ + Look up a pending session by its device code. + + Returns None if the pending session is not found or has expired. + """ + hashed_device_code = hashlib.sha256(device_code).digest() + pending_session = db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session + + +def lookup_valid_pending_session_by_user_code(db, user_code: str) -> Optional[PendingSession]: + """ + Look up a pending session by its user code. + + Returns None if the pending session is not found or has expired. + """ + pending_session = db.query(PendingSession).filter(PendingSession.user_code == user_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session diff --git a/bluesky_httpserver/database/orm.py b/bluesky_httpserver/database/orm.py index 17d7c82..7611824 100644 --- a/bluesky_httpserver/database/orm.py +++ b/bluesky_httpserver/database/orm.py @@ -181,3 +181,24 @@ class Session(Timestamped, Base): revoked = Column(Boolean, default=False, nullable=False) principal = relationship("Principal", back_populates="sessions") + pending_sessions = relationship("PendingSession", back_populates="session") + + +class PendingSession(Timestamped, Base): + """ + This is used only in Device Code Flow for OIDC authentication. + + When a CLI client initiates the device code flow, a pending session is created + with a device_code (for the client to poll) and a user_code (for the user to + enter in the browser). Once the user authenticates, the pending session is + linked to a real session, which the polling client then receives. + """ + + __tablename__ = "pending_sessions" + + hashed_device_code = Column(LargeBinary(32), primary_key=True, index=True, nullable=False) + user_code = Column(Unicode(8), index=True, nullable=False) + expiration_time = Column(DateTime(timezone=False), nullable=False) + session_id = Column(Integer, ForeignKey("sessions.id"), nullable=True) + + session = relationship("Session", back_populates="pending_sessions") diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py index c52d8f2..f1d9fcb 100644 --- a/bluesky_httpserver/schemas.py +++ b/bluesky_httpserver/schemas.py @@ -163,6 +163,23 @@ class RefreshToken(pydantic.BaseModel): refresh_token: str +class DeviceCode(pydantic.BaseModel): + """Schema for device code token polling request.""" + + device_code: str + + +class DeviceCodeResponse(pydantic.BaseModel): + """Schema for device code flow initiation response.""" + + authorization_uri: str + verification_uri: str + device_code: str + user_code: str + expires_in: int + interval: int + + class AuthenticationMode(str, enum.Enum): password = "password" external = "external" diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index ec69415..3c43529 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -195,3 +195,28 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES return True return False + + +# ============================================================================ +# OIDC Test Fixtures +# ============================================================================ + +@pytest.fixture +def oidc_base_url() -> str: + """Base URL for mock OIDC provider.""" + return "https://example.com/realms/example/" + + +@pytest.fixture +def well_known_response(oidc_base_url: str) -> dict: + """Mock OIDC well-known configuration response.""" + return { + "id_token_signing_alg_values_supported": ["RS256"], + "issuer": oidc_base_url.rstrip("/"), + "jwks_uri": f"{oidc_base_url}protocol/openid-connect/certs", + "authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth", + "token_endpoint": f"{oidc_base_url}protocol/openid-connect/token", + "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device", + "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout", + } + diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py new file mode 100644 index 0000000..30303e4 --- /dev/null +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -0,0 +1,224 @@ +"""Tests for OIDC Authenticator functionality.""" + +import time +from typing import Any, Tuple + +import httpx +import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from jose import ExpiredSignatureError, jwt +from jose.backends import RSAKey +from respx import MockRouter + +from bluesky_httpserver.authenticators import OIDCAuthenticator, ProxiedOIDCAuthenticator + + +@pytest.fixture +def oidc_well_known_url(oidc_base_url: str) -> str: + return f"{oidc_base_url}.well-known/openid-configuration" + + +@pytest.fixture +def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + """Generate RSA key pair for testing.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + return (private_key, public_key) + + +@pytest.fixture +def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: + """Create a JSON Web Key Set from the test keys.""" + _, public_key = keys + return [RSAKey(key=public_key, algorithm="RS256").to_dict()] + + +@pytest.fixture +def mock_oidc_server( + respx_mock: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +) -> MockRouter: + """Set up mock OIDC server endpoints.""" + respx_mock.get(oidc_well_known_url).mock( + return_value=httpx.Response(httpx.codes.OK, json=well_known_response) + ) + respx_mock.get(well_known_response["jwks_uri"]).mock( + return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) + ) + return respx_mock + + +def create_token(issued: bool, expired: bool) -> dict[str, Any]: + """Create a test JWT token.""" + now = time.time() + return { + "aud": "test_client", + "exp": (now - 1500) if expired else (now + 1500), + "iat": (now - 1500) if issued else (now + 1500), + "iss": "https://example.com/realms/example", + "sub": "test_user", + } + + +def encrypt_token(token: dict[str, Any], private_key: rsa.RSAPrivateKey) -> str: + """Encrypt a token with the test private key.""" + return jwt.encode( + token, + key=private_key, + algorithm="RS256", + headers={"kid": "test_key"}, + ) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestOIDCAuthenticator: + """Tests for OIDCAuthenticator class.""" + + def test_oidc_authenticator_caching( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], + ): + """Test that OIDC configuration is cached after first fetch.""" + authenticator = OIDCAuthenticator( + audience="test_client", + client_id="test_client", + client_secret="secret", + well_known_uri=oidc_well_known_url, + ) + + # Access multiple properties to ensure caching works + assert authenticator.client_id == "test_client" + assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] + assert ( + authenticator.id_token_signing_alg_values_supported + == well_known_response["id_token_signing_alg_values_supported"] + ) + assert authenticator.issuer == well_known_response["issuer"] + assert authenticator.jwks_uri == well_known_response["jwks_uri"] + assert authenticator.token_endpoint == well_known_response["token_endpoint"] + assert ( + authenticator.device_authorization_endpoint + == well_known_response["device_authorization_endpoint"] + ) + assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] + + # Should only call well-known endpoint once due to caching + assert len(mock_oidc_server.calls) == 1 + call_request = mock_oidc_server.calls[0].request + assert call_request.method == "GET" + assert call_request.url == oidc_well_known_url + + # Keys should also be cached + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 # Now also fetched JWKS + + # Multiple calls should still be cached + for _ in range(5): + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 # No new calls + + @pytest.mark.parametrize("issued", [True, False]) + @pytest.mark.parametrize("expired", [True, False]) + def test_oidc_token_decoding( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + issued: bool, + expired: bool, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + """Test token decoding with various validity scenarios.""" + private_key, _ = keys + authenticator = OIDCAuthenticator( + audience="test_client", + client_id="test_client", + client_secret="secret", + well_known_uri=oidc_well_known_url, + ) + + token = create_token(issued, expired) + encrypted = encrypt_token(token, private_key) + + if not expired: + # Non-expired tokens should decode successfully + decoded = authenticator.decode_token(encrypted) + assert decoded["sub"] == "test_user" + assert decoded["aud"] == "test_client" + else: + # Expired tokens should raise an error + with pytest.raises(ExpiredSignatureError): + authenticator.decode_token(encrypted) + + def test_oidc_authenticator_properties( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + ): + """Test that all authenticator properties are correctly set.""" + authenticator = OIDCAuthenticator( + audience="my_audience", + client_id="my_client_id", + client_secret="my_secret", + well_known_uri=oidc_well_known_url, + confirmation_message="Logged in as {id}", + redirect_on_success="https://app.example.com/success", + redirect_on_failure="https://app.example.com/failure", + ) + + assert authenticator.client_id == "my_client_id" + assert authenticator.confirmation_message == "Logged in as {id}" + assert authenticator.redirect_on_success == "https://app.example.com/success" + assert authenticator.redirect_on_failure == "https://app.example.com/failure" + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestProxiedOIDCAuthenticator: + """Tests for ProxiedOIDCAuthenticator class.""" + + @pytest.mark.asyncio + async def test_proxied_oidc_oauth2_schema( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + ): + """Test that ProxiedOIDCAuthenticator extracts bearer token correctly.""" + authenticator = ProxiedOIDCAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + ) + + # Create a mock request with Authorization header + test_request = httpx.Request( + "GET", + "http://example.com/api/test", + headers={"Authorization": "Bearer TEST_TOKEN"}, + ) + + # The oauth2_schema should extract the bearer token + token = await authenticator.oauth2_schema(test_request) + assert token == "TEST_TOKEN" + + def test_proxied_oidc_with_scopes( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + ): + """Test ProxiedOIDCAuthenticator with custom scopes.""" + authenticator = ProxiedOIDCAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes=["openid", "profile", "email"], + ) + + assert authenticator.scopes == ["openid", "profile", "email"] + assert authenticator.device_flow_client_id == "test_cli_client" diff --git a/requirements-dev.txt b/requirements-dev.txt index dd7212a..e47dd72 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,13 +3,16 @@ black codecov coverage +cryptography fastapi[all] flake8 isort pre-commit pytest +pytest-asyncio pytest-xprocess py +respx sphinx ipython numpydoc diff --git a/requirements.txt b/requirements.txt index 818362f..1377ef0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bluesky-queueserver bluesky-queueserver-api cachetools fastapi +httpx ldap3 orjson pamela From d90ad0cad8720d09786dfd0f02a5250c996f7165 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 17 Feb 2026 15:14:06 -0600 Subject: [PATCH 05/18] Removing some unnecessary code. --- bluesky_httpserver/_authentication.py | 2 -- bluesky_httpserver/authenticators.py | 1 - 2 files changed, 3 deletions(-) diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index a0d28b1..c745dff 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -721,8 +721,6 @@ async def device_code_submit( """ return HTMLResponse(content=error_html, status_code=403) - scopes = api_access_manager.get_user_scopes(username) - # Create the session session = await asyncio.get_running_loop().run_in_executor( None, _create_session_orm, settings, provider, username, db diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index e8d108d..a58fedf 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -222,7 +222,6 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: if response.is_error: logger.error("Authentication error: %r", response_body) return None - response_body = response.json() id_token = response_body["id_token"] # NOTE: We decode the id_token, not access_token, because: # 1. The id_token is the OIDC identity assertion meant for the client From 24857905f573f805940d6e3a0c4cefd409b87bf9 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 23 Feb 2026 10:06:43 -0600 Subject: [PATCH 06/18] Working example that does not require device-codes This solves the problem that what was implemented was actually authenticating the application and not the user like expected. It worked but it required that the user input a code. This solves that problem so that when you click the login link, if you are already logged in with you SSO provider you'll just automatically log in to the HTTP Server. Likewise if you use the bluesky queueserver api, when you call RM.Login you'll automatically be logged in, no user interaction required. --- bluesky_httpserver/_authentication.py | 193 +++++++++++------- .../config_schemas/examples/oidc_config.yml | 78 ------- .../config_schemas/service_configuration.yml | 43 ++-- 3 files changed, 128 insertions(+), 186 deletions(-) delete mode 100644 bluesky_httpserver/config_schemas/examples/oidc_config.yml diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index c745dff..0cb046f 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -597,6 +597,7 @@ async def device_code_authorize( "response_type": "code", "scope": "openid profile email", "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", + "state": pending_session["user_code"].replace("-", ""), } ) return { @@ -611,66 +612,23 @@ async def device_code_authorize( return device_code_authorize -def build_device_code_form_route(authenticator, provider): - """Build a GET route that shows the user code entry form.""" - - async def device_code_form( - request: Request, - code: str, - ): - """Show form for user to enter user code after browser auth.""" - action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" - html_content = f""" - - - - Authorize Session - - - -

Authorize Bluesky HTTP Server Session

-
- - - -
- -
- - -""" - return HTMLResponse(content=html_content) - - return device_code_form - - -def build_device_code_submit_route(authenticator, provider): - """Build a POST route that handles user code submission after browser auth.""" - - async def device_code_submit( - request: Request, - code: str = Form(), - user_code: str = Form(), - settings: BaseSettings = Depends(get_settings), - api_access_manager=Depends(get_api_access_manager), - ): - """Handle user code submission and link to authenticated session.""" - request.state.endpoint = "auth" - action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" - normalized_user_code = user_code.upper().replace("-", "").strip() +async def _complete_device_code_authorization( + request: Request, + authenticator, + provider: str, + code: str, + user_code: str, + settings: BaseSettings, + api_access_manager, +): + request.state.endpoint = "auth" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + normalized_user_code = user_code.upper().replace("-", "").strip() - with get_sessionmaker(settings.database_settings)() as db: - pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) - if pending_session is None: - error_html = f""" + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) + if pending_session is None: + error_html = f""" Error @@ -684,12 +642,12 @@ async def device_code_submit( """ - return HTMLResponse(content=error_html, status_code=401) + return HTMLResponse(content=error_html, status_code=401) - # Authenticate with the OIDC provider using the authorization code - user_session_state = await authenticator.authenticate(request) - if not user_session_state: - error_html = """ + # Authenticate with the OIDC provider using the authorization code + user_session_state = await authenticator.authenticate(request) + if not user_session_state: + error_html = """ Authentication Failed @@ -702,11 +660,11 @@ async def device_code_submit( """ - return HTMLResponse(content=error_html, status_code=401) + return HTMLResponse(content=error_html, status_code=401) - username = user_session_state.user_name - if not api_access_manager.is_user_known(username): - error_html = f""" + username = user_session_state.user_name + if not api_access_manager.is_user_known(username): + error_html = f""" Authorization Failed @@ -719,19 +677,19 @@ async def device_code_submit( """ - return HTMLResponse(content=error_html, status_code=403) + return HTMLResponse(content=error_html, status_code=403) - # Create the session - session = await asyncio.get_running_loop().run_in_executor( - None, _create_session_orm, settings, provider, username, db - ) + # Create the session + session = await asyncio.get_running_loop().run_in_executor( + None, _create_session_orm, settings, provider, username, db + ) - # Link the pending session to the real session - pending_session.session_id = session.id - db.add(pending_session) - db.commit() + # Link the pending session to the real session + pending_session.session_id = session.id + db.add(pending_session) + db.commit() - success_html = f""" + success_html = f""" Success @@ -744,7 +702,84 @@ async def device_code_submit( """ - return HTMLResponse(content=success_html) + return HTMLResponse(content=success_html) + + +def build_device_code_form_route(authenticator, provider): + """Build a GET route that shows the user code entry form.""" + + async def device_code_form( + request: Request, + code: str, + state: Optional[str] = Query(None), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Show form for user to enter user code after browser auth.""" + if state: + return await _complete_device_code_authorization( + request=request, + authenticator=authenticator, + provider=provider, + code=code, + user_code=state, + settings=settings, + api_access_manager=api_access_manager, + ) + + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + html_content = f""" + + + + Authorize Session + + + +

Authorize Bluesky HTTP Server Session

+
+ + + +
+ +
+ + +""" + return HTMLResponse(content=html_content) + + return device_code_form + + +def build_device_code_submit_route(authenticator, provider): + """Build a POST route that handles user code submission after browser auth.""" + + async def device_code_submit( + request: Request, + code: str = Form(), + user_code: str = Form(), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Handle user code submission and link to authenticated session.""" + return await _complete_device_code_authorization( + request=request, + authenticator=authenticator, + provider=provider, + code=code, + user_code=user_code, + settings=settings, + api_access_manager=api_access_manager, + ) return device_code_submit diff --git a/bluesky_httpserver/config_schemas/examples/oidc_config.yml b/bluesky_httpserver/config_schemas/examples/oidc_config.yml deleted file mode 100644 index c2f8d24..0000000 --- a/bluesky_httpserver/config_schemas/examples/oidc_config.yml +++ /dev/null @@ -1,78 +0,0 @@ -# Example OIDC Configuration for Bluesky HTTP Server -# -# This example shows how to configure OIDC (OpenID Connect) authentication. -# OIDC is used by providers like Google, Microsoft Entra (Azure AD), Okta, Keycloak, etc. -# -# Required environment variables: -# - OIDC_CLIENT_ID: The client ID from your OIDC provider -# - OIDC_CLIENT_SECRET: The client secret from your OIDC provider -# - OIDC_WELL_KNOWN_URI: The .well-known/openid-configuration URL -# -# Example for Google: -# OIDC_WELL_KNOWN_URI=https://accounts.google.com/.well-known/openid-configuration -# -# Example for Microsoft Entra (Azure AD): -# OIDC_WELL_KNOWN_URI=https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration -# -# Example for Keycloak: -# OIDC_WELL_KNOWN_URI=https://your-keycloak-server/realms/{realm}/.well-known/openid-configuration - -authentication: - providers: - - provider: oidc - authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator - args: - # The audience should match the client_id or be a value expected by your OIDC provider - audience: ${OIDC_CLIENT_ID} - client_id: ${OIDC_CLIENT_ID} - client_secret: ${OIDC_CLIENT_SECRET} - well_known_uri: ${OIDC_WELL_KNOWN_URI} - confirmation_message: "You have successfully logged in via OIDC as {id}." - # Optional: redirect URLs after authentication - # redirect_on_success: https://your-app.example.com/success - # redirect_on_failure: https://your-app.example.com/login-failed - - # Secret keys used to sign secure tokens (generate with: openssl rand -hex 32) - secret_keys: - - ${SECRET_KEY} - - # Allow unauthenticated access to public endpoints - allow_anonymous_access: false - - # Token lifetimes (in seconds) - access_token_max_age: 900 # 15 minutes - refresh_token_max_age: 604800 # 7 days - -# Database for storing sessions and API keys -database: - uri: ${DATABASE_URI} - pool_size: 5 - pool_pre_ping: true - -# API access control - configure which users have access -api_access: - policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl - args: - users: - # Add users identified by their OIDC subject ID (sub claim) - # The ID typically looks like an email or UUID depending on your OIDC provider - user@example.com: - roles: - - admin - - user - -# Resource access control -resource_access: - policy: bluesky_httpserver.authorization:DefaultResourceAccessControl - args: - default_group: root - -# Queue Server connection -qserver_zmq_configuration: - control_address: tcp://localhost:60615 - info_address: tcp://localhost:60625 - -# HTTP Server configuration -uvicorn: - host: 0.0.0.0 - port: 8000 diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index 12f01a3..a76e4d3 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -47,14 +47,14 @@ properties: properties: custom_routers: type: array - items: + item: type: string description: | The list of Python modules with custom routers. Overrides the list of modules set using QSERVER_HTTP_CUSTOM_ROUTERS environment variable. custom_modules: type: array - items: + item: type: string description: | THE FUNCTIONALITY WILL BE DEPRECATED IN FAVOR OF CUSTOM ROUTERS. Overrides the list of modules @@ -65,7 +65,7 @@ properties: properties: providers: type: array - items: + item: type: object additionalProperties: false required: @@ -94,34 +94,19 @@ properties: ```yaml authenticator: bluesky_httpserver.authenticators:DummyAuthenticator ``` - args: - type: object - description: | - Named arguments to pass to Authenticator. If there are none, - `args` may be omitted or empty. + args: + type: object + description: | + Named arguments to pass to Authenticator. If there are none, + `args` may be omitted or empty. - Example: + Example: - ```yaml - authenticator: bluesky_httpserver.authenticators:PAMAuthenticator - args: - service: "custom_service" - ``` - # qserver_admins: - # type: array - # items: - # type: object - # additionalProperties: false - # required: - # - provider - # - id - # properties: - # provider: - # type: string - # id: - # type: string - # description: | - # Give users with these identities 'admin' Role. + ```yaml + authenticator: bluesky_httpserver.authenticators:PAMAuthenticator + args: + service: "custom_service" + ``` secret_keys: type: array items: From 96cd9db5b4f23104d737cbc90f00bc9bd67d07c4 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 23 Feb 2026 11:00:17 -0600 Subject: [PATCH 07/18] Fixes from running black --- bluesky_httpserver/database/core.py | 4 +++- bluesky_httpserver/tests/conftest.py | 2 +- bluesky_httpserver/tests/test_oidc_authenticators.py | 9 ++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index f096edc..52d102f 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -304,7 +304,9 @@ def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optio Returns None if the pending session is not found or has expired. """ hashed_device_code = hashlib.sha256(device_code).digest() - pending_session = db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + pending_session = ( + db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + ) if pending_session is None: return None if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index 3c43529..8851e71 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -201,6 +201,7 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES # OIDC Test Fixtures # ============================================================================ + @pytest.fixture def oidc_base_url() -> str: """Base URL for mock OIDC provider.""" @@ -219,4 +220,3 @@ def well_known_response(oidc_base_url: str) -> dict: "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device", "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout", } - diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py index 30303e4..f3249cd 100644 --- a/bluesky_httpserver/tests/test_oidc_authenticators.py +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -41,9 +41,7 @@ def mock_oidc_server( json_web_keyset: list[dict[str, Any]], ) -> MockRouter: """Set up mock OIDC server endpoints.""" - respx_mock.get(oidc_well_known_url).mock( - return_value=httpx.Response(httpx.codes.OK, json=well_known_response) - ) + respx_mock.get(oidc_well_known_url).mock(return_value=httpx.Response(httpx.codes.OK, json=well_known_response)) respx_mock.get(well_known_response["jwks_uri"]).mock( return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) ) @@ -101,10 +99,7 @@ def test_oidc_authenticator_caching( assert authenticator.issuer == well_known_response["issuer"] assert authenticator.jwks_uri == well_known_response["jwks_uri"] assert authenticator.token_endpoint == well_known_response["token_endpoint"] - assert ( - authenticator.device_authorization_endpoint - == well_known_response["device_authorization_endpoint"] - ) + assert authenticator.device_authorization_endpoint == well_known_response["device_authorization_endpoint"] assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] # Should only call well-known endpoint once due to caching From 967fcbab3de634ce66f79ba9451173435476edbe Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 23 Feb 2026 13:07:07 -0600 Subject: [PATCH 08/18] Adding documentation on how to use OIDC --- docs/source/configuration.rst | 79 +++++++++++++++++++++++++++++++++++ docs/source/usage.rst | 43 +++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index eb31efa..8852a31 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -294,6 +294,85 @@ See the documentation on ``LDAPAuthenticator`` for more details. authenticators.LDAPAuthenticator +OIDC Authenticator +++++++++++++++++++ + +``OIDCAuthenticator`` integrates the server with third-party OpenID Connect providers +such as Google, Microsoft Entra ID, ORCID and others. The server does not process user +passwords directly: authentication is delegated to the provider and the server validates +the returned OIDC token. + +General setup steps: + +#. Register an application with the OIDC provider. +#. Configure redirect URIs for the provider application. For provider name ``entra`` and + host ``https://your-server.example`` the redirect URIs are: + + - ``https://your-server.example/api/auth/provider/entra/code`` + - ``https://your-server.example/api/auth/provider/entra/device_code`` + +#. Store the client secret in environment variable and reference it in config. +#. Use provider's ``.well-known/openid-configuration`` URL. + +Typical ``well_known_uri`` values: + +- Google: ``https://accounts.google.com/.well-known/openid-configuration`` +- Microsoft Entra ID: ``https://login.microsoftonline.com//v2.0/.well-known/openid-configuration`` +- ORCID: ``https://orcid.org/.well-known/openid-configuration`` + +Example configuration (Microsoft Entra ID):: + + authentication: + providers: + - provider: entra + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + audience: 00000000-0000-0000-0000-000000000000 + client_id: 00000000-0000-0000-0000-000000000000 + client_secret: ${BSKY_ENTRA_SECRET} + well_known_uri: https://login.microsoftonline.com//v2.0/.well-known/openid-configuration + confirmation_message: "You have logged in successfully." + api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + : + roles: + - admin + - expert + +Example configuration (Google):: + + authentication: + providers: + - provider: google + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + audience: + client_id: + client_secret: ${BSKY_GOOGLE_SECRET} + well_known_uri: https://accounts.google.com/.well-known/openid-configuration + api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + : + roles: user + +.. note:: + + The name used in ``api_access/args/users`` must match the identity string produced by + the authenticator for your provider configuration. Verify with ``/api/auth/whoami`` after + successful login. + +See the documentation on ``OIDCAuthenticator`` for parameter details. + +.. autosummary:: + :nosignatures: + :toctree: generated + + authenticators.OIDCAuthenticator + Expiration Time for Tokens and Sessions +++++++++++++++++++++++++++++++++++++++ diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 5e1e9b3..d6c3a10 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -154,6 +154,49 @@ Then users ``bob``, ``alice`` and ``tom`` can log into the server as :: If authentication is successful, then the server returns access and refresh tokens. +Logging in with OIDC Providers (Google, Entra, ORCID, ...) +----------------------------------------------------------- + +For providers configured with ``OIDCAuthenticator``, use provider-specific endpoints +under ``/api/auth/provider//...``. + +Browser-first flow +++++++++++++++++++ + +If you are already in a browser context, open: + +``/api/auth/provider//authorize`` + +This redirects to the OIDC provider login page and then back to the server callback. + +CLI/device flow ++++++++++++++++ + +For terminal clients, start with ``POST /api/auth/provider//authorize``. +The response includes: + +- ``authorization_uri``: open this URL in a browser +- ``verification_uri``: polling endpoint for the terminal client +- ``device_code`` and ``interval``: values for polling + +Example using ``httpie`` (provider ``entra``):: + + http POST http://localhost:60610/api/auth/provider/entra/authorize + +After opening ``authorization_uri`` in a browser and completing provider login, +poll ``verification_uri`` using ``device_code`` until tokens are issued:: + + http POST http://localhost:60610/api/auth/provider/entra/token \ + device_code='' + +When authorization is still pending, the endpoint returns ``authorization_pending``. +When complete, it returns access and refresh tokens. + +.. note:: + + In common same-device flows the callback can complete automatically without manually + typing the user code. Manual code entry remains available as a fallback path. + Generating API Keys ------------------- From 5906b28b889224527c937da912944460513723ac Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 24 Feb 2026 16:17:24 -0600 Subject: [PATCH 09/18] Fixes for unit tests, moving start LDAP These should correct some of the problems in the last CI workflow. I moved the LDAP and docker image into the continuous_integration folder so it matches tiled. --- .github/workflows/testing.yml | 2 +- bluesky_httpserver/_authentication.py | 79 +++++- bluesky_httpserver/tests/conftest.py | 19 +- .../tests/test_authenticators.py | 245 +++++++++++++++++- .../docker-configs/ldap-docker-compose.yml | 6 +- continuous_integration/scripts/start_LDAP.sh | 7 + docs/source/usage.rst | 4 +- start_LDAP.sh | 8 - 8 files changed, 340 insertions(+), 30 deletions(-) rename {.github/workflows => continuous_integration}/docker-configs/ldap-docker-compose.yml (74%) create mode 100755 continuous_integration/scripts/start_LDAP.sh delete mode 100644 start_LDAP.sh diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b7d9d54..5355c05 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -43,7 +43,7 @@ jobs: popd # Start LDAP - source start_LDAP.sh + source continuous_integration/scripts/start_LDAP.sh # These packages are installed in the base environment but may be older # versions. Explicitly upgrade them because they often create diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index 0cb046f..c1144f5 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -632,12 +632,22 @@ async def _complete_device_code_authorization( Error - +

Authorization Failed

-
Invalid user code. It may have been mistyped, or the pending request may have expired.
+
+ Invalid user code. It may have been mistyped, or the pending request may have expired. +

Try again @@ -651,12 +661,23 @@ async def _complete_device_code_authorization( Authentication Failed - +

Authentication Failed

-
User code was correct but authentication with the identity provider failed. Please contact the administrator.
+
+ User code was correct but authentication with the identity provider failed. + Please contact the administrator. +
""" @@ -668,8 +689,16 @@ async def _complete_device_code_authorization( Authorization Failed - +

Authorization Failed

@@ -693,12 +722,23 @@ async def _complete_device_code_authorization( Success - +

Success!

-
You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.
+
+ You have been authenticated. Return to your terminal application - + within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in. +
""" @@ -738,8 +778,21 @@ async def device_code_form( h1 {{ color: #333; }} form {{ margin-top: 20px; }} label {{ display: block; margin-bottom: 10px; }} - input[type="text"] {{ padding: 10px; font-size: 16px; width: 200px; text-transform: uppercase; }} - input[type="submit"] {{ padding: 10px 20px; font-size: 16px; background-color: #007bff; color: white; border: none; cursor: pointer; margin-top: 10px; }} + input[type="text"] {{ + padding: 10px; + font-size: 16px; + width: 200px; + text-transform: uppercase; + }} + input[type="submit"] {{ + padding: 10px 20px; + font-size: 16px; + background-color: #007bff; + color: white; + border: none; + cursor: pointer; + margin-top: 10px; + }} input[type="submit"]:hover {{ background-color: #0056b3; }} diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index 8851e71..d5cafdb 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -18,6 +18,22 @@ _user_group = "primary" +def _wait_for_http_server_ready(*, timeout=10, request_prefix="/api"): + """Wait until HTTP server accepts connections and responds to /status.""" + t_stop = ttime.time() + timeout + url = f"http://{SERVER_ADDRESS}:{SERVER_PORT}{request_prefix}/status" + while ttime.time() < t_stop: + try: + response = requests.get(url, timeout=0.5) + # Any HTTP response means the server is up (auth may still reject request). + if response.status_code: + return + except requests.RequestException: + pass + ttime.sleep(0.1) + raise TimeoutError(f"HTTP server is not ready after {timeout} s: {url}") + + @pytest.fixture(scope="module") def fastapi_server(xprocess): class Starter(ProcessStarter): @@ -29,6 +45,7 @@ class Starter(ProcessStarter): # args = f"start-bluesky-httpserver --host={SERVER_ADDRESS} --port {SERVER_PORT}".split() xprocess.ensure("fastapi_server", Starter) + _wait_for_http_server_ready() yield @@ -55,7 +72,7 @@ class Starter(ProcessStarter): args = f"uvicorn --host={http_server_host} --port {http_server_port} {bqss.__name__}:app".split() xprocess.ensure("fastapi_server", Starter) - ttime.sleep(1) + _wait_for_http_server_ready() yield start diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 183ce75..28e2601 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -1,9 +1,17 @@ import asyncio +import time +from typing import Any, Tuple +import httpx import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from jose import ExpiredSignatureError, jwt +from jose.backends import RSAKey +from respx import MockRouter +from starlette.datastructures import QueryParams, URL # fmt: off -from ..authenticators import LDAPAuthenticator, UserSessionState +from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ @@ -41,3 +49,238 @@ async def testing(): assert await authenticator.authenticate("user02", "password2a") is None asyncio.run(testing()) + + +@pytest.fixture +def oidc_well_known_url(oidc_base_url: str) -> str: + return f"{oidc_base_url}.well-known/openid-configuration" + + +@pytest.fixture +def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + return (private_key, public_key) + + +@pytest.fixture +def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: + _, public_key = keys + return [RSAKey(key=public_key, algorithm="RS256").to_dict()] + + +@pytest.fixture +def mock_oidc_server( + respx_mock: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +) -> MockRouter: + respx_mock.get(oidc_well_known_url).mock(return_value=httpx.Response(httpx.codes.OK, json=well_known_response)) + respx_mock.get(well_known_response["jwks_uri"]).mock( + return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) + ) + return respx_mock + + +def token(issued: bool, expired: bool) -> dict[str, str]: + now = time.time() + return { + "aud": "tiled", + "exp": (now - 1500) if expired else (now + 1500), + "iat": (now - 1500) if issued else (now + 1500), + "iss": "https://example.com/realms/example", + "sub": "Jane Doe", + } + + +def encrypted_token(token_data: dict[str, str], private_key: rsa.RSAPrivateKey) -> str: + return jwt.encode( + token_data, + key=private_key, + algorithm="RS256", + headers={"kid": "secret"}, + ) + + +def test_oidc_authenticator_caching( + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +): + authenticator = OIDCAuthenticator("tiled", "tiled", "secret", well_known_uri=oidc_well_known_url) + assert authenticator.client_id == "tiled" + assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] + assert authenticator.id_token_signing_alg_values_supported == well_known_response[ + "id_token_signing_alg_values_supported" + ] + assert authenticator.issuer == well_known_response["issuer"] + assert authenticator.jwks_uri == well_known_response["jwks_uri"] + assert authenticator.token_endpoint == well_known_response["token_endpoint"] + assert authenticator.device_authorization_endpoint == well_known_response["device_authorization_endpoint"] + assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] + + assert len(mock_oidc_server.calls) == 1 + call_request = mock_oidc_server.calls[0].request + assert call_request.method == "GET" + assert call_request.url == oidc_well_known_url + + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 + keys_request = mock_oidc_server.calls[1].request + assert keys_request.method == "GET" + assert keys_request.url == well_known_response["jwks_uri"] + + for _ in range(10): + assert authenticator.keys() == json_web_keyset + + assert len(mock_oidc_server.calls) == 2 + + +@pytest.mark.parametrize("issued", [True, False]) +@pytest.mark.parametrize("expired", [True, False]) +def test_oidc_decoding( + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + issued: bool, + expired: bool, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], +): + private_key, _ = keys + authenticator = OIDCAuthenticator("tiled", "tiled", "secret", well_known_uri=oidc_well_known_url) + access_token = token(issued, expired) + encrypted_access_token = encrypted_token(access_token, private_key) + + if not expired: + assert authenticator.decode_token(encrypted_access_token) == access_token + else: + with pytest.raises(ExpiredSignatureError): + authenticator.decode_token(encrypted_access_token) + + +@pytest.mark.asyncio +async def test_proxied_oidc_token_retrieval(oidc_well_known_url: str, mock_oidc_server: MockRouter): + authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, + device_flow_client_id="tiled-cli") + test_request = httpx.Request("GET", "http://example.com", headers={"Authorization": "bearer FOO"}) + + assert "FOO" == await authenticator.oauth2_schema(test_request) + + +def create_mock_oidc_request(query_params=None): + if query_params is None: + query_params = {} + + class MockRequest: + def __init__(self, request_query_params): + self.query_params = QueryParams(request_query_params) + self.scope = { + "type": "http", + "scheme": "http", + "server": ("localhost", 8000), + "path": "/api/v1/auth/provider/orcid/code", + "headers": [], + } + self.headers = {"host": "localhost:8000"} + self.url = URL("http://localhost:8000/api/v1/auth/provider/orcid/code") + + return MockRequest(query_params) + + +@pytest.mark.asyncio +async def test_OIDCAuthenticator_mock( + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + monkeypatch, +): + mock_jwt_payload = { + "sub": "0009-0008-8698-7745", + "aud": "APP-TEST-CLIENT-ID", + "iss": well_known_response["issuer"], + "exp": 9999999999, + "iat": 1000000000, + "given_name": "Test User", + } + + mock_oidc_server.post(well_known_response["token_endpoint"]).mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "mock-access-token", + "id_token": "mock-id-token", + "token_type": "bearer", + }, + ) + ) + + authenticator = OIDCAuthenticator( + audience="APP-TEST-CLIENT-ID", + client_id="APP-TEST-CLIENT-ID", + client_secret="test-secret", + well_known_uri=oidc_well_known_url, + ) + + mock_request = create_mock_oidc_request({"code": "test-auth-code"}) + + def mock_jwt_decode(*args, **kwargs): + return mock_jwt_payload + + def mock_jwk_construct(*args, **kwargs): + class MockJWK: + pass + + return MockJWK() + + monkeypatch.setattr("jose.jwt.decode", mock_jwt_decode) + monkeypatch.setattr("jose.jwk.construct", mock_jwk_construct) + + user_session = await authenticator.authenticate(mock_request) + + assert user_session is not None + assert user_session.user_name == "0009-0008-8698-7745" + + +@pytest.mark.asyncio +async def test_OIDCAuthenticator_missing_code_parameter(oidc_well_known_url: str): + authenticator = OIDCAuthenticator( + audience="APP-TEST-CLIENT-ID", + client_id="APP-TEST-CLIENT-ID", + client_secret="test-secret", + well_known_uri=oidc_well_known_url, + ) + + mock_request = create_mock_oidc_request({}) + + result = await authenticator.authenticate(mock_request) + assert result is None + + +@pytest.mark.asyncio +async def test_OIDCAuthenticator_token_exchange_failure( + oidc_well_known_url: str, + mock_oidc_server, + well_known_response, +): + mock_oidc_server.post(well_known_response["token_endpoint"]).mock( + return_value=httpx.Response( + 400, + json={ + "error": "invalid_client", + "error_description": "Client not found: APP-TEST-CLIENT-ID", + }, + ) + ) + + authenticator = OIDCAuthenticator( + audience="APP-TEST-CLIENT-ID", + client_id="APP-TEST-CLIENT-ID", + client_secret="test-secret", + well_known_uri=oidc_well_known_url, + ) + + mock_request = create_mock_oidc_request({"code": "invalid-code"}) + + result = await authenticator.authenticate(mock_request) + assert result is None diff --git a/.github/workflows/docker-configs/ldap-docker-compose.yml b/continuous_integration/docker-configs/ldap-docker-compose.yml similarity index 74% rename from .github/workflows/docker-configs/ldap-docker-compose.yml rename to continuous_integration/docker-configs/ldap-docker-compose.yml index 5cf12a8..2b2c45a 100644 --- a/.github/workflows/docker-configs/ldap-docker-compose.yml +++ b/continuous_integration/docker-configs/ldap-docker-compose.yml @@ -1,8 +1,6 @@ -version: '2' - services: openldap: - image: docker.io/bitnami/openldap:latest + image: osixia/openldap:latest ports: - '1389:1389' - '1636:1636' @@ -12,7 +10,7 @@ services: - LDAP_USERS=user01,user02 - LDAP_PASSWORDS=password1,password2 volumes: - - 'openldap_data:/bitnami/openldap' + - 'openldap_data:/var/lib/ldap' volumes: openldap_data: diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh new file mode 100755 index 0000000..c6a5fbc --- /dev/null +++ b/continuous_integration/scripts/start_LDAP.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +# Start LDAP server in docker container +docker pull osixia/openldap:latest +docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml up -d +docker ps \ No newline at end of file diff --git a/docs/source/usage.rst b/docs/source/usage.rst index d6c3a10..6cd168c 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -161,7 +161,7 @@ For providers configured with ``OIDCAuthenticator``, use provider-specific endpo under ``/api/auth/provider//...``. Browser-first flow -++++++++++++++++++ +~~~~~~~~~~~~~~~~~ If you are already in a browser context, open: @@ -170,7 +170,7 @@ If you are already in a browser context, open: This redirects to the OIDC provider login page and then back to the server callback. CLI/device flow -+++++++++++++++ +~~~~~~~~~~~~~~~ For terminal clients, start with ``POST /api/auth/provider//authorize``. The response includes: diff --git a/start_LDAP.sh b/start_LDAP.sh deleted file mode 100644 index 8b612de..0000000 --- a/start_LDAP.sh +++ /dev/null @@ -1,8 +0,0 @@ - -#!/bin/bash -set -e - -# Start LDAP server in docker container -# sudo docker pull osixia/openldap:latest -sudo docker compose -f .github/workflows/docker-configs/ldap-docker-compose.yml up -d -sudo docker ps From 28483f97c8b90ca330ba289cec09e68235a5c83e Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 24 Feb 2026 16:19:37 -0600 Subject: [PATCH 10/18] fixing pre-commit issues --- bluesky_httpserver/tests/test_authenticators.py | 4 ++-- continuous_integration/scripts/start_LDAP.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 28e2601..53c6bbe 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -8,7 +8,7 @@ from jose import ExpiredSignatureError, jwt from jose.backends import RSAKey from respx import MockRouter -from starlette.datastructures import QueryParams, URL +from starlette.datastructures import URL, QueryParams # fmt: off from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState @@ -161,7 +161,7 @@ def test_oidc_decoding( @pytest.mark.asyncio async def test_proxied_oidc_token_retrieval(oidc_well_known_url: str, mock_oidc_server: MockRouter): - authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, + authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, device_flow_client_id="tiled-cli") test_request = httpx.Request("GET", "http://example.com", headers={"Authorization": "bearer FOO"}) diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh index c6a5fbc..ecfa1cf 100755 --- a/continuous_integration/scripts/start_LDAP.sh +++ b/continuous_integration/scripts/start_LDAP.sh @@ -4,4 +4,4 @@ set -e # Start LDAP server in docker container docker pull osixia/openldap:latest docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml up -d -docker ps \ No newline at end of file +docker ps From a4551e34e70394a8563cc88de4e00c89c43a1c8b Mon Sep 17 00:00:00 2001 From: David Pastl Date: Wed, 25 Feb 2026 08:54:52 -0600 Subject: [PATCH 11/18] fixing documentation issues This addresses documentation problems, the levels were incorrect as I did not understand what the next level should have been in the docs. I've also updated the usage documentation a little to be more useful. --- docs/source/usage.rst | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 6cd168c..299bdcb 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -161,7 +161,7 @@ For providers configured with ``OIDCAuthenticator``, use provider-specific endpo under ``/api/auth/provider//...``. Browser-first flow -~~~~~~~~~~~~~~~~~ +****************** If you are already in a browser context, open: @@ -169,10 +169,22 @@ If you are already in a browser context, open: This redirects to the OIDC provider login page and then back to the server callback. +This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting +the authorization URI from the server:: + + http POST http://localhost:60610/api/auth/provider/entra/authorize + +Which will return a token back to the bluesky http server after the user logs in to the provider +in their browser (or automatically if already logged in). The user then gets a token +for the bluesky HTTP server to use for subsequent API requests. This flow can be used +even when using the bluesky queueserver api in a terminal so long as that session can +spawn a browser for the user to log in to the provider. + CLI/device flow -~~~~~~~~~~~~~~~ +*************** -For terminal clients, start with ``POST /api/auth/provider//authorize``. +For terminal clients (i.e. no browser possible), start with +``POST /api/auth/provider//authorize``. The response includes: - ``authorization_uri``: open this URL in a browser From 8fa89ab6509052f584c99b46545bee510ee0419c Mon Sep 17 00:00:00 2001 From: David Pastl Date: Fri, 13 Mar 2026 12:02:04 -0600 Subject: [PATCH 12/18] Adding in helper scripts for testing These allow for running the unit tests in a containerized system just like how they are done in the ci pipeline, but locally and in a way that can maximize processor usage and minimize runtime. --- docker/test.Dockerfile | 29 ++ scripts/docker/run_shard_in_container.sh | 80 ++++ scripts/run_ci_docker_parallel.sh | 480 +++++++++++++++++++++++ 3 files changed, 589 insertions(+) create mode 100644 docker/test.Dockerfile create mode 100755 scripts/docker/run_shard_in_container.sh create mode 100755 scripts/run_ci_docker_parallel.sh diff --git a/docker/test.Dockerfile b/docker/test.Dockerfile new file mode 100644 index 0000000..2e994cf --- /dev/null +++ b/docker/test.Dockerfile @@ -0,0 +1,29 @@ +ARG PYTHON_VERSION=3.13 +FROM python:${PYTHON_VERSION}-slim + +ENV PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bash \ + build-essential \ + git \ + redis-server \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY requirements.txt requirements-dev.txt ./ +COPY pyproject.toml setup.py setup.cfg MANIFEST.in versioneer.py README.rst AUTHORS.rst LICENSE ./ +COPY bluesky_httpserver ./bluesky_httpserver + +RUN python -m pip install --upgrade pip setuptools wheel numpy && \ + python -m pip install git+https://github.com/bluesky/bluesky-queueserver.git && \ + python -m pip install git+https://github.com/bluesky/bluesky-queueserver-api.git && \ + python -m pip install -r requirements-dev.txt && \ + python -m pip install . + +COPY scripts/docker/run_shard_in_container.sh /usr/local/bin/run_shard_in_container.sh +RUN chmod +x /usr/local/bin/run_shard_in_container.sh + +ENTRYPOINT ["/usr/local/bin/run_shard_in_container.sh"] diff --git a/scripts/docker/run_shard_in_container.sh b/scripts/docker/run_shard_in_container.sh new file mode 100755 index 0000000..7fc23a7 --- /dev/null +++ b/scripts/docker/run_shard_in_container.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -euo pipefail + +SHARD_GROUP="${SHARD_GROUP:-1}" +SHARD_COUNT="${SHARD_COUNT:-3}" +ARTIFACTS_DIR="${ARTIFACTS_DIR:-/artifacts}" +PYTEST_EXTRA_ARGS="${PYTEST_EXTRA_ARGS:-}" + +mkdir -p "$ARTIFACTS_DIR" + +if [[ "$SHARD_GROUP" -lt 1 || "$SHARD_COUNT" -lt 1 || "$SHARD_GROUP" -gt "$SHARD_COUNT" ]]; then + echo "Invalid shard settings: SHARD_GROUP=$SHARD_GROUP SHARD_COUNT=$SHARD_COUNT" >&2 + exit 2 +fi + +export COVERAGE_FILE="$ARTIFACTS_DIR/.coverage.${SHARD_GROUP}" + +redis-server --save "" --appendonly no --daemonize yes +for _ in $(seq 1 50); do + if redis-cli ping >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! redis-cli ping >/dev/null 2>&1; then + echo "Failed to start redis-server inside container" >&2 + exit 2 +fi + +mapfile -t shard_tests < <( + python - <<'PY' "$SHARD_GROUP" "$SHARD_COUNT" +import glob +import sys + +group = int(sys.argv[1]) +count = int(sys.argv[2]) + +tests = sorted(glob.glob("bluesky_httpserver/tests/test_*.py")) +selected = [path for idx, path in enumerate(tests) if idx % count == (group - 1)] + +for path in selected: + print(path) +PY +) + +if [[ "${#shard_tests[@]}" -eq 0 ]]; then + echo "No tests selected for shard ${SHARD_GROUP}/${SHARD_COUNT}; treating as success." + exit 0 +fi + +pytest_cmd=( + coverage + run + -m + pytest + --junitxml="$ARTIFACTS_DIR/junit.${SHARD_GROUP}.xml" + -vv +) + +if [[ -n "$PYTEST_EXTRA_ARGS" ]]; then + read -r -a extra_args <<< "$PYTEST_EXTRA_ARGS" + pytest_cmd+=("${extra_args[@]}") +fi + +pytest_cmd+=("${shard_tests[@]}") + +set +e +"${pytest_cmd[@]}" +test_status=$? +set -e + +if [[ "$test_status" -eq 5 ]]; then + echo "Pytest collected no tests for shard ${SHARD_GROUP}/${SHARD_COUNT}; treating as success." + test_status=0 +fi + +redis-cli shutdown nosave >/dev/null 2>&1 || true + +exit "$test_status" diff --git a/scripts/run_ci_docker_parallel.sh b/scripts/run_ci_docker_parallel.sh new file mode 100755 index 0000000..c9caee7 --- /dev/null +++ b/scripts/run_ci_docker_parallel.sh @@ -0,0 +1,480 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +IMAGE_TAG_BASE="bluesky-httpserver-test:local" +WORKER_COUNT="3" +CHUNK_COUNT="" +PYTHON_VERSIONS="latest" +PYTEST_EXTRA_ARGS="" +ARTIFACTS_DIR="$ROOT_DIR/.docker-test-artifacts" +DOCKER_NETWORK_NAME="bhs-ci-net" +LDAP_CONTAINER_NAME="bhs-ci-ldap" + +SUMMARY_TSV="" +SUMMARY_FAIL_LOGS="" +SUMMARY_TXT="" +SUMMARY_JSON="" +TESTS_START_EPOCH="" +TESTS_START_HUMAN="" + +SUPPORTED_PYTHON_VERSIONS=("3.10" "3.11" "3.12" "3.13") + +usage() { + cat <<'EOF' +Run bluesky-httpserver unit tests in Docker with dynamic chunk dispatch and optional Python-version matrix. + +Usage: + scripts/run_ci_docker_parallel.sh [options] + +Options: + --workers N, --worker-count N + Number of concurrent chunk workers (default: 3). + + --chunks N, --chunk-count N + Number of total chunks/splits to execute per Python version. + Default: workers * 3. + + --python-versions VALUE + Python version selection: latest | all | comma-separated list. + Examples: latest, all, 3.12, 3.11,3.13 + Default: latest (currently 3.13). + + --pytest-args "ARGS" + Extra arguments passed to pytest in each chunk. + Example: --pytest-args "-k oidc --maxfail=1" + + --artifacts-dir PATH + Output directory for all artifacts. + Default: .docker-test-artifacts under repository root. + + --image-tag TAG + Base docker image tag. Per-version tags will append -py. + Default: bluesky-httpserver-test:local + + -h, --help + Show this help message. + +Examples: + scripts/run_ci_docker_parallel.sh + scripts/run_ci_docker_parallel.sh --workers 8 --chunks 24 + scripts/run_ci_docker_parallel.sh --python-versions all --workers 8 --chunks 24 + scripts/run_ci_docker_parallel.sh --python-versions 3.11,3.13 --pytest-args "-k test_access_control" +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --workers|--worker-count) + WORKER_COUNT="$2" + shift 2 + ;; + --chunks|--chunk-count) + CHUNK_COUNT="$2" + shift 2 + ;; + --python-versions) + PYTHON_VERSIONS="$2" + shift 2 + ;; + --pytest-args) + PYTEST_EXTRA_ARGS="$2" + shift 2 + ;; + --artifacts-dir) + ARTIFACTS_DIR="$2" + shift 2 + ;; + --image-tag) + IMAGE_TAG_BASE="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + usage + exit 2 + ;; + esac +done + +if [[ "$WORKER_COUNT" -lt 1 ]]; then + echo "WORKER_COUNT must be >= 1" >&2 + exit 2 +fi + +if [[ -z "$CHUNK_COUNT" ]]; then + CHUNK_COUNT=$(( WORKER_COUNT * 3 )) +fi + +if [[ "$CHUNK_COUNT" -lt 1 ]]; then + echo "CHUNK_COUNT must be >= 1" >&2 + exit 2 +fi + +if ! command -v docker >/dev/null 2>&1; then + echo "docker is required but not found in PATH" >&2 + exit 2 +fi + +if ! docker info >/dev/null 2>&1; then + echo "docker daemon is not available" >&2 + exit 2 +fi + +normalize_python_versions() { + local selection="$1" + local raw + local normalized=() + + if [[ "$selection" == "latest" ]]; then + normalized=("3.13") + elif [[ "$selection" == "all" ]]; then + normalized=("${SUPPORTED_PYTHON_VERSIONS[@]}") + else + raw="${selection//,/ }" + read -r -a normalized <<< "$raw" + fi + + if [[ "${#normalized[@]}" -eq 0 ]]; then + echo "PYTHON_VERSIONS selection produced no versions" >&2 + exit 2 + fi + + for version in "${normalized[@]}"; do + if [[ ! " ${SUPPORTED_PYTHON_VERSIONS[*]} " =~ " ${version} " ]]; then + echo "Unsupported Python version '${version}'. Supported: ${SUPPORTED_PYTHON_VERSIONS[*]}" >&2 + exit 2 + fi + done + + echo "${normalized[@]}" +} + +ensure_ldap_image() { + local image_ref="bitnami/openldap:latest" + if docker image inspect "$image_ref" >/dev/null 2>&1; then + return + fi + + echo "LDAP image $image_ref not found locally; trying docker pull..." + if docker pull "$image_ref"; then + return + fi + + echo "docker pull failed; building bitnami/openldap:latest from source (CI fallback)." + local workdir="$ROOT_DIR/.docker-test-artifacts/bitnami-containers" + rm -rf "$workdir" + git clone --depth 1 https://github.com/bitnami/containers.git "$workdir" + (cd "$workdir/bitnami/openldap/2.6/debian-12" && docker build -t "$image_ref" .) +} + +start_services() { + ensure_ldap_image + + docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true + docker network create "$DOCKER_NETWORK_NAME" >/dev/null + + docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true + docker run -d --rm \ + --name "$LDAP_CONTAINER_NAME" \ + --network "$DOCKER_NETWORK_NAME" \ + -e LDAP_ADMIN_USERNAME=admin \ + -e LDAP_ADMIN_PASSWORD=adminpassword \ + -e LDAP_USERS=user01,user02 \ + -e LDAP_PASSWORDS=password1,password2 \ + bitnami/openldap:latest >/dev/null + + sleep 2 +} + +stop_services() { + docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true + docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true +} + +cleanup() { + stop_services +} + +collect_junit_totals() { + local artifacts_dir="$1" + + python - "$artifacts_dir" <<'PY' +import glob +import os +import sys +import xml.etree.ElementTree as ET + +artifacts_dir = sys.argv[1] +tests = failures = errors = files = 0 + +for path in sorted(glob.glob(os.path.join(artifacts_dir, "junit.*.xml"))): + files += 1 + try: + root = ET.parse(path).getroot() + except Exception: + continue + + if root.tag == "testsuite": + suites = [root] + elif root.tag == "testsuites": + suites = root.findall("testsuite") + else: + suites = [] + + for suite in suites: + tests += int(suite.attrib.get("tests", 0) or 0) + failures += int(suite.attrib.get("failures", 0) or 0) + errors += int(suite.attrib.get("errors", 0) or 0) + +print(f"{tests} {failures} {errors} {files}") +PY +} + +append_summary_row() { + local py_version="$1" + local chunks_total="$2" + local junit_files="$3" + local tests="$4" + local failures="$5" + local errors="$6" + local status="$7" + + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "$py_version" "$chunks_total" "$junit_files" "$tests" "$failures" "$errors" "$status" >> "$SUMMARY_TSV" +} + +write_summary_files() { + local end_epoch end_human elapsed_sec + + if [[ -z "$SUMMARY_TSV" || -z "$SUMMARY_TXT" || -z "$SUMMARY_JSON" ]]; then + return + fi + + if [[ ! -f "$SUMMARY_TSV" ]]; then + return + fi + + end_epoch="$(date +%s)" + end_human="$(date -u +"%Y-%m-%dT%H:%M:%SZ")" + + if [[ -n "$TESTS_START_EPOCH" ]]; then + elapsed_sec=$(( end_epoch - TESTS_START_EPOCH )) + else + elapsed_sec=0 + fi + + { + echo "Test Run Summary" + echo "Start (UTC): ${TESTS_START_HUMAN:-N/A}" + echo "End (UTC): $end_human" + echo "Elapsed: ${elapsed_sec}s" + echo + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "Python" "Status" "Chunks" "JUnit" "Tests" "Failures" "Errors" + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "------" "------" "------" "-----" "-----" "--------" "------" + + if [[ -s "$SUMMARY_TSV" ]]; then + while IFS=$'\t' read -r py_version chunks_total junit_files tests failures errors status; do + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "$py_version" "$status" "$chunks_total" "$junit_files" "$tests" "$failures" "$errors" + done < "$SUMMARY_TSV" + else + echo "No per-version summary rows were recorded." + fi + + if [[ -s "$SUMMARY_FAIL_LOGS" ]]; then + echo + echo "Failed Chunk Logs" + cat "$SUMMARY_FAIL_LOGS" + fi + } > "$SUMMARY_TXT" + + python - "$SUMMARY_TSV" "$SUMMARY_FAIL_LOGS" "$SUMMARY_JSON" "${TESTS_START_HUMAN:-N/A}" "$end_human" "$elapsed_sec" <<'PY' +import json +import sys + +summary_tsv, fail_logs_path, output_path, start_utc, end_utc, elapsed_sec = sys.argv[1:] + +rows = [] +with open(summary_tsv) as f: + for line in f: + parts = line.rstrip("\n").split("\t") + if len(parts) != 7: + continue + py_version, chunks_total, junit_files, tests, failures, errors, status = parts + rows.append( + { + "python_version": py_version, + "status": status, + "chunks_total": int(chunks_total), + "junit_files": int(junit_files), + "tests": int(tests), + "failures": int(failures), + "errors": int(errors), + } + ) + +failed_logs = [] +with open(fail_logs_path) as f: + failed_logs = [line.strip() for line in f if line.strip()] + +payload = { + "start_utc": start_utc, + "end_utc": end_utc, + "elapsed_seconds": int(elapsed_sec), + "python_versions": rows, + "failed_chunk_logs": failed_logs, +} + +with open(output_path, "w") as f: + json.dump(payload, f, indent=2) + f.write("\n") +PY + + echo "==> Test run end time (UTC): $end_human" + echo "==> Test run elapsed: ${elapsed_sec}s" + echo "==> Summary written: $SUMMARY_TXT" + echo "==> Summary JSON: $SUMMARY_JSON" +} + +on_exit() { + local exit_code=$? + write_summary_files || true + cleanup + trap - EXIT + exit "$exit_code" +} + +trap on_exit EXIT + +read -r -a SELECTED_PYTHON_VERSIONS <<< "$(normalize_python_versions "$PYTHON_VERSIONS")" + +echo "==> Preparing artifacts directory: $ARTIFACTS_DIR" +rm -rf "$ARTIFACTS_DIR" +mkdir -p "$ARTIFACTS_DIR" + +SUMMARY_TSV="$ARTIFACTS_DIR/.summary_rows.tsv" +SUMMARY_FAIL_LOGS="$ARTIFACTS_DIR/.summary_fail_logs.txt" +SUMMARY_TXT="$ARTIFACTS_DIR/summary.txt" +SUMMARY_JSON="$ARTIFACTS_DIR/summary.json" + +: > "$SUMMARY_TSV" +: > "$SUMMARY_FAIL_LOGS" + +echo "==> Starting shared services (LDAP)" +start_services + +TESTS_START_EPOCH="$(date +%s)" +TESTS_START_HUMAN="$(date -u +"%Y-%m-%dT%H:%M:%SZ")" +echo "==> Test run start time (UTC): $TESTS_START_HUMAN" +echo "==> Python versions selected: ${SELECTED_PYTHON_VERSIONS[*]}" + +run_chunk() { + local group="$1" + local log_file="$CURRENT_ARTIFACTS_DIR/shard.${group}.log" + + if docker run --rm \ + --network "$DOCKER_NETWORK_NAME" \ + -e SHARD_GROUP="$group" \ + -e SHARD_COUNT="$CHUNK_COUNT" \ + -e ARTIFACTS_DIR="/artifacts" \ + -e PYTEST_EXTRA_ARGS="$PYTEST_EXTRA_ARGS" \ + -e QSERVER_TEST_LDAP_HOST="$LDAP_CONTAINER_NAME" \ + -e QSERVER_TEST_LDAP_PORT="1389" \ + -e QSERVER_TEST_REDIS_ADDR="localhost" \ + -e QSERVER_HTTP_TEST_BIND_HOST="127.0.0.1" \ + -e QSERVER_HTTP_TEST_HOST="127.0.0.1" \ + -v "$CURRENT_ARTIFACTS_DIR:/artifacts" \ + "$CURRENT_IMAGE_TAG" >"$log_file" 2>&1; then + : > "$CURRENT_ARTIFACTS_DIR/.status.${group}.ok" + else + : > "$CURRENT_ARTIFACTS_DIR/.status.${group}.fail" + exit 1 + fi +} + +export -f run_chunk +export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_CONTAINER_NAME + +for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do + CURRENT_IMAGE_TAG="${IMAGE_TAG_BASE}-py${PYTHON_VERSION}" + CURRENT_ARTIFACTS_DIR="$ARTIFACTS_DIR/py${PYTHON_VERSION}" + export CURRENT_IMAGE_TAG CURRENT_ARTIFACTS_DIR + + echo "==> Building test image: $CURRENT_IMAGE_TAG (Python $PYTHON_VERSION)" + docker build \ + --build-arg PYTHON_VERSION="$PYTHON_VERSION" \ + -f "$ROOT_DIR/docker/test.Dockerfile" \ + -t "$CURRENT_IMAGE_TAG" \ + "$ROOT_DIR" + + mkdir -p "$CURRENT_ARTIFACTS_DIR" + + echo "==> [Python $PYTHON_VERSION] Starting dynamic dispatch: $WORKER_COUNT workers over $CHUNK_COUNT chunks" + if ! seq 1 "$CHUNK_COUNT" | xargs -P "$WORKER_COUNT" -I {} bash -lc 'run_chunk "$1"' _ {}; then + echo "One or more chunks failed for Python $PYTHON_VERSION." >&2 + read -r TOTAL_TESTS TOTAL_FAILURES TOTAL_ERRORS TOTAL_JUNIT_FILES < <(collect_junit_totals "$CURRENT_ARTIFACTS_DIR") + for group in $(seq 1 "$CHUNK_COUNT"); do + if [[ -f "$CURRENT_ARTIFACTS_DIR/.status.${group}.fail" ]]; then + echo "Chunk $group failed. Log: $CURRENT_ARTIFACTS_DIR/shard.${group}.log" >&2 + echo "$CURRENT_ARTIFACTS_DIR/shard.${group}.log" >> "$SUMMARY_FAIL_LOGS" + fi + done + append_summary_row "py${PYTHON_VERSION}" "$CHUNK_COUNT" "$TOTAL_JUNIT_FILES" \ + "$TOTAL_TESTS" "$TOTAL_FAILURES" "$TOTAL_ERRORS" "FAIL" + exit 1 + fi + + for group in $(seq 1 "$CHUNK_COUNT"); do + if [[ -f "$CURRENT_ARTIFACTS_DIR/.status.${group}.ok" ]]; then + echo "[Python $PYTHON_VERSION] Chunk $group completed successfully" + fi + done + + rm -f "$CURRENT_ARTIFACTS_DIR"/.status.*.ok "$CURRENT_ARTIFACTS_DIR"/.status.*.fail + + echo "==> [Python $PYTHON_VERSION] Merging coverage artifacts" + docker run --rm \ + --entrypoint bash \ + -v "$CURRENT_ARTIFACTS_DIR:/artifacts" \ + "$CURRENT_IMAGE_TAG" \ + -lc "set -euo pipefail; \ + python -m coverage combine /artifacts/.coverage.* && \ + python -m coverage xml -o /artifacts/coverage.xml && \ + python -m coverage report -m > /artifacts/coverage.txt" + + if [[ "${#SELECTED_PYTHON_VERSIONS[@]}" -eq 1 ]]; then + cp "$CURRENT_ARTIFACTS_DIR/coverage.xml" "$ROOT_DIR/coverage.xml" + else + cp "$CURRENT_ARTIFACTS_DIR/coverage.xml" "$ROOT_DIR/coverage.py${PYTHON_VERSION}.xml" + fi + + read -r TOTAL_TESTS TOTAL_FAILURES TOTAL_ERRORS TOTAL_JUNIT_FILES < <(collect_junit_totals "$CURRENT_ARTIFACTS_DIR") + echo "==> [Python $PYTHON_VERSION] JUnit summary: tests=$TOTAL_TESTS failures=$TOTAL_FAILURES errors=$TOTAL_ERRORS files=$TOTAL_JUNIT_FILES" + + VERSION_STATUS="PASS" + if [[ "$TOTAL_FAILURES" -gt 0 || "$TOTAL_ERRORS" -gt 0 ]]; then + VERSION_STATUS="FAIL" + fi + + append_summary_row "py${PYTHON_VERSION}" "$CHUNK_COUNT" "$TOTAL_JUNIT_FILES" \ + "$TOTAL_TESTS" "$TOTAL_FAILURES" "$TOTAL_ERRORS" "$VERSION_STATUS" +done + +echo "==> Completed. Artifacts:" +echo " versioned logs : $ARTIFACTS_DIR/py/shard..log" +echo " versioned junit : $ARTIFACTS_DIR/py/junit..xml" +echo " versioned coverage : $ARTIFACTS_DIR/py/{coverage.txt,coverage.xml}" +echo " run summary : $ARTIFACTS_DIR/{summary.txt,summary.json}" + +if [[ "${#SELECTED_PYTHON_VERSIONS[@]}" -eq 1 ]]; then + echo " root coverage xml : $ROOT_DIR/coverage.xml" +else + echo " root coverage xmls : $ROOT_DIR/coverage.py.xml" +fi From 8298a7b29acd1d461eb7dc50391609a41453e738 Mon Sep 17 00:00:00 2001 From: davidpcls Date: Fri, 20 Mar 2026 11:24:56 -0600 Subject: [PATCH 13/18] Fixing unit tests (#3) This is a set of test changes intended to improve the reliability of unit testing, as the current unit tests are randomly failing due to test design. Primarily this appears to be centered around LDAP. So this work was to: * Fix for ldap errors * Hardening unit tests so they fail less frequency * Try to handle console output more reliably --- .github/workflows/testing.yml | 23 +- bluesky_httpserver/tests/conftest.py | 23 +- .../tests/test_authenticators.py | 29 ++- .../tests/test_console_output.py | 81 +++++--- .../tests/test_core_api_main.py | 44 +++- bluesky_httpserver/tests/test_server.py | 26 ++- .../docker-configs/ldap-docker-compose.yml | 11 +- .../dockerfiles}/test.Dockerfile | 0 continuous_integration/scripts/start_LDAP.sh | 196 +++++++++++++++++- docs/source/usage.rst | 4 +- scripts/run_ci_docker_parallel.sh | 56 ++--- 11 files changed, 371 insertions(+), 122 deletions(-) rename {docker => continuous_integration/dockerfiles}/test.Dockerfile (100%) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 5355c05..adef4fc 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -26,7 +26,7 @@ jobs: - name: Fetch tags run: git fetch --tags --prune --unshallow - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - uses: shogo82148/actions-setup-redis@v1 @@ -36,14 +36,8 @@ jobs: run: | # sudo apt install redis - pushd .. - git clone https://github.com/bitnami/containers.git - cd containers/bitnami/openldap/2.6/debian-12 - docker build -t bitnami/openldap:latest . - popd - # Start LDAP - source continuous_integration/scripts/start_LDAP.sh + bash continuous_integration/scripts/start_LDAP.sh # These packages are installed in the base environment but may be older # versions. Explicitly upgrade them because they often create @@ -70,6 +64,19 @@ jobs: pip list - name: Test with pytest + env: + PYTEST_ADDOPTS: "--durations=20" run: | coverage run -m pytest -vv coverage report -m + - name: Dump LDAP diagnostics on failure + if: failure() + run: | + docker ps + docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml ps + LDAP_CONTAINER_ID=$(docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml ps -q openldap | tr -d '[:space:]') + if [ -n "$LDAP_CONTAINER_ID" ]; then + docker logs --tail 200 "$LDAP_CONTAINER_ID" + else + docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml logs --tail 200 openldap + fi diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index d5cafdb..8a81df9 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -4,6 +4,7 @@ import pytest import requests from bluesky_queueserver.manager.comms import zmq_single_request +from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa: F401 from bluesky_queueserver.manager.tests.common import set_qserver_zmq_encoding # noqa: F401 from xprocess import ProcessStarter @@ -60,7 +61,11 @@ def fastapi_server_fs(xprocess): to perform additional steps (such as setting environmental variables) before the server is started. """ - def start(http_server_host=SERVER_ADDRESS, http_server_port=SERVER_PORT, api_key=API_KEY_FOR_TESTS): + def start( + http_server_host=SERVER_ADDRESS, + http_server_port=SERVER_PORT, + api_key=API_KEY_FOR_TESTS, + ): class Starter(ProcessStarter): max_read_lines = 53 @@ -112,7 +117,12 @@ def add_plans_to_queue(): user_group = _user_group user = "HTTP unit test setup" - plan1 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 10, "delay": 1}, "item_type": "plan"} + plan1 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 10, "delay": 1}, + "item_type": "plan", + } plan2 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} for plan in (plan1, plan2, plan2): resp2, _ = zmq_single_request("queue_item_add", {"item": plan, "user": user, "user_group": user_group}) @@ -120,7 +130,14 @@ def add_plans_to_queue(): def request_to_json( - request_type, path, *, request_prefix="/api", api_key=API_KEY_FOR_TESTS, token=None, login=None, **kwargs + request_type, + path, + *, + request_prefix="/api", + api_key=API_KEY_FOR_TESTS, + token=None, + login=None, + **kwargs, ): if login: auth = None diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 53c6bbe..7b7dd4b 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -1,4 +1,5 @@ import asyncio +import os import time from typing import Any, Tuple @@ -10,20 +11,28 @@ from respx import MockRouter from starlette.datastructures import URL, QueryParams -# fmt: off from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState +LDAP_TEST_HOST = os.environ.get("QSERVER_TEST_LDAP_HOST", "localhost") +LDAP_TEST_PORT = int(os.environ.get("QSERVER_TEST_LDAP_PORT", "1389")) +LDAP_TEST_ALT_HOST = os.environ.get("QSERVER_TEST_LDAP_ALT_HOST") +if not LDAP_TEST_ALT_HOST: + LDAP_TEST_ALT_HOST = "127.0.0.1" if LDAP_TEST_HOST == "localhost" else LDAP_TEST_HOST + + +# fmt: off + @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ - ("localhost", 1389), - ("localhost:1389", 904), # Random port, ignored - ("localhost:1389", None), - ("127.0.0.1", 1389), - ("127.0.0.1:1389", 904), - (["localhost"], 1389), - (["localhost", "127.0.0.1"], 1389), - (["localhost", "127.0.0.1:1389"], 1389), - (["localhost:1389", "127.0.0.1:1389"], None), + (LDAP_TEST_HOST, LDAP_TEST_PORT), + (f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", 904), # Random port, ignored + (f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", None), + (LDAP_TEST_ALT_HOST, LDAP_TEST_PORT), + (f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}", 904), + ([LDAP_TEST_HOST], LDAP_TEST_PORT), + ([LDAP_TEST_HOST, LDAP_TEST_ALT_HOST], LDAP_TEST_PORT), + ([LDAP_TEST_HOST, f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}"], LDAP_TEST_PORT), + ([f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}"], None), ]) # fmt: on @pytest.mark.parametrize("use_tls,use_ssl", [(False, False)]) diff --git a/bluesky_httpserver/tests/test_console_output.py b/bluesky_httpserver/tests/test_console_output.py index 1f089ec..6193db0 100644 --- a/bluesky_httpserver/tests/test_console_output.py +++ b/bluesky_httpserver/tests/test_console_output.py @@ -3,17 +3,16 @@ import re import threading import time as ttime +from typing import Any import pytest import requests -from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa F401 from websockets.sync.client import connect from bluesky_httpserver.tests.conftest import ( # noqa F401 API_KEY_FOR_TESTS, SERVER_ADDRESS, SERVER_PORT, - fastapi_server_fs, request_to_json, set_qserver_zmq_encoding, wait_for_environment_to_be_closed, @@ -36,37 +35,42 @@ def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs): self._api_key = api_key def run(self): - kwargs = {"stream": True} + kwargs: dict[str, Any] = {"stream": True} if self._api_key: - auth = None headers = {"Authorization": f"ApiKey {self._api_key}"} - kwargs.update({"auth": auth, "headers": headers}) + kwargs.update({"headers": headers}) + + kwargs["timeout"] = (5, 1) - with requests.get(f"http://{SERVER_ADDRESS}:{SERVER_PORT}/api/stream_console_output", **kwargs) as r: - r.encoding = "utf-8" + while not self._exit: + try: + with requests.get( + f"http://{SERVER_ADDRESS}:{SERVER_PORT}/api/stream_console_output", + **kwargs, + ) as r: + r.encoding = "utf-8" - characters = [] - n_brackets = 0 + characters = [] + n_brackets = 0 - for ch in r.iter_content(decode_unicode=True): - # Note, that some output must be received from the server before the loop exits - if self._exit: - break + for ch in r.iter_content(decode_unicode=True): + if self._exit: + return - characters.append(ch) - if ch == "{": - n_brackets += 1 - elif ch == "}": - n_brackets -= 1 + characters.append(ch) + if ch == "{": + n_brackets += 1 + elif ch == "}": + n_brackets -= 1 - # If the received buffer ('characters') is not empty and the message contains - # equal number of opening and closing brackets then consider the message complete. - if characters and not n_brackets: - line = "".join(characters) - characters = [] + if characters and not n_brackets: + line = "".join(characters) + characters = [] - print(f"{line}") - self.received_data_buffer.append(json.loads(line)) + print(f"{line}") + self.received_data_buffer.append(json.loads(line)) + except requests.exceptions.ReadTimeout: + continue def stop(self): """ @@ -81,7 +85,10 @@ def __del__(self): @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_stream_console_output_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``stream_console_output`` API @@ -122,7 +129,8 @@ def test_http_server_stream_console_output_1( assert resp2["items"][0] == resp1["item"] assert resp2["running_item"] == {} - rsc.join() + rsc.join(timeout=10) + assert not rsc.is_alive(), "Timed out waiting for stream_console_output thread to terminate" assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) @@ -160,7 +168,11 @@ def test_http_server_stream_console_output_1( @pytest.mark.parametrize("zmq_encoding", (None, "json", "msgpack")) @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port, zmq_encoding # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, + zmq_encoding, # noqa F811 ): """ Test for ``console_output`` API (not a streaming version). @@ -238,7 +250,10 @@ def test_http_server_console_output_1( @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_update_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``console_output`` API (not a streaming version). @@ -379,7 +394,10 @@ def __del__(self): @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_socket_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``/console_output/ws`` websocket @@ -421,7 +439,8 @@ def test_http_server_console_output_socket_1( assert resp2["items"][0] == resp1["item"] assert resp2["running_item"] == {} - rsc.join() + rsc.join(timeout=10) + assert not rsc.is_alive(), "Timed out waiting for console_output websocket thread to terminate" assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) diff --git a/bluesky_httpserver/tests/test_core_api_main.py b/bluesky_httpserver/tests/test_core_api_main.py index b2b5140..0c471bd 100644 --- a/bluesky_httpserver/tests/test_core_api_main.py +++ b/bluesky_httpserver/tests/test_core_api_main.py @@ -30,8 +30,17 @@ # Plans used in most of the tests: '_plan1' and '_plan2' are quickly executed '_plan3' runs for 5 seconds. _plan1 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} -_plan2 = {"name": "scan", "args": [["det1", "det2"], "motor", -1, 1, 10], "item_type": "plan"} -_plan3 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 5, "delay": 1}, "item_type": "plan"} +_plan2 = { + "name": "scan", + "args": [["det1", "det2"], "motor", -1, 1, 10], + "item_type": "plan", +} +_plan3 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 5, "delay": 1}, + "item_type": "plan", +} _instruction_stop = {"name": "queue_stop", "item_type": "instruction"} @@ -515,8 +524,10 @@ def test_http_server_queue_item_update_2_fail(re_manager, fastapi_server, replac resp2 = request_to_json("post", "/queue/item/update", json=params) assert resp2["success"] is False - assert resp2["msg"] == "Failed to add an item: Failed to replace item: " \ - "Item with UID 'incorrect_uid' is not in the queue" + assert ( + resp2["msg"] == "Failed to add an item: Failed to replace item: " + "Item with UID 'incorrect_uid' is not in the queue" + ) resp3 = request_to_json("get", "/queue/get") assert resp3["items"] != [] @@ -1286,16 +1297,33 @@ def test_http_server_history_clear(re_manager, fastapi_server, clear_params, exp def test_http_server_manager_kill(re_manager, fastapi_server): # noqa F811 + timeout_variants = ( + "Request timeout: ZMQ communication error: timeout occurred", + "Request timeout: ZMQ communication error: Resource temporarily unavailable", + ) + request_to_json("post", "/environment/open") assert wait_for_environment_to_be_created(10), "Timeout" resp = request_to_json("post", "/test/manager/kill") assert "success" not in resp - assert "Request timeout: ZMQ communication error: timeout occurred" in resp["detail"] - - ttime.sleep(10) + assert any(_ in resp["detail"] for _ in timeout_variants) + + deadline = ttime.time() + 20 + last_status = None + while ttime.time() < deadline: + ttime.sleep(0.2) + last_status = request_to_json("get", "/status") + if ( + isinstance(last_status, dict) + and last_status.get("manager_state") == "idle" + and last_status.get("worker_environment_exists") is True + ): + break + else: + assert False, f"Timeout while waiting for manager recovery after kill. Last status: {last_status!r}" - resp = request_to_json("get", "/status") + resp = last_status assert resp["msg"].startswith("RE Manager") assert resp["manager_state"] == "idle" assert resp["items_in_queue"] == 0 diff --git a/bluesky_httpserver/tests/test_server.py b/bluesky_httpserver/tests/test_server.py index 117f4df..33b82a2 100644 --- a/bluesky_httpserver/tests/test_server.py +++ b/bluesky_httpserver/tests/test_server.py @@ -27,8 +27,17 @@ # Plans used in most of the tests: '_plan1' and '_plan2' are quickly executed '_plan3' runs for 5 seconds. _plan1 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} -_plan2 = {"name": "scan", "args": [["det1", "det2"], "motor", -1, 1, 10], "item_type": "plan"} -_plan3 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 5, "delay": 1}, "item_type": "plan"} +_plan2 = { + "name": "scan", + "args": [["det1", "det2"], "motor", -1, 1, 10], + "item_type": "plan", +} +_plan3 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 5, "delay": 1}, + "item_type": "plan", +} _config_public_key = """ @@ -122,7 +131,7 @@ def test_http_server_secure_1(monkeypatch, tmpdir, re_manager_cmd, fastapi_serve @pytest.mark.parametrize("option", ["ev", "cfg_file", "both"]) # fmt: on def test_http_server_set_zmq_address_1( - monkeypatch, tmpdir, re_manager_cmd, fastapi_server_fs, option # noqa: F811 + monkeypatch, tmpdir, re_manager_cmd, fastapi_server_fs, free_tcp_port_factory, option # noqa: F811 ): """ Test if ZMQ address of RE Manager is passed to the HTTP server using 'QSERVER_ZMQ_ADDRESS_CONTROL' @@ -130,11 +139,12 @@ def test_http_server_set_zmq_address_1( channel different from default address, add and execute a plan. """ - # Change ZMQ address to use port 60616 instead of the default port 60615. - zmq_control_address_server = "tcp://*:60616" - zmq_info_address_server = "tcp://*:60617" - zmq_control_address = "tcp://localhost:60616" - zmq_info_address = "tcp://localhost:60617" + zmq_control_port = free_tcp_port_factory() + zmq_info_port = free_tcp_port_factory() + zmq_control_address_server = f"tcp://*:{zmq_control_port}" + zmq_info_address_server = f"tcp://*:{zmq_info_port}" + zmq_control_address = f"tcp://localhost:{zmq_control_port}" + zmq_info_address = f"tcp://localhost:{zmq_info_port}" if option == "ev": monkeypatch.setenv("QSERVER_ZMQ_CONTROL_ADDRESS", zmq_control_address) monkeypatch.setenv("QSERVER_ZMQ_INFO_ADDRESS", zmq_info_address) diff --git a/continuous_integration/docker-configs/ldap-docker-compose.yml b/continuous_integration/docker-configs/ldap-docker-compose.yml index 2b2c45a..5fbfc53 100644 --- a/continuous_integration/docker-configs/ldap-docker-compose.yml +++ b/continuous_integration/docker-configs/ldap-docker-compose.yml @@ -1,14 +1,13 @@ services: openldap: - image: osixia/openldap:latest + image: osixia/openldap:1.5.0 ports: - - '1389:1389' - - '1636:1636' + - '1389:389' + - '1636:636' environment: - - LDAP_ADMIN_USERNAME=admin + - LDAP_ORGANISATION=Example Inc. + - LDAP_DOMAIN=example.org - LDAP_ADMIN_PASSWORD=adminpassword - - LDAP_USERS=user01,user02 - - LDAP_PASSWORDS=password1,password2 volumes: - 'openldap_data:/var/lib/ldap' diff --git a/docker/test.Dockerfile b/continuous_integration/dockerfiles/test.Dockerfile similarity index 100% rename from docker/test.Dockerfile rename to continuous_integration/dockerfiles/test.Dockerfile diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh index ecfa1cf..d2bd48d 100755 --- a/continuous_integration/scripts/start_LDAP.sh +++ b/continuous_integration/scripts/start_LDAP.sh @@ -1,7 +1,195 @@ -#!/bin/bash -set -e +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +COMPOSE_FILE="${LDAP_COMPOSE_FILE:-$ROOT_DIR/continuous_integration/docker-configs/ldap-docker-compose.yml}" +COMPOSE_PROJECT="${LDAP_COMPOSE_PROJECT:-}" +LDAP_HOST="${LDAP_HOST:-127.0.0.1}" +LDAP_PORT="${LDAP_PORT:-1389}" +LDAP_ADMIN_DN="cn=admin,dc=example,dc=org" +LDAP_ADMIN_PASSWORD="adminpassword" +LDAP_BASE_DN="dc=example,dc=org" + +compose_cmd() { + if [[ -n "$COMPOSE_PROJECT" ]]; then + docker compose -p "$COMPOSE_PROJECT" -f "$COMPOSE_FILE" "$@" + else + docker compose -f "$COMPOSE_FILE" "$@" + fi +} + +get_openldap_container_id() { + compose_cmd ps -q openldap | tr -d '[:space:]' +} + +wait_for_ldap() { + local timeout_seconds="${1:-60}" + local deadline=$((SECONDS + timeout_seconds)) + + while (( SECONDS < deadline )); do + if python - </dev/null 2>&1 +import socket + +with socket.create_connection(("${LDAP_HOST}", ${LDAP_PORT}), timeout=1): + pass +PY + then + return 0 + fi + sleep 1 + done + + return 1 +} + +wait_for_ldap_bind() { + local container_id="$1" + local timeout_seconds="${2:-60}" + local deadline=$((SECONDS + timeout_seconds)) + local rc=0 + + while (( SECONDS < deadline )); do + rc=0 + docker exec "$container_id" ldapsearch \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" \ + -b "$LDAP_BASE_DN" \ + -s base \ + "(objectclass=*)" dn >/dev/null 2>&1 || rc=$? + if [[ "$rc" -eq 0 ]]; then + return 0 + fi + sleep 1 + done + + return 1 +} + +wait_for_ldap_test_user_bind() { + local container_id="$1" + local timeout_seconds="${2:-60}" + local deadline=$((SECONDS + timeout_seconds)) + local rc=0 + + while (( SECONDS < deadline )); do + rc=0 + docker exec "$container_id" ldapwhoami \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "cn=user01,ou=users,$LDAP_BASE_DN" \ + -w "password1" >/dev/null 2>&1 || rc=$? + if [[ "$rc" -eq 0 ]]; then + return 0 + fi + sleep 1 + done + + return 1 +} + +print_ldap_diagnostics() { + local container_id="${1:-}" + + echo "LDAP startup diagnostics:" >&2 + compose_cmd ps >&2 || true + + if [[ -z "$container_id" ]]; then + container_id="$(get_openldap_container_id)" + fi + + if [[ -n "$container_id" ]]; then + docker logs --tail 200 "$container_id" >&2 || true + else + compose_cmd logs --tail 200 openldap >&2 || true + fi +} + +ldap_entry_exists() { + local container_id="$1" + local dn="$2" + + docker exec "$container_id" ldapsearch \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" \ + -b "$dn" \ + -s base \ + "(objectclass=*)" dn >/dev/null 2>&1 +} + +ldap_add_if_missing() { + local container_id="$1" + local dn="$2" + local ldif="$3" + + if ldap_entry_exists "$container_id" "$dn"; then + return 0 + fi + + docker exec -i "$container_id" ldapadd \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" >/dev/null <&2 + print_ldap_diagnostics + exit 1 +fi + +if ! wait_for_ldap 120; then + echo "LDAP port ${LDAP_HOST}:${LDAP_PORT} did not become reachable in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +echo "LDAP port ${LDAP_HOST}:${LDAP_PORT} is reachable. Waiting for slapd initialization..." +sleep 3 + +if ! wait_for_ldap_bind "$CONTAINER_ID" 120; then + echo "LDAP admin bind did not become ready in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +seed_ldap_test_users "$CONTAINER_ID" + +if ! wait_for_ldap_test_user_bind "$CONTAINER_ID" 60; then + echo "LDAP test-user bind did not become ready in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + docker ps diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 299bdcb..bcae133 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -169,7 +169,7 @@ If you are already in a browser context, open: This redirects to the OIDC provider login page and then back to the server callback. -This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting +This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting the authorization URI from the server:: http POST http://localhost:60610/api/auth/provider/entra/authorize @@ -183,7 +183,7 @@ spawn a browser for the user to log in to the provider. CLI/device flow *************** -For terminal clients (i.e. no browser possible), start with +For terminal clients (i.e. no browser possible), start with ``POST /api/auth/provider//authorize``. The response includes: diff --git a/scripts/run_ci_docker_parallel.sh b/scripts/run_ci_docker_parallel.sh index c9caee7..efb6594 100755 --- a/scripts/run_ci_docker_parallel.sh +++ b/scripts/run_ci_docker_parallel.sh @@ -8,8 +8,10 @@ CHUNK_COUNT="" PYTHON_VERSIONS="latest" PYTEST_EXTRA_ARGS="" ARTIFACTS_DIR="$ROOT_DIR/.docker-test-artifacts" -DOCKER_NETWORK_NAME="bhs-ci-net" -LDAP_CONTAINER_NAME="bhs-ci-ldap" +LDAP_COMPOSE_FILE="$ROOT_DIR/continuous_integration/docker-configs/ldap-docker-compose.yml" +LDAP_COMPOSE_PROJECT="bhs-ci-ldap-parallel-$$" +LDAP_SERVICE_NAME="openldap" +DOCKER_NETWORK_NAME="${LDAP_COMPOSE_PROJECT}_default" SUMMARY_TSV="" SUMMARY_FAIL_LOGS="" @@ -154,46 +156,16 @@ normalize_python_versions() { echo "${normalized[@]}" } -ensure_ldap_image() { - local image_ref="bitnami/openldap:latest" - if docker image inspect "$image_ref" >/dev/null 2>&1; then - return - fi - - echo "LDAP image $image_ref not found locally; trying docker pull..." - if docker pull "$image_ref"; then - return - fi - - echo "docker pull failed; building bitnami/openldap:latest from source (CI fallback)." - local workdir="$ROOT_DIR/.docker-test-artifacts/bitnami-containers" - rm -rf "$workdir" - git clone --depth 1 https://github.com/bitnami/containers.git "$workdir" - (cd "$workdir/bitnami/openldap/2.6/debian-12" && docker build -t "$image_ref" .) -} - start_services() { - ensure_ldap_image - - docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true - docker network create "$DOCKER_NETWORK_NAME" >/dev/null - - docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true - docker run -d --rm \ - --name "$LDAP_CONTAINER_NAME" \ - --network "$DOCKER_NETWORK_NAME" \ - -e LDAP_ADMIN_USERNAME=admin \ - -e LDAP_ADMIN_PASSWORD=adminpassword \ - -e LDAP_USERS=user01,user02 \ - -e LDAP_PASSWORDS=password1,password2 \ - bitnami/openldap:latest >/dev/null - - sleep 2 + LDAP_COMPOSE_FILE="$LDAP_COMPOSE_FILE" \ + LDAP_COMPOSE_PROJECT="$LDAP_COMPOSE_PROJECT" \ + LDAP_HOST="127.0.0.1" \ + LDAP_PORT="1389" \ + bash "$ROOT_DIR/continuous_integration/scripts/start_LDAP.sh" >/dev/null } stop_services() { - docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true - docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true + docker compose -p "$LDAP_COMPOSE_PROJECT" -f "$LDAP_COMPOSE_FILE" down -v >/dev/null 2>&1 || true } cleanup() { @@ -385,8 +357,8 @@ run_chunk() { -e SHARD_COUNT="$CHUNK_COUNT" \ -e ARTIFACTS_DIR="/artifacts" \ -e PYTEST_EXTRA_ARGS="$PYTEST_EXTRA_ARGS" \ - -e QSERVER_TEST_LDAP_HOST="$LDAP_CONTAINER_NAME" \ - -e QSERVER_TEST_LDAP_PORT="1389" \ + -e QSERVER_TEST_LDAP_HOST="$LDAP_SERVICE_NAME" \ + -e QSERVER_TEST_LDAP_PORT="389" \ -e QSERVER_TEST_REDIS_ADDR="localhost" \ -e QSERVER_HTTP_TEST_BIND_HOST="127.0.0.1" \ -e QSERVER_HTTP_TEST_HOST="127.0.0.1" \ @@ -400,7 +372,7 @@ run_chunk() { } export -f run_chunk -export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_CONTAINER_NAME +export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_SERVICE_NAME for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do CURRENT_IMAGE_TAG="${IMAGE_TAG_BASE}-py${PYTHON_VERSION}" @@ -410,7 +382,7 @@ for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do echo "==> Building test image: $CURRENT_IMAGE_TAG (Python $PYTHON_VERSION)" docker build \ --build-arg PYTHON_VERSION="$PYTHON_VERSION" \ - -f "$ROOT_DIR/docker/test.Dockerfile" \ + -f "$ROOT_DIR/continuous_integration/dockerfiles/test.Dockerfile" \ -t "$CURRENT_IMAGE_TAG" \ "$ROOT_DIR" From 94d8949c7751e578529f406b5e44c576e00c4822 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 25 May 2026 12:52:33 -0600 Subject: [PATCH 14/18] First attempt at migrating in the latest changes --- bluesky_httpserver/_authentication.py | 139 +++++++++++++----- bluesky_httpserver/authentication/__init__.py | 2 + bluesky_httpserver/authenticators.py | 135 ++++++++++++++++- bluesky_httpserver/database/core.py | 9 ++ bluesky_httpserver/routers/core_api.py | 15 +- bluesky_httpserver/schemas.py | 1 + .../tests/test_auth_for_websockets.py | 24 ++- .../tests/test_authenticators.py | 84 +++++++++-- .../tests/test_oidc_authenticators.py | 63 +++++++- 9 files changed, 409 insertions(+), 63 deletions(-) diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index c1144f5..992c5fb 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -4,7 +4,7 @@ import uuid as uuid_module import warnings from datetime import datetime, timedelta -from typing import Optional +from typing import Any, Optional from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn @@ -36,6 +36,7 @@ from .database import orm from .database.core import ( create_user, + get_or_create_principal, latest_principal_activity, lookup_valid_api_key, lookup_valid_pending_session_by_device_code, @@ -140,28 +141,53 @@ def create_refresh_token(session_id, secret_key, expires_delta): return encoded_jwt -def decode_token(token, secret_keys): - credentials_exception = HTTPException( - status_code=401, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) +def _decode_token_with_secret_keys(token, secret_keys): # The first key in settings.secret_keys is used for *encoding*. # All keys are tried for *decoding* until one works or they all - # fail. They supports key rotation. + # fail. They support key rotation. for secret_key in secret_keys: try: payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) - break + return payload except ExpiredSignatureError: # Do not let this be caught below with the other JWTError types. raise except JWTError: # Try the next key in the key rotation. continue - else: - raise credentials_exception - return payload + return None + + +def decode_token(token, secret_keys, proxied_authenticator=None): + credentials_exception = HTTPException( + status_code=401, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + payload = _decode_token_with_secret_keys(token, secret_keys) + if payload is not None: + return payload + if proxied_authenticator is not None: + return proxied_authenticator.decode_token(token) + raise credentials_exception + + +def _extract_scopes(decoded_access_token: dict[str, Any]) -> set[str]: + if "scp" in decoded_access_token: + scp = decoded_access_token["scp"] + return set(scp) if isinstance(scp, list) else set(scp.split(" ")) + if "scope" in decoded_access_token: + return set(decoded_access_token["scope"].split(" ")) + return set() + + +def _get_proxied_authenticator(authenticators): + if not authenticators: + return None + for authenticator in authenticators.values(): + if hasattr(authenticator, "oauth2_schema") and hasattr(authenticator, "decode_token"): + return authenticator + return None async def get_api_key( @@ -280,27 +306,49 @@ def get_current_principal( request.state.cookies_to_set.append({"key": API_KEY_COOKIE_NAME, "value": api_key}) elif access_token is not None: try: - payload = decode_token(access_token, settings.secret_keys) + payload = decode_token( + access_token, + settings.secret_keys, + _get_proxied_authenticator(authenticators), + ) except ExpiredSignatureError: raise HTTPException( status_code=401, detail="Access token has expired. Refresh token.", headers=headers_for_401, ) - principal = schemas.Principal( - uuid=uuid_module.UUID(hex=payload["sub"]), - type=payload["sub_typ"], - identities=[ - schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"] - ], - ) - - # scopes = payload["scp"] - - # Combine scopes for all identities (it is expected to be only one identity). - ids = [_["id"] for _ in payload["ids"] if _["idp"] in settings.authentication_provider_names] - scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) + token_scopes = _extract_scopes(payload) + if "sub_typ" in payload and "ids" in payload: + principal = schemas.Principal( + uuid=uuid_module.UUID(hex=payload["sub"]), + type=payload["sub_typ"], + identities=[ + schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"] + ], + ) + ids = [ + _["id"] + for _ in payload["ids"] + if (_["idp"] in settings.authentication_provider_names) + and api_access_manager.is_user_known(_["id"]) + ] + else: + identity_id = payload.get("user") or payload["sub"] + provider = ( + settings.authentication_provider_names[0] + if settings.authentication_provider_names + else _DEFAULT_ANONYMOUS_PROVIDER_NAME + ) + with get_sessionmaker(settings.database_settings)() as db: + principal_orm = get_or_create_principal(db, provider, identity_id) + principal = schemas.Principal( + uuid=principal_orm.uuid, + type="user", + identities=[schemas.Identity(id=identity_id, provider=provider)], + ) + ids = [identity_id] if api_access_manager.is_user_known(identity_id) else [] + scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) if ids else set(token_scopes) roles_sets = [api_access_manager.get_user_roles(_) for _ in ids] roles = set.union(*roles_sets) if roles_sets else set() @@ -361,7 +409,7 @@ def get_current_principal( return principal -def get_current_principal_websocket( +async def get_current_principal_websocket( websocket: WebSocket, scopes: str, ): @@ -373,13 +421,31 @@ def get_current_principal_websocket( auth_header = websocket.headers.get("Authorization", "") access_token, api_key = None, None - # Currently we do not support authentication with tokens - # if auth_header.startswith("Bearer "): - # access_token = auth_header[len("Bearer") :].strip() - if auth_header.startswith("ApiKey "): - api_key = auth_header[len("ApiKey") :].strip() + scheme, param = get_authorization_scheme_param(auth_header) + if scheme.lower() == "bearer": + access_token = param + elif scheme.lower() == "apikey": + api_key = param + + if access_token is None: + access_token = websocket.query_params.get("access_token") + if api_key is None: + api_key = websocket.query_params.get("api_key") principal = None + websocket.state.already_accepted = False + no_credentials = (access_token is None) and (api_key is None) + if no_credentials and not settings.allow_anonymous_access: + try: + await websocket.accept() + websocket.state.already_accepted = True + message = await asyncio.wait_for(websocket.receive_json(), timeout=1) + if isinstance(message, dict) and message.get("type") == "auth": + access_token = message.get("access_token") + api_key = message.get("api_key") + except Exception: + return None + try: principal = get_current_principal( request=websocket, @@ -558,11 +624,14 @@ async def authorize_redirect( """Redirect browser to OAuth provider for authentication.""" redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code" + requested_scopes = {"openid", "offline_access"} + requested_scopes.update(getattr(authenticator, "extra_scopes", [])) params = { "client_id": authenticator.client_id, "response_type": "code", - "scope": "openid profile email", + "scope": " ".join(sorted(requested_scopes)), "redirect_uri": redirect_uri, + "prompt": "login", } if state: params["state"] = state @@ -595,7 +664,9 @@ async def device_code_authorize( params={ "client_id": authenticator.client_id, "response_type": "code", - "scope": "openid profile email", + "scope": " ".join( + sorted({"openid", "offline_access", *getattr(authenticator, "extra_scopes", [])}) + ), "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", "state": pending_session["user_code"].replace("-", ""), } diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index 85d835e..3475cd1 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,4 +1,5 @@ from .._authentication import ( + _extract_scopes, base_authentication_router, build_auth_code_route, build_authorize_route, @@ -21,6 +22,7 @@ "ExternalAuthenticator", "InternalAuthenticator", "UserSessionState", + "_extract_scopes", "get_current_principal", "get_current_principal_websocket", "base_authentication_router", diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index a58fedf..dfb4466 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -4,9 +4,10 @@ import logging import re import secrets +import uuid from collections.abc import Iterable from datetime import timedelta -from typing import Any, List, Mapping, Optional, cast +from typing import Any, Dict, List, Mapping, Optional, cast import httpx from cachetools import TTLCache, cached @@ -189,17 +190,18 @@ def device_authorization_endpoint(self) -> str: def end_session_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) - @cached(TTLCache(maxsize=1, ttl=timedelta(days=7).total_seconds())) + @cached(TTLCache(maxsize=1, ttl=timedelta(hours=1).total_seconds())) def keys(self) -> List[str]: return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) - def decode_token(self, token: str) -> dict[str, Any]: + def decode_token(self, id_token: str, access_token: Optional[str] = None) -> dict[str, Any]: return jwt.decode( - token, + id_token, key=self.keys(), algorithms=self.id_token_signing_alg_values_supported, audience=self._audience, issuer=self.issuer, + access_token=access_token, ) async def authenticate(self, request: Request) -> Optional[UserSessionState]: @@ -223,13 +225,14 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: logger.error("Authentication error: %r", response_body) return None id_token = response_body["id_token"] + access_token = response_body.get("access_token") # NOTE: We decode the id_token, not access_token, because: # 1. The id_token is the OIDC identity assertion meant for the client # 2. Some providers (like Microsoft Entra) return opaque access_tokens # that cannot be decoded with the JWKS keys when the resource is # a first-party Microsoft API (e.g., Graph API with User.Read scope) try: - verified_body = self.decode_token(id_token) + verified_body = self.decode_token(id_token, access_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", @@ -310,18 +313,139 @@ def oauth2_schema(self) -> OAuth2: return self._oidc_bearer +class EntraAuthenticator(ProxiedOIDCAuthenticator): + def __init__( + self, + audience: str, + client_id: str, + well_known_uri: str, + device_flow_client_id: str, + extra_scopes: Optional[List[str]] = None, + confirmation_message: str = "", + scopes_map: Optional[Dict[str, list[str]]] = None, + client_secret: str = "", + redirect_on_success: Optional[str] = None, + ): + self.scopes_map = scopes_map if scopes_map is not None else {} + self.extra_scopes = extra_scopes or [] + super().__init__( + audience, + client_id, + well_known_uri, + device_flow_client_id, + scopes=None, + confirmation_message=confirmation_message, + ) + if client_secret: + self._client_secret = Secret(client_secret) + self.redirect_on_success = redirect_on_success + + @property + def scopes(self): + mapped = set() + for tiled_scopes in self.scopes_map.values(): + mapped.update(tiled_scopes) + return list(mapped) + + @scopes.setter + def scopes(self, value): + pass + + def decode_token(self, id_token: str, access_token: Optional[str] = None) -> dict[str, Any]: + claims = super().decode_token(id_token, access_token) + original_sub = claims.get("sub") + issuer = claims.get("iss", "") + claims["sub"] = uuid.uuid5(uuid.NAMESPACE_URL, f"{issuer}|{original_sub}").hex + claims["entra_sub"] = original_sub + + claims["entra_username"] = ( + claims.get("nameID") or claims.get("preferred_username") or claims.get("upn") or claims.get("email") + ) + + if user := claims.get("entra_username"): + user = user.strip() + if "\\" in user: + user = user.rsplit("\\", 1)[-1] + elif "@" in user: + user = user.split("@", 1)[0] + else: + user = original_sub + logger.warning( + "EntraAuthenticator: no human-readable username claim found in token " + "(checked nameID, preferred_username, upn, email). Falling back to Entra sub=%r.", + original_sub, + ) + claims["user"] = user + + scp_raw = claims.get("scp", "") + tiled_scope_set = set() + if scp_raw: + for scope in scp_raw.split(" "): + mapped_scopes = self.scopes_map.get(scope) + if mapped_scopes is None: + logger.warning("Unmapped Entra scope in 'scp': %s", scope) + continue + tiled_scope_set.update(mapped_scopes) + else: + for mapped_scopes in self.scopes_map.values(): + tiled_scope_set.update(mapped_scopes) + claims["scope"] = " ".join(tiled_scope_set) + return claims + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + code = request.query_params.get("code") + if not code: + logger.warning("Authentication failed: No authorization code parameter provided.") + return None + redirect_uri = f"{get_root_url(request)}{request.url.path}" + response = await exchange_code( + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, + extra_scopes=self.extra_scopes, + ) + response_body = response.json() + if response.is_error: + logger.error("Authentication error: %r", response_body) + return None + id_token = response_body["id_token"] + access_token = response_body.get("access_token") + refresh_token = response_body.get("refresh_token") + try: + verified_body = self.decode_token(id_token, access_token) + except JWTError: + logger.exception( + "Authentication error. Unverified token: %r", + jwt.get_unverified_claims(id_token), + ) + return None + username = verified_body.get("user") or verified_body["sub"] + state: dict[str, Any] = {} + if access_token: + state["entra_access_token"] = access_token + if refresh_token: + state["entra_refresh_token"] = refresh_token + return UserSessionState(username, state) + + async def exchange_code( token_uri: str, auth_code: str, client_id: str, client_secret: str, redirect_uri: str, + extra_scopes: Optional[List[str]] = None, ) -> httpx.Response: """Method that talks to an IdP to exchange a code for an access_token and/or id_token Args: token_url ([type]): [description] auth_code ([type]): [description] """ + scopes = {"openid", "offline_access"} + if extra_scopes: + scopes.update(extra_scopes) auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() response = httpx.post( url=token_uri, @@ -331,6 +455,7 @@ async def exchange_code( "redirect_uri": redirect_uri, "code": auth_code, "client_secret": client_secret, + "scope": " ".join(sorted(scopes)), }, headers={"Authorization": f"Basic {auth_value}"}, ) diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index 52d102f..a394fdd 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -209,6 +209,15 @@ def create_user(db, identity_provider, id): return principal +def get_or_create_principal(db, identity_provider, id): + identity = db.query(Identity).filter(Identity.id == id).filter(Identity.provider == identity_provider).first() + if identity is None: + principal = create_user(db, identity_provider, id) + else: + principal = identity.principal + return principal + + def lookup_valid_session(db, session_id): if isinstance(session_id, int): # Old versions of tiled used an integer sid. diff --git a/bluesky_httpserver/routers/core_api.py b/bluesky_httpserver/routers/core_api.py index 397972b..b70ca46 100644 --- a/bluesky_httpserver/routers/core_api.py +++ b/bluesky_httpserver/routers/core_api.py @@ -1140,12 +1140,13 @@ def is_alive(self): @router.websocket("/console_output/ws") async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.console_output_stream.add_queue(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() @@ -1166,12 +1167,13 @@ async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): @router.websocket("/status/ws") async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.system_info_stream.add_queue_status(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() @@ -1193,12 +1195,13 @@ async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): @router.websocket("/info/ws") async def info_ws(websocket: WebSocket, scopes=["read:monitor"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.system_info_stream.add_queue_info(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py index f1d9fcb..05a1764 100644 --- a/bluesky_httpserver/schemas.py +++ b/bluesky_httpserver/schemas.py @@ -190,6 +190,7 @@ class AboutAuthenticationProvider(pydantic.BaseModel): mode: AuthenticationMode links: Dict[str, str] confirmation_message: Optional[str] = None + extra_scopes: Optional[List[str]] = None class AboutAuthenticationLinks(pydantic.BaseModel): diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py index 3d26e22..17fb498 100644 --- a/bluesky_httpserver/tests/test_auth_for_websockets.py +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -50,12 +50,13 @@ class _ReceiveSystemInfoSocket(threading.Thread): save messages to the buffer. """ - def __init__(self, *, endpoint, api_key=None, token=None, **kwargs): + def __init__(self, *, endpoint, api_key=None, token=None, auth_message=None, **kwargs): super().__init__(**kwargs) self.received_data_buffer = [] self._exit = False self._api_key = api_key self._token = token + self._auth_message = auth_message self._endpoint = endpoint def run(self): @@ -69,6 +70,8 @@ def run(self): try: with connect(websocket_uri, additional_headers=additional_headers) as websocket: + if self._auth_message is not None: + websocket.send(json.dumps(self._auth_message)) while not self._exit: try: msg_json = websocket.recv(timeout=0.1, decode=False) @@ -94,7 +97,10 @@ def __del__(self): # fmt: off -@pytest.mark.parametrize("ws_auth_type", ["apikey", "apikey_invalid", "none"]) +@pytest.mark.parametrize( + "ws_auth_type", + ["apikey", "apikey_invalid", "token", "token_invalid", "none", "first_message_apikey", "first_message_token"], +) # fmt: on def test_websocket_auth_01( tmpdir, @@ -135,10 +141,14 @@ def test_websocket_auth_01( ws_params = {"api_key": api_key} elif ws_auth_type == "apikey_invalid": ws_params = {"api_key": "InvalidApiKey"} - # elif ws_auth_type == "token": - # ws_params = {"token": token} - # elif ws_auth_type == "token_invalid": - # ws_params = {"token": "InvalidToken"} + elif ws_auth_type == "token": + ws_params = {"token": token} + elif ws_auth_type == "token_invalid": + ws_params = {"token": "InvalidToken"} + elif ws_auth_type == "first_message_apikey": + ws_params = {"auth_message": {"type": "auth", "api_key": api_key}} + elif ws_auth_type == "first_message_token": + ws_params = {"auth_message": {"type": "auth", "access_token": token}} else: assert False, f"Unknown authentication type: {ws_auth_type!r}" @@ -164,7 +174,7 @@ def test_websocket_auth_01( buffer = rsc.received_data_buffer if ws_auth_type in ("none", "apikey_invalid", "token_invalid"): assert len(buffer) == 0 - elif ws_auth_type in ("apikey", "token"): + elif ws_auth_type in ("apikey", "token", "first_message_apikey", "first_message_token"): assert len(buffer) > 0 for msg in buffer: assert "time" in msg, msg diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 7b7dd4b..7397550 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -10,8 +10,16 @@ from jose.backends import RSAKey from respx import MockRouter from starlette.datastructures import URL, QueryParams +from starlette.requests import Request -from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState +from .._authentication import build_authorize_route +from ..authenticators import ( + LDAPAuthenticator, + OIDCAuthenticator, + ProxiedOIDCAuthenticator, + UserSessionState, + exchange_code, +) LDAP_TEST_HOST = os.environ.get("QSERVER_TEST_LDAP_HOST", "localhost") LDAP_TEST_PORT = int(os.environ.get("QSERVER_TEST_LDAP_PORT", "1389")) @@ -233,22 +241,21 @@ async def test_OIDCAuthenticator_mock( mock_request = create_mock_oidc_request({"code": "test-auth-code"}) - def mock_jwt_decode(*args, **kwargs): - return mock_jwt_payload - - def mock_jwk_construct(*args, **kwargs): - class MockJWK: - pass + decode_calls = {} - return MockJWK() + def mock_decode_token(id_token, access_token=None): + decode_calls["id_token"] = id_token + decode_calls["access_token"] = access_token + return mock_jwt_payload - monkeypatch.setattr("jose.jwt.decode", mock_jwt_decode) - monkeypatch.setattr("jose.jwk.construct", mock_jwk_construct) + monkeypatch.setattr(authenticator, "decode_token", mock_decode_token) user_session = await authenticator.authenticate(mock_request) assert user_session is not None assert user_session.user_name == "0009-0008-8698-7745" + assert decode_calls["id_token"] == "mock-id-token" + assert decode_calls["access_token"] == "mock-access-token" @pytest.mark.asyncio @@ -293,3 +300,60 @@ async def test_OIDCAuthenticator_token_exchange_failure( result = await authenticator.authenticate(mock_request) assert result is None + + +@pytest.mark.asyncio +async def test_exchange_code_requests_offline_access(monkeypatch): + captured = {} + + def mock_post(*, url, data, headers): + captured["url"] = url + captured["data"] = data + captured["headers"] = headers + return httpx.Response(200, json={"id_token": "X", "access_token": "Y"}) + + monkeypatch.setattr("httpx.post", mock_post) + + await exchange_code( + token_uri="https://idp.example/token", + auth_code="authcode", + client_id="client-id", + client_secret="client-secret", + redirect_uri="https://server.example/callback", + extra_scopes=["api://example/access_as_user"], + ) + + assert captured["url"] == "https://idp.example/token" + assert set(captured["data"]["scope"].split(" ")) == { + "openid", + "offline_access", + "api://example/access_as_user", + } + + +@pytest.mark.asyncio +async def test_authorize_route_requests_extra_scopes_and_prompt(): + class _Authenticator: + client_id = "test-client" + extra_scopes = ["api://example/access_as_user"] + authorization_endpoint = httpx.URL("https://idp.example/auth") + + route = build_authorize_route(_Authenticator(), "oidc") + request = Request( + { + "type": "http", + "scheme": "http", + "path": "/api/auth/provider/oidc/authorize", + "root_path": "", + "query_string": b"", + "headers": [(b"host", b"localhost:8000")], + "server": ("localhost", 8000), + "client": ("127.0.0.1", 54321), + } + ) + response = await route(request) + location = response.headers["location"] + assert "prompt=login" in location + assert "offline_access" in location + assert "openid" in location + assert "api%3A%2F%2Fexample%2Faccess_as_user" in location diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py index f3249cd..cedf611 100644 --- a/bluesky_httpserver/tests/test_oidc_authenticators.py +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -10,7 +10,7 @@ from jose.backends import RSAKey from respx import MockRouter -from bluesky_httpserver.authenticators import OIDCAuthenticator, ProxiedOIDCAuthenticator +from bluesky_httpserver.authenticators import EntraAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator @pytest.fixture @@ -217,3 +217,64 @@ def test_proxied_oidc_with_scopes( assert authenticator.scopes == ["openid", "profile", "email"] assert authenticator.device_flow_client_id == "test_cli_client" + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestEntraAuthenticator: + def test_entra_scope_mapping_and_username( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + private_key, _ = keys + authenticator = EntraAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes_map={"User.Read": ["read:monitor"]}, + ) + token_claims = { + "aud": "test_client", + "exp": time.time() + 1500, + "iat": time.time() - 1, + "iss": "https://example.com/realms/example", + "sub": "entra-subject", + "preferred_username": "alice@example.org", + "scp": "User.Read", + } + encoded = encrypt_token(token_claims, private_key) + decoded = authenticator.decode_token(encoded) + assert decoded["user"] == "alice" + assert set(decoded["scope"].split(" ")) == {"read:monitor"} + assert decoded["entra_sub"] == "entra-subject" + + def test_entra_unmapped_scope_warning( + self, + caplog, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + private_key, _ = keys + authenticator = EntraAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes_map={"Known.Scope": ["read:monitor"]}, + ) + token_claims = { + "aud": "test_client", + "exp": time.time() + 1500, + "iat": time.time() - 1, + "iss": "https://example.com/realms/example", + "sub": "entra-subject", + "scp": "Unknown.Scope", + } + encoded = encrypt_token(token_claims, private_key) + with caplog.at_level("WARNING"): + decoded = authenticator.decode_token(encoded) + assert decoded["scope"] == "" + assert any("Unmapped Entra scope" in record.message for record in caplog.records) From 229bb8568d46a4347d64c3264ba78017678b8810 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Thu, 25 Jun 2026 08:50:28 -0600 Subject: [PATCH 15/18] First commit with common-auth flow --- bluesky_httpserver/app.py | 2 +- .../authentication/authenticator_base.py | 54 +- bluesky_httpserver/authenticators.py | 1084 +---------------- .../config_schemas/service_configuration.yml | 8 +- .../tests/test_access_control.py | 8 +- .../tests/test_access_policies.py | 2 +- bluesky_httpserver/tests/test_auth_api.py | 2 +- .../tests/test_auth_for_websockets.py | 2 +- .../tests/test_authenticators.py | 359 ------ .../tests/test_oidc_authenticators.py | 280 ----- docs/source/configuration.rst | 14 +- docs/source/usage.rst | 8 +- requirements.txt | 1 + 13 files changed, 88 insertions(+), 1736 deletions(-) delete mode 100644 bluesky_httpserver/tests/test_authenticators.py delete mode 100644 bluesky_httpserver/tests/test_oidc_authenticators.py diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 0d96667..29e2273 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -14,8 +14,8 @@ from fastapi import APIRouter, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi +from bluesky_authentication.protocols import ExternalAuthenticator, InternalAuthenticator -from .authentication import ExternalAuthenticator, InternalAuthenticator from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py index af103c5..ae7c599 100644 --- a/bluesky_httpserver/authentication/authenticator_base.py +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -1,37 +1,31 @@ -from abc import ABC -from dataclasses import dataclass -from typing import Optional +try: + from bluesky_authentication.protocols import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, + ) +except ModuleNotFoundError: + from abc import ABC + from dataclasses import dataclass + from typing import Optional -from fastapi import Request + from fastapi import Request + @dataclass + class UserSessionState: + """Data transfer class to communicate custom session state information.""" -@dataclass -class UserSessionState: - """Data transfer class to communicate custom session state information.""" + user_name: str + state: dict = None - user_name: str - state: dict = None + class InternalAuthenticator(ABC): + """Base class for authenticators that use username/password credentials.""" + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: + raise NotImplementedError -class InternalAuthenticator(ABC): - """ - Base class for authenticators that use username/password credentials. + class ExternalAuthenticator(ABC): + """Base class for authenticators that use external identity providers.""" - Subclasses must implement the authenticate method which takes a username - and password and returns a UserSessionState on success or None on failure. - """ - - async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: - raise NotImplementedError - - -class ExternalAuthenticator(ABC): - """ - Base class for authenticators that use external identity providers. - - Subclasses must implement the authenticate method which takes a FastAPI - Request object and returns a UserSessionState on success or None on failure. - """ - - async def authenticate(self, request: Request) -> Optional[UserSessionState]: - raise NotImplementedError + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index dfb4466..09c073a 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,1058 +1,40 @@ -import asyncio -import base64 -import functools -import logging -import re -import secrets -import uuid -from collections.abc import Iterable -from datetime import timedelta -from typing import Any, Dict, List, Mapping, Optional, cast +import warnings -import httpx -from cachetools import TTLCache, cached -from fastapi import APIRouter, Request -from fastapi.security import OAuth2, OAuth2AuthorizationCodeBearer -from jose import JWTError, jwt -from pydantic import Secret -from starlette.responses import RedirectResponse -from .authentication import ( +warnings.warn( + "Importing authenticators from 'bluesky_httpserver.authenticators' is deprecated " + "and will be removed in a future release. Use 'bluesky_authentication.authenticators' " + "and 'bluesky_authentication.protocols' instead.", + DeprecationWarning, + stacklevel=2, +) + +from bluesky_authentication.authenticators import ( # noqa: F401 + DictionaryAuthenticator, + DummyAuthenticator, + EntraAuthenticator, + LDAPAuthenticator, + OIDCAuthenticator, + PAMAuthenticator, + ProxiedOIDCAuthenticator, + SAMLAuthenticator, +) +from bluesky_authentication.protocols import ( # noqa: F401 ExternalAuthenticator, InternalAuthenticator, UserSessionState, ) -from .utils import get_root_url, modules_available - -logger = logging.getLogger(__name__) - - -class DummyAuthenticator(InternalAuthenticator): - """ - For test and demo purposes only! - - Accept any username and any password. - - """ - - def __init__(self, confirmation_message: str = ""): - self.confirmation_message = confirmation_message - - async def authenticate(self, username: str, password: str) -> UserSessionState: - return UserSessionState(username, {}) - - -class DictionaryAuthenticator(InternalAuthenticator): - """ - For test and demo purposes only! - - Check passwords from a dictionary of usernames mapped to passwords. - """ - - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - users_to_password: - type: object - description: | - Mapping usernames to password. Environment variable expansion should be - used to avoid placing passwords directly in configuration. - confirmation_message: - type: string - description: May be displayed by client after successful login. -""" - - def __init__(self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""): - self._users_to_passwords = users_to_passwords - self.confirmation_message = confirmation_message - - async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: - true_password = self._users_to_passwords.get(username) - if not true_password: - # Username is not valid. - return None - if secrets.compare_digest(true_password, password): - return UserSessionState(username, {}) - - -class PAMAuthenticator(InternalAuthenticator): - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - service: - type: string - description: PAM service. Default is 'login'. - confirmation_message: - type: string - description: May be displayed by client after successful login. -""" - - def __init__(self, service: str = "login", confirmation_message: str = ""): - if not modules_available("pamela"): - raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") - self.service = service - self.confirmation_message = confirmation_message - # TODO Try to open a PAM session. - - async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: - import pamela - - try: - pamela.authenticate(username, password, service=self.service) - return UserSessionState(username, {}) - except pamela.PAMError: - # Authentication failed. - return None - - -class OIDCAuthenticator(ExternalAuthenticator): - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - audience: - type: string - client_id: - type: string - client_secret: - type: string - well_known_uri: - type: string - confirmation_message: - type: string - redirect_on_success: - type: string - redirect_on_failure: - type: string -""" - - def __init__( - self, - audience: str, - client_id: str, - client_secret: str, - well_known_uri: str, - confirmation_message: str = "", - redirect_on_success: Optional[str] = None, - redirect_on_failure: Optional[str] = None, - ): - self._audience = audience - self._client_id = client_id - self._client_secret = Secret(client_secret) - self._well_known_url = well_known_uri - self.confirmation_message = confirmation_message - self.redirect_on_success = redirect_on_success - self.redirect_on_failure = redirect_on_failure - - @functools.cached_property - def _config_from_oidc_url(self) -> dict[str, Any]: - response: httpx.Response = httpx.get(self._well_known_url) - response.raise_for_status() - return response.json() - - @functools.cached_property - def client_id(self) -> str: - return self._client_id - - @functools.cached_property - def id_token_signing_alg_values_supported(self) -> list[str]: - return cast( - list[str], - self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), - ) - - @functools.cached_property - def issuer(self) -> str: - return cast(str, self._config_from_oidc_url.get("issuer")) - - @functools.cached_property - def jwks_uri(self) -> str: - return cast(str, self._config_from_oidc_url.get("jwks_uri")) - - @functools.cached_property - def token_endpoint(self) -> str: - return cast(str, self._config_from_oidc_url.get("token_endpoint")) - - @functools.cached_property - def authorization_endpoint(self) -> httpx.URL: - return httpx.URL(cast(str, self._config_from_oidc_url.get("authorization_endpoint"))) - - @functools.cached_property - def device_authorization_endpoint(self) -> str: - return cast(str, self._config_from_oidc_url.get("device_authorization_endpoint")) - - @functools.cached_property - def end_session_endpoint(self) -> str: - return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) - - @cached(TTLCache(maxsize=1, ttl=timedelta(hours=1).total_seconds())) - def keys(self) -> List[str]: - return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) - - def decode_token(self, id_token: str, access_token: Optional[str] = None) -> dict[str, Any]: - return jwt.decode( - id_token, - key=self.keys(), - algorithms=self.id_token_signing_alg_values_supported, - audience=self._audience, - issuer=self.issuer, - access_token=access_token, - ) - - async def authenticate(self, request: Request) -> Optional[UserSessionState]: - code = request.query_params.get("code") - if not code: - logger.warning("Authentication failed: No authorization code parameter provided.") - return None - # A proxy in the middle may make the request into something like - # 'http://localhost:8000/...' so we fix the first part but keep - # the original URI path. - redirect_uri = f"{get_root_url(request)}{request.url.path}" - response = await exchange_code( - self.token_endpoint, - code, - self._client_id, - self._client_secret.get_secret_value(), - redirect_uri, - ) - response_body = response.json() - if response.is_error: - logger.error("Authentication error: %r", response_body) - return None - id_token = response_body["id_token"] - access_token = response_body.get("access_token") - # NOTE: We decode the id_token, not access_token, because: - # 1. The id_token is the OIDC identity assertion meant for the client - # 2. Some providers (like Microsoft Entra) return opaque access_tokens - # that cannot be decoded with the JWKS keys when the resource is - # a first-party Microsoft API (e.g., Graph API with User.Read scope) - try: - verified_body = self.decode_token(id_token, access_token) - except JWTError: - logger.exception( - "Authentication error. Unverified token: %r", - jwt.get_unverified_claims(id_token), - ) - return None - # Use preferred_username as the user identifier, extracting just the username - # part if it's in email format (user@domain.com -> user) - preferred_username = verified_body.get("preferred_username") - if preferred_username and "@" in preferred_username: - user_id = preferred_username.split("@")[0] - elif preferred_username: - user_id = preferred_username - else: - user_id = verified_body["sub"] - logger.info( - "OIDC authentication successful. user_id=%r (sub=%r, preferred_username=%r, email=%r, name=%r)", - user_id, - verified_body.get("sub"), - verified_body.get("preferred_username"), - verified_body.get("email"), - verified_body.get("name"), - ) - return UserSessionState(user_id, {}) - - -class ProxiedOIDCAuthenticator(OIDCAuthenticator): - configuration_schema = """ -$schema": http://json-schema.org/draft-07/schema# -type: object -additionalProperties: false -properties: - audience: - type: string - client_id: - type: string - well_known_uri: - type: string - scopes: - type: array - items: - type: string - description: | - Optional list of OAuth2 scopes to request. If provided, authorization - should be enforced by an external policy agent (for example ExternalPolicyDecisionPoint) - rather than by this authenticator. - device_flow_client_id: - type: string - confirmation_message: - type: string -""" - - def __init__( - self, - audience: str, - client_id: str, - well_known_uri: str, - device_flow_client_id: str, - scopes: Optional[List[str]] = None, - confirmation_message: str = "", - ): - super().__init__( - audience=audience, - client_id=client_id, - client_secret="", - well_known_uri=well_known_uri, - confirmation_message=confirmation_message, - ) - self.scopes = scopes - self.device_flow_client_id = device_flow_client_id - self._oidc_bearer = OAuth2AuthorizationCodeBearer( - authorizationUrl=str(self.authorization_endpoint), - tokenUrl=self.token_endpoint, - ) - - @property - def oauth2_schema(self) -> OAuth2: - return self._oidc_bearer - - -class EntraAuthenticator(ProxiedOIDCAuthenticator): - def __init__( - self, - audience: str, - client_id: str, - well_known_uri: str, - device_flow_client_id: str, - extra_scopes: Optional[List[str]] = None, - confirmation_message: str = "", - scopes_map: Optional[Dict[str, list[str]]] = None, - client_secret: str = "", - redirect_on_success: Optional[str] = None, - ): - self.scopes_map = scopes_map if scopes_map is not None else {} - self.extra_scopes = extra_scopes or [] - super().__init__( - audience, - client_id, - well_known_uri, - device_flow_client_id, - scopes=None, - confirmation_message=confirmation_message, - ) - if client_secret: - self._client_secret = Secret(client_secret) - self.redirect_on_success = redirect_on_success - - @property - def scopes(self): - mapped = set() - for tiled_scopes in self.scopes_map.values(): - mapped.update(tiled_scopes) - return list(mapped) - - @scopes.setter - def scopes(self, value): - pass - - def decode_token(self, id_token: str, access_token: Optional[str] = None) -> dict[str, Any]: - claims = super().decode_token(id_token, access_token) - original_sub = claims.get("sub") - issuer = claims.get("iss", "") - claims["sub"] = uuid.uuid5(uuid.NAMESPACE_URL, f"{issuer}|{original_sub}").hex - claims["entra_sub"] = original_sub - - claims["entra_username"] = ( - claims.get("nameID") or claims.get("preferred_username") or claims.get("upn") or claims.get("email") - ) - - if user := claims.get("entra_username"): - user = user.strip() - if "\\" in user: - user = user.rsplit("\\", 1)[-1] - elif "@" in user: - user = user.split("@", 1)[0] - else: - user = original_sub - logger.warning( - "EntraAuthenticator: no human-readable username claim found in token " - "(checked nameID, preferred_username, upn, email). Falling back to Entra sub=%r.", - original_sub, - ) - claims["user"] = user - - scp_raw = claims.get("scp", "") - tiled_scope_set = set() - if scp_raw: - for scope in scp_raw.split(" "): - mapped_scopes = self.scopes_map.get(scope) - if mapped_scopes is None: - logger.warning("Unmapped Entra scope in 'scp': %s", scope) - continue - tiled_scope_set.update(mapped_scopes) - else: - for mapped_scopes in self.scopes_map.values(): - tiled_scope_set.update(mapped_scopes) - claims["scope"] = " ".join(tiled_scope_set) - return claims - - async def authenticate(self, request: Request) -> Optional[UserSessionState]: - code = request.query_params.get("code") - if not code: - logger.warning("Authentication failed: No authorization code parameter provided.") - return None - redirect_uri = f"{get_root_url(request)}{request.url.path}" - response = await exchange_code( - self.token_endpoint, - code, - self._client_id, - self._client_secret.get_secret_value(), - redirect_uri, - extra_scopes=self.extra_scopes, - ) - response_body = response.json() - if response.is_error: - logger.error("Authentication error: %r", response_body) - return None - id_token = response_body["id_token"] - access_token = response_body.get("access_token") - refresh_token = response_body.get("refresh_token") - try: - verified_body = self.decode_token(id_token, access_token) - except JWTError: - logger.exception( - "Authentication error. Unverified token: %r", - jwt.get_unverified_claims(id_token), - ) - return None - username = verified_body.get("user") or verified_body["sub"] - state: dict[str, Any] = {} - if access_token: - state["entra_access_token"] = access_token - if refresh_token: - state["entra_refresh_token"] = refresh_token - return UserSessionState(username, state) - - -async def exchange_code( - token_uri: str, - auth_code: str, - client_id: str, - client_secret: str, - redirect_uri: str, - extra_scopes: Optional[List[str]] = None, -) -> httpx.Response: - """Method that talks to an IdP to exchange a code for an access_token and/or id_token - Args: - token_url ([type]): [description] - auth_code ([type]): [description] - """ - scopes = {"openid", "offline_access"} - if extra_scopes: - scopes.update(extra_scopes) - auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() - response = httpx.post( - url=token_uri, - data={ - "grant_type": "authorization_code", - "client_id": client_id, - "redirect_uri": redirect_uri, - "code": auth_code, - "client_secret": client_secret, - "scope": " ".join(sorted(scopes)), - }, - headers={"Authorization": f"Basic {auth_value}"}, - ) - return response - - -class SAMLAuthenticator(ExternalAuthenticator): - - def __init__( - self, - saml_settings, # See EXAMPLE_SAML_SETTINGS below. - attribute_name: str, # which SAML attribute to use as 'id' for Identity - confirmation_message: str = "", - ): - self.saml_settings = saml_settings - self.attribute_name = attribute_name - self.confirmation_message = confirmation_message - self.authorization_endpoint = "/login" - - router = APIRouter() - - if not modules_available("onelogin"): - # The PyPI package name is 'python3-saml' - # but it imports as 'onelogin'. - # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") - - from onelogin.saml2.auth import OneLogin_Saml2_Auth - - @router.get("/login") - async def saml_login(request: Request) -> RedirectResponse: - req = await prepare_saml_from_fastapi_request(request) - auth = OneLogin_Saml2_Auth(req, self.saml_settings) - callback_url = auth.login() - return RedirectResponse(url=callback_url) - - self.include_routers = [router] - - async def authenticate(self, request: Request) -> Optional[UserSessionState]: - if not modules_available("onelogin"): - raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") - from onelogin.saml2.auth import OneLogin_Saml2_Auth - - req = await prepare_saml_from_fastapi_request(request, True) - auth = OneLogin_Saml2_Auth(req, self.saml_settings) - auth.process_response() # Process IdP response - errors = auth.get_errors() # This method receives an array with the errors - if errors: - raise Exception( - "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) - ) - if auth.is_authenticated(): - # Return a string that the Identity can use as id. - attribute_as_list = auth.get_attributes()[self.attribute_name] - # Confused in what situation this would have more than one item.... - assert len(attribute_as_list) == 1 - return UserSessionState(attribute_as_list[0], {}) - else: - return None - - -async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]: - form_data = await request.form() - rv = { - "http_host": request.client.host, - "server_port": request.url.port, - "script_name": request.url.path, - "post_data": {}, - "get_data": {}, - # Advanced request options - # "https": "", - # "request_uri": "", - # "query_string": "", - # "validate_signature_from_qs": False, - # "lowercase_urlencoding": False - } - if request.query_params: - rv["get_data"] = (request.query_params,) - if "SAMLResponse" in form_data: - SAMLResponse = form_data["SAMLResponse"] - rv["post_data"]["SAMLResponse"] = SAMLResponse - if "RelayState" in form_data: - RelayState = form_data["RelayState"] - rv["post_data"]["RelayState"] = RelayState - return rv - - -class LDAPAuthenticator(InternalAuthenticator): - """ - LDAP authenticator. - The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator - The parameter ``use_tls`` was added for convenience of testing. - - Parameters - ---------- - server_address: str or list(str) - Address(es) of the LDAP server(s) to contact. A string value may represent a single - server, a list of strings may represent one or more servers. If a server address - includes port, then the value of ``server_port`` is ignored, otherwise ``server_port`` - or the default port is used to access the server. - - Could be an IP address or hostname. - server_port: int or None - Port on which to contact the LDAP server. Default port is used if ``None``. - - Defaults to ``636`` if ``use_ssl`` is set, ``389`` otherwise. - use_ssl: boolean - Use SSL to communicate with the LDAP server. - - Deprecated in version 3 of LDAP. Your LDAP server must be configured to support this, however. - use_tls: boolean - Enable/disable TLS if ``use_ssl`` is False. By default TLS is enabled. It should not be disabled - in production systems. - - connect_timeout: float - Timeout used for connecting to the LDAP server. Default: 5. - - receive_timeout: float - Timeout used for communication with the LDAP server, e.g. this timeout is used to wait for - completion of 2FA. For smooth operation it should probably exceed timeout set at LDAP server. - Default: 60. - - bind_dn_template: list or str - Template from which to construct the full dn - when authenticating to LDAP. ``{username}`` is replaced - with the actual username used to log in. - - If your LDAP is set in such a way that the userdn can not - be formed from a template, but must be looked up with an attribute - (such as uid or ``sAMAccountName``), please see ``lookup_dn``. It might - be particularly relevant for ActiveDirectory installs. - - Unicode Example: - - .. code-block:: - - "uid={username},ou=people,dc=wikimedia,dc=org" - - List Example: - - .. code-block:: - - [ - "uid={username},ou=people,dc=wikimedia,dc=org", - "uid={username},ou=Developers,dc=wikimedia,dc=org" - ] - allowed_groups: list or None - List of LDAP group DNs that users could be members of to be granted access. - - If a user is in any one of the listed groups, then that user is granted access. - Membership is tested by fetching info about each group and looking for the User's - dn to be a value of one of `member` or `uniqueMember`, *or* if the username being - used to log in with is value of the `uid`. - - Set to an empty list or None to allow all users that have an LDAP account to log in, - without performing any group membership checks. - valid_username_regex: str - Regex for validating usernames - those that do not match this regex will be rejected. - - This is primarily used as a measure against LDAP injection, which has fatal security - considerations. The default works for most LDAP installations, but some users might need - to modify it to fit their custom installs. If you are modifying it, be sure to understand - the implications of allowing additional characters in usernames and what that means for - LDAP injection issues. See https://www.owasp.org/index.php/LDAP_injection for an overview - of LDAP injection. - lookup_dn: boolean - Form user's DN by looking up an entry from directory - - By default, LDAPAuthenticator finds the user's DN by using `bind_dn_template`. - However, in some installations, the user's DN does not contain the username, and - hence needs to be looked up. You can set this to True and then use ``user_search_base`` - and ``user_attribute`` to accomplish this. - user_search_base: str - Base for looking up user accounts in the directory, if `lookup_dn` is set to True. - - LDAPAuthenticator will search all objects matching under this base where the `user_attribute` - is set to the current username to form the userdn. - - For example, if all users objects existed under the base ou=people,dc=wikimedia,dc=org, and - the username users use is set with the attribute `uid`, you can use the following config: - - .. code-block:: - - c.LDAPAuthenticator.lookup_dn = True - c.LDAPAuthenticator.lookup_dn_search_filter = '({login_attr}={login})' - c.LDAPAuthenticator.lookup_dn_search_user = 'ldap_search_user_technical_account' - c.LDAPAuthenticator.lookup_dn_search_password = 'secret' - c.LDAPAuthenticator.user_search_base = 'ou=people,dc=wikimedia,dc=org' - c.LDAPAuthenticator.user_attribute = 'sAMAccountName' - c.LDAPAuthenticator.lookup_dn_user_dn_attribute = 'cn' - c.LDAPAuthenticator.bind_dn_template = '{username}' - user_attribute: str - Attribute containing user's name, if ``lookup_dn`` is set to True. - - See ``user_search_base`` for info on how this attribute is used. - - For most LDAP servers, this is uid. For Active Directory, it is - sAMAccountName. - lookup_dn_search_filter: str or None - How to query LDAP for user name lookup, if ``lookup_dn`` is set to True. - lookup_dn_search_user: str or None - Technical account for user lookup, if ``lookup_dn`` is set to True. - - If both lookup_dn_search_user and lookup_dn_search_password are None, - then anonymous LDAP query will be done. - lookup_dn_search_password: str or None - Technical account for user lookup, if ``lookup_dn`` is set to True. - lookup_dn_user_dn_attribute: str or None - Attribute containing user's name needed for building DN string, if ``lookup_dn`` is set to True. - - See ``user_search_base`` for info on how this attribute is used. - - For most LDAP servers, this is username. For Active Directory, it is cn. - escape_userdn: boolean - If set to True, escape special chars in userdn when authenticating in LDAP. - - On some LDAP servers, when userdn contains chars like '(', ')', '\' - authentication may fail when those chars - are not escaped. - search_filter: str - LDAP3 Search Filter whose results are allowed access - attributes: list or None - List of attributes to be searched - auth_state_attributes: list or None - List of attributes to be returned in auth_state for a user - use_lookup_dn_username: boolean - If set to true uses the ``lookup_dn_user_dn_attribute`` attribute as username instead of - the supplied one. - - This can be useful in an heterogeneous environment, when supplying a UNIX username - to authenticate against AD. - confirmation_message: str - May be displayed by client after successful login. - - Examples - -------- - - Using the authenticator class (the code runs in ``asyncio`` loop): - - .. code-block:: - - from bluesky_httpserver.authenticators import LDAPAuthenticator - authenticator = LDAPAuthenticator( - "localhost", 1389, bind_dn_template="cn={username},ou=users,dc=example,dc=org", use_tls=False - ) - await authenticator.authenticate("user01", "password1") - await authenticator.authenticate("user02", "password2") - - - Simple example of a config file (e.g. ``config_ldap.yml``): - - .. code-block:: - - uvicorn: - host: localhost - port: 60610 - authentication: - providers: - - provider: ldap_local - authenticator: bluesky_httpserver.authenticators:LDAPAuthenticator - args: - server_address: localhost - server_port: 1389 - bind_dn_template: "cn={username},ou=users,dc=example,dc=org" - use_tls: false - use_ssl: false - tiled_admins: - - provider: ldap_local - id: user02 - """ - - def __init__( - self, - server_address, - server_port=None, - *, - use_ssl=False, - use_tls=True, - connect_timeout=5, - receive_timeout=60, - bind_dn_template=None, - allowed_groups=None, - valid_username_regex=r"^[a-z][.a-z0-9_-]*$", - lookup_dn=False, - user_search_base=None, - user_attribute=None, - lookup_dn_search_filter="({login_attr}={login})", - lookup_dn_search_user=None, - lookup_dn_search_password=None, - lookup_dn_user_dn_attribute=None, - escape_userdn=False, - search_filter="", - attributes=None, - auth_state_attributes=None, - use_lookup_dn_username=True, - confirmation_message="", - ): - self.use_ssl = use_ssl - self.use_tls = use_tls - self.connect_timeout = connect_timeout - self.receive_timeout = receive_timeout - self.bind_dn_template = bind_dn_template - self.allowed_groups = allowed_groups - self.valid_username_regex = valid_username_regex - self.lookup_dn = lookup_dn - self.user_search_base = user_search_base - self.user_attribute = user_attribute - self.lookup_dn_search_filter = lookup_dn_search_filter - self.lookup_dn_search_user = lookup_dn_search_user - self.lookup_dn_search_password = lookup_dn_search_password - self.lookup_dn_user_dn_attribute = lookup_dn_user_dn_attribute - self.escape_userdn = escape_userdn - self.search_filter = search_filter - self.attributes = attributes if attributes else [] - self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] - self.use_lookup_dn_username = use_lookup_dn_username - - if isinstance(server_address, str): - server_address_list = [server_address] - elif isinstance(server_address, Iterable): - server_address_list = list(server_address) - else: - raise TypeError( - f"Unsupported type of `server_address` (list): server_address={server_address} " - f"type(server_address)={type(server_address)}" - ) - if not server_address_list: - raise ValueError("No servers are specified: 'server_address' is an empty list") - - self.server_address_list = server_address_list - self.server_port = server_port if server_port is not None else self._server_port_default() - self.confirmation_message = confirmation_message - - def _server_port_default(self): - if self.use_ssl: - return 636 # default SSL port for LDAP - else: - return 389 # default plaintext port for LDAP - - async def resolve_username(self, username_supplied_by_user): - import ldap3 - - search_dn = self.lookup_dn_search_user - if self.escape_userdn: - search_dn = ldap3.utils.conv.escape_filter_chars(search_dn) - conn = await asyncio.get_running_loop().run_in_executor( - None, self.get_connection, search_dn, self.lookup_dn_search_password - ) - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) - if not is_bound: - msg = "Failed to connect to LDAP server with search user '{search_dn}'" - self.log.warning(msg.format(search_dn=search_dn)) - return (None, None) - - search_filter = self.lookup_dn_search_filter.format( - login_attr=self.user_attribute, login=username_supplied_by_user - ) - msg = "\n".join( - [ - "Looking up user with:", - " search_base = '{search_base}'", - " search_filter = '{search_filter}'", - " attributes = '{attributes}'", - ] - ) - logger.debug( - msg.format( - search_base=self.user_search_base, - search_filter=search_filter, - attributes=self.user_attribute, - ) - ) - - search_func = functools.partial( - conn.search, - search_base=self.user_search_base, - search_scope=ldap3.SUBTREE, - search_filter=search_filter, - attributes=[self.lookup_dn_user_dn_attribute], - ) - await asyncio.get_running_loop().run_in_executor(None, search_func) - - response = conn.response - if len(response) == 0 or "attributes" not in response[0].keys(): - msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" - logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) - return (None, None) - - user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] - if isinstance(user_dn, list): - if len(user_dn) == 0: - return (None, None) - elif len(user_dn) == 1: - user_dn = user_dn[0] - else: - msg = ( - "A lookup of the username '{username}' returned a list " - "of entries for the attribute '{attribute}'. Only the " - "first among these ('{first_entry}') was used. The other " - "entries ({other_entries}) were ignored." - ) - logger.warning( - msg.format( - username=username_supplied_by_user, - attribute=self.lookup_dn_user_dn_attribute, - first_entry=user_dn[0], - other_entries=", ".join(user_dn[1:]), - ) - ) - user_dn = user_dn[0] - - return (user_dn, response[0]["dn"]) - - def get_connection(self, userdn, password): - import ldap3 - - # NOTE: setting 'active=False' essentially disables exclusion of inactive servers from the pool. - # It probably does not matter if the pool contains only one server, but it could have implications - # when there are multiple servers in the pool. It is not clear what those implications are. - # But using the default 'activate=True' results in the thread being blocked indefinitely - # at the step of creating 'ldap3.Connection' regardless of timeouts in case all the servers are - # inactive (e.g. the pool has one server and it is unaccessible), which is unacceptable. - # Further investigation may be needed in the future. - server_pool = ldap3.ServerPool(None, ldap3.RANDOM, active=False) - for address in self.server_address_list: - if re.search(r".+:\d+", address): - # Port is found in the address - address_split = address.split(":") - server_addr = ":".join(address_split[:-1]) - server_port = int(address_split[-1]) - else: - # Use the default port - server_addr = address - server_port = self.server_port - - server = ldap3.Server( - server_addr, - port=server_port, - use_ssl=self.use_ssl, - connect_timeout=self.connect_timeout, - ) - server_pool.add(server) - - auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS - auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl - conn = ldap3.Connection( - server_pool, - user=userdn, - password=password, - auto_bind=auto_bind, - receive_timeout=self.receive_timeout, - ) - return conn - - async def get_user_attributes(self, conn, userdn): - attrs = {} - if self.auth_state_attributes: - search_func = functools.partial( - conn.search, - userdn, - "(objectClass=*)", - attributes=self.auth_state_attributes, - ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) - if found: - attrs = conn.entries[0].entry_attributes_as_dict - return attrs - - async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: - import ldap3 - - username_saved = username # Save the user name passed as a parameter - - # Protect against invalid usernames as well as LDAP injection attacks - if not re.match(self.valid_username_regex, username): - logger.warning( - "username:%s Illegal characters in username, must match regex %s", - username, - self.valid_username_regex, - ) - return None - - # No empty passwords! - if password is None or password.strip() == "": - logger.warning("username:%s Login denied for blank password", username) - return None - - # bind_dn_template should be of type List[str] - bind_dn_template = self.bind_dn_template - if isinstance(bind_dn_template, str): - bind_dn_template = [bind_dn_template] - - # sanity check - if not self.lookup_dn and not bind_dn_template: - logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") - return None - - if self.lookup_dn: - username, resolved_dn = await self.resolve_username(username) - if not username: - return None - if str(self.lookup_dn_user_dn_attribute).upper() == "CN": - # Only escape commas if the lookup attribute is CN - username = re.subn(r"([^\\]),", r"\1\,", username)[0] - if not bind_dn_template: - bind_dn_template = [resolved_dn] - - is_bound = False - for dn in bind_dn_template: - if not dn: - logger.warning("Ignoring blank 'bind_dn_template' entry!") - continue - userdn = dn.format(username=username) - if self.escape_userdn: - userdn = ldap3.utils.conv.escape_filter_chars(userdn) - msg = "Attempting to bind {username} with {userdn}" - logger.debug(msg.format(username=username, userdn=userdn)) - msg = "Status of user bind {username} with {userdn} : {is_bound}" - try: - conn = await asyncio.get_running_loop().run_in_executor( - None, self.get_connection, userdn, password - ) - except ldap3.core.exceptions.LDAPBindError as exc: - is_bound = False - msg += "\n{exc_type}: {exc_msg}".format( - exc_type=exc.__class__.__name__, - exc_msg=exc.args[0] if exc.args else "", - ) - else: - if conn.bound: - is_bound = True - else: - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) - - msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) - logger.debug(msg) - if is_bound: - break - - if not is_bound: - msg = "Invalid password for user '{username}'" - logger.warning(msg.format(username=username)) - return None - - if self.search_filter: - search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) - - search_func = functools.partial( - conn.search, - search_base=self.user_search_base, - search_scope=ldap3.SUBTREE, - search_filter=search_filter, - attributes=self.attributes, - ) - await asyncio.get_running_loop().run_in_executor(None, search_func) - - n_users = len(conn.response) - if n_users == 0: - msg = "User with '{userattr}={username}' not found in directory" - logger.warning(msg.format(userattr=self.user_attribute, username=username)) - return None - if n_users > 1: - msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" - logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) - return None - - if self.allowed_groups: - logger.debug("username:%s Using dn %s", username, userdn) - found = False - for group in self.allowed_groups: - group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" - group_filter = group_filter.format(userdn=userdn, uid=username) - group_attributes = ["member", "uniqueMember", "memberUid"] - - search_func = functools.partial( - conn.search, - group, - search_scope=ldap3.BASE, - search_filter=group_filter, - attributes=group_attributes, - ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) - if found: - break - - if not found: - # If we reach here, then none of the groups matched - msg = "username:{username} User not in any of the allowed groups" - logger.warning(msg.format(username=username)) - return None - - if not self.use_lookup_dn_username: - username = username_saved - user_info = await self.get_user_attributes(conn, userdn) - if user_info: - logger.debug("username:%s attributes:%s", username, user_info) - # this path might never have been worked out...is it ever hit? - return UserSessionState(username, user_info) - return UserSessionState(username, {}) +__all__ = [ + "DictionaryAuthenticator", + "DummyAuthenticator", + "EntraAuthenticator", + "ExternalAuthenticator", + "InternalAuthenticator", + "LDAPAuthenticator", + "OIDCAuthenticator", + "PAMAuthenticator", + "ProxiedOIDCAuthenticator", + "SAMLAuthenticator", + "UserSessionState", +] diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index a76e4d3..e7e5148 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -83,8 +83,10 @@ properties: description: | Type of Authenticator to use. - These are typically from the bluesky_httpserver.authenticators module, + These are typically from the bluesky_authentication.authenticators module, though user-defined ones may be used as well. + Legacy import paths under bluesky_httpserver.authenticators remain + supported for backward compatibility. This is given as an import path. In an import path, packages/modules are separated by dots, and the object itself it separated by a colon. @@ -92,7 +94,7 @@ properties: Example: ```yaml - authenticator: bluesky_httpserver.authenticators:DummyAuthenticator + authenticator: bluesky_authentication.authenticators:DummyAuthenticator ``` args: type: object @@ -103,7 +105,7 @@ properties: Example: ```yaml - authenticator: bluesky_httpserver.authenticators:PAMAuthenticator + authenticator: bluesky_authentication.authenticators:PAMAuthenticator args: service: "custom_service" ``` diff --git a/bluesky_httpserver/tests/test_access_control.py b/bluesky_httpserver/tests/test_access_control.py index e6afdf0..007bb61 100644 --- a/bluesky_httpserver/tests/test_access_control.py +++ b/bluesky_httpserver/tests/test_access_control.py @@ -47,7 +47,7 @@ allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -61,7 +61,7 @@ allow_anonymous_access: False providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -177,7 +177,7 @@ allow_anonymous_access: False providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password @@ -688,7 +688,7 @@ def test_authentication_and_authorization_08( authentication: providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_access_policies.py b/bluesky_httpserver/tests/test_access_policies.py index 374d711..867b906 100644 --- a/bluesky_httpserver/tests/test_access_policies.py +++ b/bluesky_httpserver/tests/test_access_policies.py @@ -465,7 +465,7 @@ async def read_info(): authentication: providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_auth_api.py b/bluesky_httpserver/tests/test_auth_api.py index ad23e5a..fb1630f 100644 --- a/bluesky_httpserver/tests/test_auth_api.py +++ b/bluesky_httpserver/tests/test_auth_api.py @@ -13,7 +13,7 @@ allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py index 17fb498..f449a2c 100644 --- a/bluesky_httpserver/tests/test_auth_for_websockets.py +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -22,7 +22,7 @@ allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: bob_password diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py deleted file mode 100644 index 7397550..0000000 --- a/bluesky_httpserver/tests/test_authenticators.py +++ /dev/null @@ -1,359 +0,0 @@ -import asyncio -import os -import time -from typing import Any, Tuple - -import httpx -import pytest -from cryptography.hazmat.primitives.asymmetric import rsa -from jose import ExpiredSignatureError, jwt -from jose.backends import RSAKey -from respx import MockRouter -from starlette.datastructures import URL, QueryParams -from starlette.requests import Request - -from .._authentication import build_authorize_route -from ..authenticators import ( - LDAPAuthenticator, - OIDCAuthenticator, - ProxiedOIDCAuthenticator, - UserSessionState, - exchange_code, -) - -LDAP_TEST_HOST = os.environ.get("QSERVER_TEST_LDAP_HOST", "localhost") -LDAP_TEST_PORT = int(os.environ.get("QSERVER_TEST_LDAP_PORT", "1389")) -LDAP_TEST_ALT_HOST = os.environ.get("QSERVER_TEST_LDAP_ALT_HOST") -if not LDAP_TEST_ALT_HOST: - LDAP_TEST_ALT_HOST = "127.0.0.1" if LDAP_TEST_HOST == "localhost" else LDAP_TEST_HOST - - -# fmt: off - - -@pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ - (LDAP_TEST_HOST, LDAP_TEST_PORT), - (f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", 904), # Random port, ignored - (f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", None), - (LDAP_TEST_ALT_HOST, LDAP_TEST_PORT), - (f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}", 904), - ([LDAP_TEST_HOST], LDAP_TEST_PORT), - ([LDAP_TEST_HOST, LDAP_TEST_ALT_HOST], LDAP_TEST_PORT), - ([LDAP_TEST_HOST, f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}"], LDAP_TEST_PORT), - ([f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}"], None), -]) -# fmt: on -@pytest.mark.parametrize("use_tls,use_ssl", [(False, False)]) -def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server_port): - """ - Basic test for ``LDAPAuthenticator``. - - TODO: The test could be extended with enabled TLS or SSL, but it requires configuration - of the LDAP server. - """ - authenticator = LDAPAuthenticator( - ldap_server_address, - ldap_server_port, - bind_dn_template="cn={username},ou=users,dc=example,dc=org", - use_tls=use_tls, - use_ssl=use_ssl, - ) - - async def testing(): - assert await authenticator.authenticate("user01", "password1") == UserSessionState("user01", {}) - assert await authenticator.authenticate("user02", "password2") == UserSessionState("user02", {}) - assert await authenticator.authenticate("user02a", "password2") is None - assert await authenticator.authenticate("user02", "password2a") is None - - asyncio.run(testing()) - - -@pytest.fixture -def oidc_well_known_url(oidc_base_url: str) -> str: - return f"{oidc_base_url}.well-known/openid-configuration" - - -@pytest.fixture -def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: - private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - public_key = private_key.public_key() - return (private_key, public_key) - - -@pytest.fixture -def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: - _, public_key = keys - return [RSAKey(key=public_key, algorithm="RS256").to_dict()] - - -@pytest.fixture -def mock_oidc_server( - respx_mock: MockRouter, - oidc_well_known_url: str, - well_known_response: dict[str, Any], - json_web_keyset: list[dict[str, Any]], -) -> MockRouter: - respx_mock.get(oidc_well_known_url).mock(return_value=httpx.Response(httpx.codes.OK, json=well_known_response)) - respx_mock.get(well_known_response["jwks_uri"]).mock( - return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) - ) - return respx_mock - - -def token(issued: bool, expired: bool) -> dict[str, str]: - now = time.time() - return { - "aud": "tiled", - "exp": (now - 1500) if expired else (now + 1500), - "iat": (now - 1500) if issued else (now + 1500), - "iss": "https://example.com/realms/example", - "sub": "Jane Doe", - } - - -def encrypted_token(token_data: dict[str, str], private_key: rsa.RSAPrivateKey) -> str: - return jwt.encode( - token_data, - key=private_key, - algorithm="RS256", - headers={"kid": "secret"}, - ) - - -def test_oidc_authenticator_caching( - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - well_known_response: dict[str, Any], - json_web_keyset: list[dict[str, Any]], -): - authenticator = OIDCAuthenticator("tiled", "tiled", "secret", well_known_uri=oidc_well_known_url) - assert authenticator.client_id == "tiled" - assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] - assert authenticator.id_token_signing_alg_values_supported == well_known_response[ - "id_token_signing_alg_values_supported" - ] - assert authenticator.issuer == well_known_response["issuer"] - assert authenticator.jwks_uri == well_known_response["jwks_uri"] - assert authenticator.token_endpoint == well_known_response["token_endpoint"] - assert authenticator.device_authorization_endpoint == well_known_response["device_authorization_endpoint"] - assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] - - assert len(mock_oidc_server.calls) == 1 - call_request = mock_oidc_server.calls[0].request - assert call_request.method == "GET" - assert call_request.url == oidc_well_known_url - - assert authenticator.keys() == json_web_keyset - assert len(mock_oidc_server.calls) == 2 - keys_request = mock_oidc_server.calls[1].request - assert keys_request.method == "GET" - assert keys_request.url == well_known_response["jwks_uri"] - - for _ in range(10): - assert authenticator.keys() == json_web_keyset - - assert len(mock_oidc_server.calls) == 2 - - -@pytest.mark.parametrize("issued", [True, False]) -@pytest.mark.parametrize("expired", [True, False]) -def test_oidc_decoding( - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - issued: bool, - expired: bool, - keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], -): - private_key, _ = keys - authenticator = OIDCAuthenticator("tiled", "tiled", "secret", well_known_uri=oidc_well_known_url) - access_token = token(issued, expired) - encrypted_access_token = encrypted_token(access_token, private_key) - - if not expired: - assert authenticator.decode_token(encrypted_access_token) == access_token - else: - with pytest.raises(ExpiredSignatureError): - authenticator.decode_token(encrypted_access_token) - - -@pytest.mark.asyncio -async def test_proxied_oidc_token_retrieval(oidc_well_known_url: str, mock_oidc_server: MockRouter): - authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, - device_flow_client_id="tiled-cli") - test_request = httpx.Request("GET", "http://example.com", headers={"Authorization": "bearer FOO"}) - - assert "FOO" == await authenticator.oauth2_schema(test_request) - - -def create_mock_oidc_request(query_params=None): - if query_params is None: - query_params = {} - - class MockRequest: - def __init__(self, request_query_params): - self.query_params = QueryParams(request_query_params) - self.scope = { - "type": "http", - "scheme": "http", - "server": ("localhost", 8000), - "path": "/api/v1/auth/provider/orcid/code", - "headers": [], - } - self.headers = {"host": "localhost:8000"} - self.url = URL("http://localhost:8000/api/v1/auth/provider/orcid/code") - - return MockRequest(query_params) - - -@pytest.mark.asyncio -async def test_OIDCAuthenticator_mock( - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - well_known_response: dict[str, Any], - monkeypatch, -): - mock_jwt_payload = { - "sub": "0009-0008-8698-7745", - "aud": "APP-TEST-CLIENT-ID", - "iss": well_known_response["issuer"], - "exp": 9999999999, - "iat": 1000000000, - "given_name": "Test User", - } - - mock_oidc_server.post(well_known_response["token_endpoint"]).mock( - return_value=httpx.Response( - 200, - json={ - "access_token": "mock-access-token", - "id_token": "mock-id-token", - "token_type": "bearer", - }, - ) - ) - - authenticator = OIDCAuthenticator( - audience="APP-TEST-CLIENT-ID", - client_id="APP-TEST-CLIENT-ID", - client_secret="test-secret", - well_known_uri=oidc_well_known_url, - ) - - mock_request = create_mock_oidc_request({"code": "test-auth-code"}) - - decode_calls = {} - - def mock_decode_token(id_token, access_token=None): - decode_calls["id_token"] = id_token - decode_calls["access_token"] = access_token - return mock_jwt_payload - - monkeypatch.setattr(authenticator, "decode_token", mock_decode_token) - - user_session = await authenticator.authenticate(mock_request) - - assert user_session is not None - assert user_session.user_name == "0009-0008-8698-7745" - assert decode_calls["id_token"] == "mock-id-token" - assert decode_calls["access_token"] == "mock-access-token" - - -@pytest.mark.asyncio -async def test_OIDCAuthenticator_missing_code_parameter(oidc_well_known_url: str): - authenticator = OIDCAuthenticator( - audience="APP-TEST-CLIENT-ID", - client_id="APP-TEST-CLIENT-ID", - client_secret="test-secret", - well_known_uri=oidc_well_known_url, - ) - - mock_request = create_mock_oidc_request({}) - - result = await authenticator.authenticate(mock_request) - assert result is None - - -@pytest.mark.asyncio -async def test_OIDCAuthenticator_token_exchange_failure( - oidc_well_known_url: str, - mock_oidc_server, - well_known_response, -): - mock_oidc_server.post(well_known_response["token_endpoint"]).mock( - return_value=httpx.Response( - 400, - json={ - "error": "invalid_client", - "error_description": "Client not found: APP-TEST-CLIENT-ID", - }, - ) - ) - - authenticator = OIDCAuthenticator( - audience="APP-TEST-CLIENT-ID", - client_id="APP-TEST-CLIENT-ID", - client_secret="test-secret", - well_known_uri=oidc_well_known_url, - ) - - mock_request = create_mock_oidc_request({"code": "invalid-code"}) - - result = await authenticator.authenticate(mock_request) - assert result is None - - -@pytest.mark.asyncio -async def test_exchange_code_requests_offline_access(monkeypatch): - captured = {} - - def mock_post(*, url, data, headers): - captured["url"] = url - captured["data"] = data - captured["headers"] = headers - return httpx.Response(200, json={"id_token": "X", "access_token": "Y"}) - - monkeypatch.setattr("httpx.post", mock_post) - - await exchange_code( - token_uri="https://idp.example/token", - auth_code="authcode", - client_id="client-id", - client_secret="client-secret", - redirect_uri="https://server.example/callback", - extra_scopes=["api://example/access_as_user"], - ) - - assert captured["url"] == "https://idp.example/token" - assert set(captured["data"]["scope"].split(" ")) == { - "openid", - "offline_access", - "api://example/access_as_user", - } - - -@pytest.mark.asyncio -async def test_authorize_route_requests_extra_scopes_and_prompt(): - class _Authenticator: - client_id = "test-client" - extra_scopes = ["api://example/access_as_user"] - authorization_endpoint = httpx.URL("https://idp.example/auth") - - route = build_authorize_route(_Authenticator(), "oidc") - request = Request( - { - "type": "http", - "scheme": "http", - "path": "/api/auth/provider/oidc/authorize", - "root_path": "", - "query_string": b"", - "headers": [(b"host", b"localhost:8000")], - "server": ("localhost", 8000), - "client": ("127.0.0.1", 54321), - } - ) - response = await route(request) - location = response.headers["location"] - assert "prompt=login" in location - assert "offline_access" in location - assert "openid" in location - assert "api%3A%2F%2Fexample%2Faccess_as_user" in location diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py deleted file mode 100644 index cedf611..0000000 --- a/bluesky_httpserver/tests/test_oidc_authenticators.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Tests for OIDC Authenticator functionality.""" - -import time -from typing import Any, Tuple - -import httpx -import pytest -from cryptography.hazmat.primitives.asymmetric import rsa -from jose import ExpiredSignatureError, jwt -from jose.backends import RSAKey -from respx import MockRouter - -from bluesky_httpserver.authenticators import EntraAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator - - -@pytest.fixture -def oidc_well_known_url(oidc_base_url: str) -> str: - return f"{oidc_base_url}.well-known/openid-configuration" - - -@pytest.fixture -def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: - """Generate RSA key pair for testing.""" - private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - public_key = private_key.public_key() - return (private_key, public_key) - - -@pytest.fixture -def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: - """Create a JSON Web Key Set from the test keys.""" - _, public_key = keys - return [RSAKey(key=public_key, algorithm="RS256").to_dict()] - - -@pytest.fixture -def mock_oidc_server( - respx_mock: MockRouter, - oidc_well_known_url: str, - well_known_response: dict[str, Any], - json_web_keyset: list[dict[str, Any]], -) -> MockRouter: - """Set up mock OIDC server endpoints.""" - respx_mock.get(oidc_well_known_url).mock(return_value=httpx.Response(httpx.codes.OK, json=well_known_response)) - respx_mock.get(well_known_response["jwks_uri"]).mock( - return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) - ) - return respx_mock - - -def create_token(issued: bool, expired: bool) -> dict[str, Any]: - """Create a test JWT token.""" - now = time.time() - return { - "aud": "test_client", - "exp": (now - 1500) if expired else (now + 1500), - "iat": (now - 1500) if issued else (now + 1500), - "iss": "https://example.com/realms/example", - "sub": "test_user", - } - - -def encrypt_token(token: dict[str, Any], private_key: rsa.RSAPrivateKey) -> str: - """Encrypt a token with the test private key.""" - return jwt.encode( - token, - key=private_key, - algorithm="RS256", - headers={"kid": "test_key"}, - ) - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -class TestOIDCAuthenticator: - """Tests for OIDCAuthenticator class.""" - - def test_oidc_authenticator_caching( - self, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - well_known_response: dict[str, Any], - json_web_keyset: list[dict[str, Any]], - ): - """Test that OIDC configuration is cached after first fetch.""" - authenticator = OIDCAuthenticator( - audience="test_client", - client_id="test_client", - client_secret="secret", - well_known_uri=oidc_well_known_url, - ) - - # Access multiple properties to ensure caching works - assert authenticator.client_id == "test_client" - assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] - assert ( - authenticator.id_token_signing_alg_values_supported - == well_known_response["id_token_signing_alg_values_supported"] - ) - assert authenticator.issuer == well_known_response["issuer"] - assert authenticator.jwks_uri == well_known_response["jwks_uri"] - assert authenticator.token_endpoint == well_known_response["token_endpoint"] - assert authenticator.device_authorization_endpoint == well_known_response["device_authorization_endpoint"] - assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] - - # Should only call well-known endpoint once due to caching - assert len(mock_oidc_server.calls) == 1 - call_request = mock_oidc_server.calls[0].request - assert call_request.method == "GET" - assert call_request.url == oidc_well_known_url - - # Keys should also be cached - assert authenticator.keys() == json_web_keyset - assert len(mock_oidc_server.calls) == 2 # Now also fetched JWKS - - # Multiple calls should still be cached - for _ in range(5): - assert authenticator.keys() == json_web_keyset - assert len(mock_oidc_server.calls) == 2 # No new calls - - @pytest.mark.parametrize("issued", [True, False]) - @pytest.mark.parametrize("expired", [True, False]) - def test_oidc_token_decoding( - self, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - issued: bool, - expired: bool, - keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], - ): - """Test token decoding with various validity scenarios.""" - private_key, _ = keys - authenticator = OIDCAuthenticator( - audience="test_client", - client_id="test_client", - client_secret="secret", - well_known_uri=oidc_well_known_url, - ) - - token = create_token(issued, expired) - encrypted = encrypt_token(token, private_key) - - if not expired: - # Non-expired tokens should decode successfully - decoded = authenticator.decode_token(encrypted) - assert decoded["sub"] == "test_user" - assert decoded["aud"] == "test_client" - else: - # Expired tokens should raise an error - with pytest.raises(ExpiredSignatureError): - authenticator.decode_token(encrypted) - - def test_oidc_authenticator_properties( - self, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - well_known_response: dict[str, Any], - ): - """Test that all authenticator properties are correctly set.""" - authenticator = OIDCAuthenticator( - audience="my_audience", - client_id="my_client_id", - client_secret="my_secret", - well_known_uri=oidc_well_known_url, - confirmation_message="Logged in as {id}", - redirect_on_success="https://app.example.com/success", - redirect_on_failure="https://app.example.com/failure", - ) - - assert authenticator.client_id == "my_client_id" - assert authenticator.confirmation_message == "Logged in as {id}" - assert authenticator.redirect_on_success == "https://app.example.com/success" - assert authenticator.redirect_on_failure == "https://app.example.com/failure" - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -class TestProxiedOIDCAuthenticator: - """Tests for ProxiedOIDCAuthenticator class.""" - - @pytest.mark.asyncio - async def test_proxied_oidc_oauth2_schema( - self, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - ): - """Test that ProxiedOIDCAuthenticator extracts bearer token correctly.""" - authenticator = ProxiedOIDCAuthenticator( - audience="test_client", - client_id="test_client", - well_known_uri=oidc_well_known_url, - device_flow_client_id="test_cli_client", - ) - - # Create a mock request with Authorization header - test_request = httpx.Request( - "GET", - "http://example.com/api/test", - headers={"Authorization": "Bearer TEST_TOKEN"}, - ) - - # The oauth2_schema should extract the bearer token - token = await authenticator.oauth2_schema(test_request) - assert token == "TEST_TOKEN" - - def test_proxied_oidc_with_scopes( - self, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - ): - """Test ProxiedOIDCAuthenticator with custom scopes.""" - authenticator = ProxiedOIDCAuthenticator( - audience="test_client", - client_id="test_client", - well_known_uri=oidc_well_known_url, - device_flow_client_id="test_cli_client", - scopes=["openid", "profile", "email"], - ) - - assert authenticator.scopes == ["openid", "profile", "email"] - assert authenticator.device_flow_client_id == "test_cli_client" - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -class TestEntraAuthenticator: - def test_entra_scope_mapping_and_username( - self, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], - ): - private_key, _ = keys - authenticator = EntraAuthenticator( - audience="test_client", - client_id="test_client", - well_known_uri=oidc_well_known_url, - device_flow_client_id="test_cli_client", - scopes_map={"User.Read": ["read:monitor"]}, - ) - token_claims = { - "aud": "test_client", - "exp": time.time() + 1500, - "iat": time.time() - 1, - "iss": "https://example.com/realms/example", - "sub": "entra-subject", - "preferred_username": "alice@example.org", - "scp": "User.Read", - } - encoded = encrypt_token(token_claims, private_key) - decoded = authenticator.decode_token(encoded) - assert decoded["user"] == "alice" - assert set(decoded["scope"].split(" ")) == {"read:monitor"} - assert decoded["entra_sub"] == "entra-subject" - - def test_entra_unmapped_scope_warning( - self, - caplog, - mock_oidc_server: MockRouter, - oidc_well_known_url: str, - keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], - ): - private_key, _ = keys - authenticator = EntraAuthenticator( - audience="test_client", - client_id="test_client", - well_known_uri=oidc_well_known_url, - device_flow_client_id="test_cli_client", - scopes_map={"Known.Scope": ["read:monitor"]}, - ) - token_claims = { - "aud": "test_client", - "exp": time.time() + 1500, - "iat": time.time() - 1, - "iss": "https://example.com/realms/example", - "sub": "entra-subject", - "scp": "Unknown.Scope", - } - encoded = encrypt_token(token_claims, private_key) - with caplog.at_level("WARNING"): - decoded = authenticator.decode_token(encoded) - assert decoded["scope"] == "" - assert any("Unmapped Entra scope" in record.message for record in caplog.records) diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index 8852a31..989fd31 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -14,6 +14,12 @@ and allow high level of customization of functionality, using configuration YML allows greater flexibility and is considered a preferable way of configuring the server in production deployments. +.. note:: + + Canonical authenticator import paths are in ``bluesky_authentication.authenticators``. + Legacy paths in ``bluesky_httpserver.authenticators`` remain supported for backward + compatibility. + Environment variable for passing the path to server configuration file(s): - ``QSERVER_HTTP_SERVER_CONFIG`` - path to a single YML file or a directory with multiple YML files. @@ -218,7 +224,7 @@ authorization policy and enabled public access:: allow_anonymous_access: True providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: ${BOB_PASSWORD} @@ -269,7 +275,7 @@ respectively. The configuration does not enable public access. :: authentication: providers: - provider: ldap - authenticator: bluesky_httpserver.authenticators:LDAPAuthenticator + authenticator: bluesky_authentication.authenticators:LDAPAuthenticator args: server_address: localhost server_port: 1389 @@ -325,7 +331,7 @@ Example configuration (Microsoft Entra ID):: authentication: providers: - provider: entra - authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + authenticator: bluesky_authentication.authenticators:OIDCAuthenticator args: audience: 00000000-0000-0000-0000-000000000000 client_id: 00000000-0000-0000-0000-000000000000 @@ -346,7 +352,7 @@ Example configuration (Google):: authentication: providers: - provider: google - authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + authenticator: bluesky_authentication.authenticators:OIDCAuthenticator args: audience: client_id: diff --git a/docs/source/usage.rst b/docs/source/usage.rst index bcae133..07cd1f3 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -106,6 +106,12 @@ API calls with invalid token or API key are rejected even if public access is en Authentication API for Users ============================ +.. note:: + + Canonical authenticator import paths are in ``bluesky_authentication.authenticators``. + Legacy paths in ``bluesky_httpserver.authenticators`` remain supported for backward + compatibility. + Logging into the Server (Requesting Token) ------------------------------------------ @@ -122,7 +128,7 @@ is an example of a config file sets up ``DictionaryAPIAccessControl`` as a provi authentication: providers: - provider: toy - authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + authenticator: bluesky_authentication.authenticators:DictionaryAuthenticator args: users_to_passwords: bob: ${BOB_PASSWORD} diff --git a/requirements.txt b/requirements.txt index 1377ef0..5d234b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ alembic +bluesky-authentication bluesky-queueserver bluesky-queueserver-api cachetools From b0d1acab0fc051ae25e166c9dee6ca25ee7d4423 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Thu, 25 Jun 2026 08:53:54 -0600 Subject: [PATCH 16/18] Fixing pre-commit checks --- bluesky_httpserver/app.py | 2 +- bluesky_httpserver/authenticators.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 29e2273..84470d1 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -9,12 +9,12 @@ import urllib.parse from functools import lru_cache, partial +from bluesky_authentication.protocols import ExternalAuthenticator, InternalAuthenticator from bluesky_queueserver.manager.comms import validate_zmq_key from bluesky_queueserver_api.zmq.aio import REManagerAPI from fastapi import APIRouter, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from bluesky_authentication.protocols import ExternalAuthenticator, InternalAuthenticator from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 09c073a..6d4ad01 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,14 +1,5 @@ import warnings - -warnings.warn( - "Importing authenticators from 'bluesky_httpserver.authenticators' is deprecated " - "and will be removed in a future release. Use 'bluesky_authentication.authenticators' " - "and 'bluesky_authentication.protocols' instead.", - DeprecationWarning, - stacklevel=2, -) - from bluesky_authentication.authenticators import ( # noqa: F401 DictionaryAuthenticator, DummyAuthenticator, @@ -25,6 +16,14 @@ UserSessionState, ) +warnings.warn( + "Importing authenticators from 'bluesky_httpserver.authenticators' is deprecated " + "and will be removed in a future release. Use 'bluesky_authentication.authenticators' " + "and 'bluesky_authentication.protocols' instead.", + DeprecationWarning, + stacklevel=2, +) + __all__ = [ "DictionaryAuthenticator", "DummyAuthenticator", From 4b6242cd2db768f4a4b290cd7297e9b562d3eedd Mon Sep 17 00:00:00 2001 From: David Pastl Date: Thu, 25 Jun 2026 13:09:48 -0600 Subject: [PATCH 17/18] Updating to use common auth --- .github/workflows/docs.yml | 2 ++ .github/workflows/docs_publish.yml | 2 ++ .github/workflows/testing.yml | 2 ++ 3 files changed, 6 insertions(+) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 012a480..1a806f7 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -32,6 +32,8 @@ jobs: pip install . popd + pip install --upgrade "git+https://github.com/davidpcls/bluesky-authentication.git@main" + pip install . pip install -r requirements-dev.txt pip list diff --git a/.github/workflows/docs_publish.yml b/.github/workflows/docs_publish.yml index 49d73c3..fc1f0d0 100644 --- a/.github/workflows/docs_publish.yml +++ b/.github/workflows/docs_publish.yml @@ -44,6 +44,8 @@ jobs: pip install . popd + pip install --upgrade "git+https://github.com/davidpcls/bluesky-authentication.git@main" + pip install . pip install -r requirements-dev.txt pip list diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index adef4fc..cf3a820 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -56,6 +56,8 @@ jobs: pip install . popd + pip install --upgrade "git+https://github.com/davidpcls/bluesky-authentication.git@main" + pip install --upgrade pip pip install . pip install -r requirements-dev.txt From eb4e6ca3c68c78e61f8c267e4cbfe68c95fcce98 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Thu, 25 Jun 2026 15:21:57 -0600 Subject: [PATCH 18/18] Fixing missing factory method --- bluesky_httpserver/tests/test_access_control.py | 2 +- bluesky_httpserver/tests/test_access_policies.py | 2 +- bluesky_httpserver/tests/test_auth_api.py | 2 +- bluesky_httpserver/tests/test_auth_for_websockets.py | 2 +- bluesky_httpserver/tests/test_core_api_fs.py | 1 + bluesky_httpserver/tests/test_core_api_main.py | 1 + bluesky_httpserver/tests/test_server.py | 1 + 7 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bluesky_httpserver/tests/test_access_control.py b/bluesky_httpserver/tests/test_access_control.py index 007bb61..244cf46 100644 --- a/bluesky_httpserver/tests/test_access_control.py +++ b/bluesky_httpserver/tests/test_access_control.py @@ -3,7 +3,7 @@ import pprint import pytest -from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd, re_manager_factory # noqa F401 from bluesky_httpserver.authorization._defaults import ( _DEFAULT_RESOURCE_ACCESS_GROUP, diff --git a/bluesky_httpserver/tests/test_access_policies.py b/bluesky_httpserver/tests/test_access_policies.py index 867b906..627a8b4 100644 --- a/bluesky_httpserver/tests/test_access_policies.py +++ b/bluesky_httpserver/tests/test_access_policies.py @@ -6,7 +6,7 @@ import pytest import requests -from bluesky_queueserver.manager.tests.common import re_manager # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_factory # noqa F401 from xprocess import ProcessStarter from bluesky_httpserver.authorization import ( diff --git a/bluesky_httpserver/tests/test_auth_api.py b/bluesky_httpserver/tests/test_auth_api.py index fb1630f..117a1e3 100644 --- a/bluesky_httpserver/tests/test_auth_api.py +++ b/bluesky_httpserver/tests/test_auth_api.py @@ -1,7 +1,7 @@ import pprint import time as ttime -from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd, re_manager_factory # noqa F401 from bluesky_httpserver.authorization._defaults import _DEFAULT_ROLES diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py index f449a2c..91f69d2 100644 --- a/bluesky_httpserver/tests/test_auth_for_websockets.py +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -4,7 +4,7 @@ import time as ttime import pytest -from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd, re_manager_factory # noqa F401 from websockets.sync.client import connect from .conftest import fastapi_server_fs # noqa: F401 diff --git a/bluesky_httpserver/tests/test_core_api_fs.py b/bluesky_httpserver/tests/test_core_api_fs.py index a00d99a..ace9205 100644 --- a/bluesky_httpserver/tests/test_core_api_fs.py +++ b/bluesky_httpserver/tests/test_core_api_fs.py @@ -7,6 +7,7 @@ copy_default_profile_collection, re_manager, re_manager_cmd, + re_manager_factory, re_manager_pc_copy, set_qserver_zmq_address, set_qserver_zmq_public_key, diff --git a/bluesky_httpserver/tests/test_core_api_main.py b/bluesky_httpserver/tests/test_core_api_main.py index 0c471bd..815669e 100644 --- a/bluesky_httpserver/tests/test_core_api_main.py +++ b/bluesky_httpserver/tests/test_core_api_main.py @@ -11,6 +11,7 @@ ip_kernel_simple_client, re_manager, re_manager_cmd, + re_manager_factory, re_manager_pc_copy, ) diff --git a/bluesky_httpserver/tests/test_server.py b/bluesky_httpserver/tests/test_server.py index 33b82a2..10f0d6c 100644 --- a/bluesky_httpserver/tests/test_server.py +++ b/bluesky_httpserver/tests/test_server.py @@ -8,6 +8,7 @@ copy_default_profile_collection, re_manager, re_manager_cmd, + re_manager_factory, re_manager_pc_copy, set_qserver_zmq_address, set_qserver_zmq_public_key,