From bfa28270c493e473cd13efc4fec0fa368036cdce Mon Sep 17 00:00:00 2001 From: Raquel Barbadillo Date: Mon, 15 Jun 2026 18:43:35 +0200 Subject: [PATCH] Fix SDK OpenTelemetry opt-in behavior --- src/mistralai/client/_hooks/tracing.py | 36 +++--- .../extra/observability/telemetry.py | 9 +- .../extra/tests/test_otel_tracing.py | 105 ++++++++++++++++-- 3 files changed, 127 insertions(+), 23 deletions(-) diff --git a/src/mistralai/client/_hooks/tracing.py b/src/mistralai/client/_hooks/tracing.py index 7e2bb17d..54b04c77 100644 --- a/src/mistralai/client/_hooks/tracing.py +++ b/src/mistralai/client/_hooks/tracing.py @@ -44,29 +44,30 @@ def __init__(self) -> None: self._auto_telemetry_provider: Optional[Any] = None self._telemetry_finalizer: Optional[weakref.finalize] = None self._telemetry_auto_disabled: bool = False + self._telemetry_use_global_provider: bool = False self.tracing_enabled, self.tracer = get_or_create_otel_tracer() def before_request( self, hook_ctx: BeforeRequestContext, request: httpx.Request ) -> Union[httpx.Request, Exception]: - # The GenAI span is created in this hook, but HTTPX creates its own - # auto-instrumented span later inside send(). Wrap the configured - # clients so each request's stored GenAI span is current only while - # that request is being sent. - self._ensure_client_send_wrapped(getattr(hook_ctx.config, "client", None)) - self._ensure_async_client_send_wrapped( - getattr(hook_ctx.config, "async_client", None) - ) - configure_telemetry_for_hook( + telemetry_configured = configure_telemetry_for_hook( self, hook_ctx.config, respect_global_provider=True, ) - # Refresh tracer/provider per request so tracing can be enabled if the - # application configures OpenTelemetry after the client is instantiated. - self.tracing_enabled, self.tracer = get_or_create_otel_tracer( - provider=self.tracer_provider, + should_trace = ( + telemetry_configured + or self.tracer_provider is not None + or self._telemetry_use_global_provider ) + if should_trace: + # Refresh tracer/provider per request so tracing can be enabled if the + # application configures OpenTelemetry after the client is instantiated. + self.tracing_enabled, self.tracer = get_or_create_otel_tracer( + provider=self.tracer_provider, + ) + else: + self.tracing_enabled = False request, span = get_traced_request_and_span( tracing_enabled=self.tracing_enabled, tracer=self.tracer, @@ -75,6 +76,15 @@ def before_request( request=request, ) self._store_span_on_request(request, span) + # The GenAI span is created in this hook, but HTTPX creates its own + # auto-instrumented span later inside send(). Wrap the configured + # clients so each request's stored GenAI span is current only while + # that request is being sent. + if span is not None: + self._ensure_client_send_wrapped(getattr(hook_ctx.config, "client", None)) + self._ensure_async_client_send_wrapped( + getattr(hook_ctx.config, "async_client", None) + ) return request @staticmethod diff --git a/src/mistralai/extra/observability/telemetry.py b/src/mistralai/extra/observability/telemetry.py index 822b62b9..22f2922d 100644 --- a/src/mistralai/extra/observability/telemetry.py +++ b/src/mistralai/extra/observability/telemetry.py @@ -135,7 +135,9 @@ def configure_telemetry_for_hook( ) -> bool: """Configure telemetry for a tracing hook when the user has opted in.""" # Fast path: already resolved and no explicit override requested. - if hook._auto_telemetry_provider is not None and telemetry is None: + if telemetry is None and ( + hook._auto_telemetry_provider is not None or hook._telemetry_use_global_provider + ): return True if telemetry is None and hook._telemetry_auto_disabled: return False @@ -155,6 +157,7 @@ def configure_telemetry_for_hook( ) if provider_mode is None: _shutdown_telemetry_provider(hook) + hook._telemetry_use_global_provider = False hook._telemetry_auto_disabled = True return False @@ -176,6 +179,7 @@ def configure_telemetry_for_hook( "configure_telemetry(client, provider='dedicated') to attach an " "SDK-owned provider for this client." ) + hook._telemetry_use_global_provider = False hook._telemetry_auto_disabled = True return False @@ -294,6 +298,7 @@ def _attach_telemetry_provider( _shutdown_telemetry_provider(hook) hook.tracer_provider = provider hook._auto_telemetry_provider = provider + hook._telemetry_use_global_provider = False hook._telemetry_auto_disabled = False hook._telemetry_finalizer = weakref.finalize( finalizer_owner, provider.shutdown @@ -306,6 +311,7 @@ def _attach_custom_tracer_provider( ) -> None: _shutdown_telemetry_provider(hook) hook.tracer_provider = provider + hook._telemetry_use_global_provider = False hook._telemetry_auto_disabled = False @@ -323,6 +329,7 @@ def _use_global_tracer_provider( _shutdown_telemetry_provider(hook) hook.tracer_provider = None + hook._telemetry_use_global_provider = True hook._telemetry_auto_disabled = True return True diff --git a/src/mistralai/extra/tests/test_otel_tracing.py b/src/mistralai/extra/tests/test_otel_tracing.py index b60b05cd..b1543ae7 100644 --- a/src/mistralai/extra/tests/test_otel_tracing.py +++ b/src/mistralai/extra/tests/test_otel_tracing.py @@ -15,11 +15,12 @@ from collections import Counter from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer import json +import os import threading import unittest from datetime import datetime, timezone from typing import cast -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import httpx from opentelemetry import context as context_api @@ -118,9 +119,16 @@ def _make_httpx_response(body: dict, status_code: int = 200) -> httpx.Response: return resp -def _make_hook_context(operation_id: str) -> HookContext: +def _make_hook_context( + operation_id: str, + telemetry: bool | str | None = "global", +) -> HookContext: + config = MagicMock() + config.telemetry = telemetry + config.client = None + config.async_client = None return HookContext( - config=MagicMock(), + config=config, base_url="https://api.mistral.ai", operation_id=operation_id, oauth2_scopes=None, @@ -254,6 +262,7 @@ def _run_hook_lifecycle( request_body, response_body, streaming: bool = False, + telemetry: bool | str | None = "global", ): """Drive the real TracingHook: before_request → after_success. @@ -266,7 +275,7 @@ def _run_hook_lifecycle( so the span is finalised before returning. """ hook = TracingHook() - hook_ctx = _make_hook_context(operation_id) + hook_ctx = _make_hook_context(operation_id, telemetry=telemetry) req_dict = ( _dump(request_body) if hasattr(request_body, "model_dump") else request_body @@ -309,10 +318,11 @@ def _run_hook_error_lifecycle( response_body: dict, status_code: int = 400, error: Exception | None = None, + telemetry: bool | str | None = "global", ): """Drive the real TracingHook: before_request → after_error.""" hook = TracingHook() - hook_ctx = _make_hook_context(operation_id) + hook_ctx = _make_hook_context(operation_id, telemetry=telemetry) req_dict = ( _dump(request_body) if hasattr(request_body, "model_dump") else request_body @@ -343,6 +353,36 @@ def assertSpanAttributes(self, span, expected: dict): actual = {k: span.attributes[k] for k in expected} self.assertEqual(expected, actual) + def test_global_provider_alone_does_not_enable_sdk_span(self): + request = ChatCompletionRequest( + model="mistral-small-latest", + messages=[UserMessage(content="Hello")], + ) + response = ChatCompletionResponse( + id="cmpl-no-telemetry-env", + object="chat.completion", + model="mistral-small-latest", + created=1700000015, + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage(content="Hi!", tool_calls=None), + finish_reason="stop", + ), + ], + usage=UsageInfo(prompt_tokens=5, completion_tokens=2, total_tokens=7), + ) + + with patch.dict(os.environ, {}, clear=True): + self._run_hook_lifecycle( + "chat_completion_v1_chat_completions_post", + request, + response, + telemetry=None, + ) + + self.assertEqual(len(self._get_finished_spans()), 0) + # -- Simple chat completion ------------------------------------------------ def test_simple_chat_completion(self): @@ -1754,6 +1794,7 @@ async def _mock_handler(request: httpx.Request) -> httpx.Response: api_key="test-key", async_client=async_client, ) + client.sdk_configuration.__dict__["telemetry"] = "global" async def _run(): return await asyncio.gather( @@ -1804,6 +1845,48 @@ async def _run(): # -- HTTPX auto-instrumentation parenting --------------------------------- + def test_app_otel_does_not_enable_mistral_span_without_mistral_telemetry(self): + instrumentor = HTTPXClientInstrumentor() + instrumentor.instrument() + tracer = trace.get_tracer("test-workflow-parenting") + + try: + with patch.dict(os.environ, {}, clear=True): + with ( + _ChatCompletionTestServer() as server, + httpx.Client() as http_client, + ): + client = Mistral( + api_key="test-key", + client=http_client, + server_url=server.url, + ) + with tracer.start_as_current_span( + "ExecuteActivity:generate_site_diagnostic" + ) as activity_span: + client.chat.complete( + model="mistral-small-latest", + messages=_make_user_messages("hello"), + ) + + self.assertEqual( + trace.get_current_span().get_span_context().span_id, + activity_span.get_span_context().span_id, + ) + finally: + instrumentor.uninstrument() + + spans = self._get_finished_spans() + activity = next( + s for s in spans if s.name == "ExecuteActivity:generate_site_diagnostic" + ) + genai_spans = [s for s in spans if s.name == "chat mistral-small-latest"] + post_spans = [s for s in spans if s.name == "POST"] + + self.assertEqual(genai_spans, []) + self.assertEqual(len(post_spans), 1) + self.assertEqual(post_spans[0].parent.span_id, activity.context.span_id) + def test_httpx_auto_instrumented_span_is_child_of_genai_span(self): instrumentor = HTTPXClientInstrumentor() instrumentor.instrument() @@ -1816,6 +1899,7 @@ def test_httpx_auto_instrumented_span_is_child_of_genai_span(self): client=http_client, server_url=server.url, ) + client.sdk_configuration.__dict__["telemetry"] = "global" with tracer.start_as_current_span( "ExecuteActivity:generate_site_diagnostic" ) as activity_span: @@ -1858,6 +1942,7 @@ def raise_connect_error(request: httpx.Request) -> httpx.Response: client=http_client, server_url="https://api.mistral.ai", ) + client.sdk_configuration.__dict__["telemetry"] = "global" with tracer.start_as_current_span( "ExecuteActivity:generate_site_diagnostic" ) as activity_span: @@ -1901,6 +1986,7 @@ async def _run(server_url: str): async_client=async_client, server_url=server_url, ) + client.sdk_configuration.__dict__["telemetry"] = "global" with tracer.start_as_current_span( "ExecuteActivity:generate_site_diagnostic" @@ -1957,7 +2043,7 @@ def test_custom_provider_captures_spans(self): hook = TracingHook() hook.tracer_provider = custom_provider - hook_ctx = _make_hook_context("chat_completion") + hook_ctx = _make_hook_context("chat_completion", telemetry=None) request_body = _dump( ChatCompletionRequest( @@ -2002,12 +2088,13 @@ def test_custom_provider_captures_spans(self): ] self.assertEqual(len(global_spans), 0) - def test_fallback_to_global_provider(self): - """When tracer_provider is None (default), spans go to the global provider.""" + def test_global_telemetry_uses_global_provider(self): + """When telemetry is set to global, spans go to the global provider.""" _EXPORTER.clear() hook = TracingHook() - # tracer_provider defaults to None — should use global provider + # tracer_provider defaults to None; telemetry="global" opts into the + # configured global provider. self.assertIsNone(hook.tracer_provider) hook_ctx = _make_hook_context("chat_completion")