Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/mistralai/client/_hooks/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,30 @@ def __init__(self) -> None:
self._auto_telemetry_provider: Optional[Any] = None
self._telemetry_finalizer: Optional[weakref.finalize] = None
self._telemetry_auto_disabled: bool = False
self._telemetry_use_global_provider: bool = False
self.tracing_enabled, self.tracer = get_or_create_otel_tracer()

def before_request(
self, hook_ctx: BeforeRequestContext, request: httpx.Request
) -> Union[httpx.Request, Exception]:
# The GenAI span is created in this hook, but HTTPX creates its own
# auto-instrumented span later inside send(). Wrap the configured
# clients so each request's stored GenAI span is current only while
# that request is being sent.
self._ensure_client_send_wrapped(getattr(hook_ctx.config, "client", None))
self._ensure_async_client_send_wrapped(
getattr(hook_ctx.config, "async_client", None)
)
configure_telemetry_for_hook(
telemetry_configured = configure_telemetry_for_hook(
self,
hook_ctx.config,
respect_global_provider=True,
)
# Refresh tracer/provider per request so tracing can be enabled if the
# application configures OpenTelemetry after the client is instantiated.
self.tracing_enabled, self.tracer = get_or_create_otel_tracer(
provider=self.tracer_provider,
should_trace = (
telemetry_configured
or self.tracer_provider is not None
or self._telemetry_use_global_provider
)
if should_trace:
# Refresh tracer/provider per request so tracing can be enabled if the
# application configures OpenTelemetry after the client is instantiated.
self.tracing_enabled, self.tracer = get_or_create_otel_tracer(
provider=self.tracer_provider,
)
else:
self.tracing_enabled = False
request, span = get_traced_request_and_span(
tracing_enabled=self.tracing_enabled,
tracer=self.tracer,
Expand All @@ -75,6 +76,15 @@ def before_request(
request=request,
)
self._store_span_on_request(request, span)
# The GenAI span is created in this hook, but HTTPX creates its own
# auto-instrumented span later inside send(). Wrap the configured
# clients so each request's stored GenAI span is current only while
# that request is being sent.
if span is not None:
self._ensure_client_send_wrapped(getattr(hook_ctx.config, "client", None))
self._ensure_async_client_send_wrapped(
getattr(hook_ctx.config, "async_client", None)
)
return request

@staticmethod
Expand Down
9 changes: 8 additions & 1 deletion src/mistralai/extra/observability/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def configure_telemetry_for_hook(
) -> bool:
"""Configure telemetry for a tracing hook when the user has opted in."""
# Fast path: already resolved and no explicit override requested.
if hook._auto_telemetry_provider is not None and telemetry is None:
if telemetry is None and (
hook._auto_telemetry_provider is not None or hook._telemetry_use_global_provider
):
return True
if telemetry is None and hook._telemetry_auto_disabled:
return False
Expand All @@ -155,6 +157,7 @@ def configure_telemetry_for_hook(
)
if provider_mode is None:
_shutdown_telemetry_provider(hook)
hook._telemetry_use_global_provider = False
hook._telemetry_auto_disabled = True
return False

Expand All @@ -176,6 +179,7 @@ def configure_telemetry_for_hook(
"configure_telemetry(client, provider='dedicated') to attach an "
"SDK-owned provider for this client."
)
hook._telemetry_use_global_provider = False
hook._telemetry_auto_disabled = True
return False

Expand Down Expand Up @@ -294,6 +298,7 @@ def _attach_telemetry_provider(
_shutdown_telemetry_provider(hook)
hook.tracer_provider = provider
hook._auto_telemetry_provider = provider
hook._telemetry_use_global_provider = False
hook._telemetry_auto_disabled = False
hook._telemetry_finalizer = weakref.finalize(
finalizer_owner, provider.shutdown
Expand All @@ -306,6 +311,7 @@ def _attach_custom_tracer_provider(
) -> None:
_shutdown_telemetry_provider(hook)
hook.tracer_provider = provider
hook._telemetry_use_global_provider = False
hook._telemetry_auto_disabled = False


Expand All @@ -323,6 +329,7 @@ def _use_global_tracer_provider(

_shutdown_telemetry_provider(hook)
hook.tracer_provider = None
hook._telemetry_use_global_provider = True
hook._telemetry_auto_disabled = True
return True

Expand Down
105 changes: 96 additions & 9 deletions src/mistralai/extra/tests/test_otel_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from collections import Counter
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
import json
import os
import threading
import unittest
from datetime import datetime, timezone
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import httpx
from opentelemetry import context as context_api
Expand Down Expand Up @@ -118,9 +119,16 @@ def _make_httpx_response(body: dict, status_code: int = 200) -> httpx.Response:
return resp


def _make_hook_context(operation_id: str) -> HookContext:
def _make_hook_context(
operation_id: str,
telemetry: bool | str | None = "global",
) -> HookContext:
config = MagicMock()
config.telemetry = telemetry
config.client = None
config.async_client = None
return HookContext(
config=MagicMock(),
config=config,
base_url="https://api.mistral.ai",
operation_id=operation_id,
oauth2_scopes=None,
Expand Down Expand Up @@ -254,6 +262,7 @@ def _run_hook_lifecycle(
request_body,
response_body,
streaming: bool = False,
telemetry: bool | str | None = "global",
):
"""Drive the real TracingHook: before_request → after_success.

Expand All @@ -266,7 +275,7 @@ def _run_hook_lifecycle(
so the span is finalised before returning.
"""
hook = TracingHook()
hook_ctx = _make_hook_context(operation_id)
hook_ctx = _make_hook_context(operation_id, telemetry=telemetry)

req_dict = (
_dump(request_body) if hasattr(request_body, "model_dump") else request_body
Expand Down Expand Up @@ -309,10 +318,11 @@ def _run_hook_error_lifecycle(
response_body: dict,
status_code: int = 400,
error: Exception | None = None,
telemetry: bool | str | None = "global",
):
"""Drive the real TracingHook: before_request → after_error."""
hook = TracingHook()
hook_ctx = _make_hook_context(operation_id)
hook_ctx = _make_hook_context(operation_id, telemetry=telemetry)

req_dict = (
_dump(request_body) if hasattr(request_body, "model_dump") else request_body
Expand Down Expand Up @@ -343,6 +353,36 @@ def assertSpanAttributes(self, span, expected: dict):
actual = {k: span.attributes[k] for k in expected}
self.assertEqual(expected, actual)

def test_global_provider_alone_does_not_enable_sdk_span(self):
request = ChatCompletionRequest(
model="mistral-small-latest",
messages=[UserMessage(content="Hello")],
)
response = ChatCompletionResponse(
id="cmpl-no-telemetry-env",
object="chat.completion",
model="mistral-small-latest",
created=1700000015,
choices=[
ChatCompletionChoice(
index=0,
message=AssistantMessage(content="Hi!", tool_calls=None),
finish_reason="stop",
),
],
usage=UsageInfo(prompt_tokens=5, completion_tokens=2, total_tokens=7),
)

with patch.dict(os.environ, {}, clear=True):
self._run_hook_lifecycle(
"chat_completion_v1_chat_completions_post",
request,
response,
telemetry=None,
)

self.assertEqual(len(self._get_finished_spans()), 0)

# -- Simple chat completion ------------------------------------------------

def test_simple_chat_completion(self):
Expand Down Expand Up @@ -1754,6 +1794,7 @@ async def _mock_handler(request: httpx.Request) -> httpx.Response:
api_key="test-key",
async_client=async_client,
)
client.sdk_configuration.__dict__["telemetry"] = "global"

async def _run():
return await asyncio.gather(
Expand Down Expand Up @@ -1804,6 +1845,48 @@ async def _run():

# -- HTTPX auto-instrumentation parenting ---------------------------------

def test_app_otel_does_not_enable_mistral_span_without_mistral_telemetry(self):
instrumentor = HTTPXClientInstrumentor()
instrumentor.instrument()
tracer = trace.get_tracer("test-workflow-parenting")

try:
with patch.dict(os.environ, {}, clear=True):
with (
_ChatCompletionTestServer() as server,
httpx.Client() as http_client,
):
client = Mistral(
api_key="test-key",
client=http_client,
server_url=server.url,
)
with tracer.start_as_current_span(
"ExecuteActivity:generate_site_diagnostic"
) as activity_span:
client.chat.complete(
model="mistral-small-latest",
messages=_make_user_messages("hello"),
)

self.assertEqual(
trace.get_current_span().get_span_context().span_id,
activity_span.get_span_context().span_id,
)
finally:
instrumentor.uninstrument()

spans = self._get_finished_spans()
activity = next(
s for s in spans if s.name == "ExecuteActivity:generate_site_diagnostic"
)
genai_spans = [s for s in spans if s.name == "chat mistral-small-latest"]
post_spans = [s for s in spans if s.name == "POST"]

self.assertEqual(genai_spans, [])
self.assertEqual(len(post_spans), 1)
self.assertEqual(post_spans[0].parent.span_id, activity.context.span_id)

def test_httpx_auto_instrumented_span_is_child_of_genai_span(self):
instrumentor = HTTPXClientInstrumentor()
instrumentor.instrument()
Expand All @@ -1816,6 +1899,7 @@ def test_httpx_auto_instrumented_span_is_child_of_genai_span(self):
client=http_client,
server_url=server.url,
)
client.sdk_configuration.__dict__["telemetry"] = "global"
with tracer.start_as_current_span(
"ExecuteActivity:generate_site_diagnostic"
) as activity_span:
Expand Down Expand Up @@ -1858,6 +1942,7 @@ def raise_connect_error(request: httpx.Request) -> httpx.Response:
client=http_client,
server_url="https://api.mistral.ai",
)
client.sdk_configuration.__dict__["telemetry"] = "global"
with tracer.start_as_current_span(
"ExecuteActivity:generate_site_diagnostic"
) as activity_span:
Expand Down Expand Up @@ -1901,6 +1986,7 @@ async def _run(server_url: str):
async_client=async_client,
server_url=server_url,
)
client.sdk_configuration.__dict__["telemetry"] = "global"

with tracer.start_as_current_span(
"ExecuteActivity:generate_site_diagnostic"
Expand Down Expand Up @@ -1957,7 +2043,7 @@ def test_custom_provider_captures_spans(self):
hook = TracingHook()
hook.tracer_provider = custom_provider

hook_ctx = _make_hook_context("chat_completion")
hook_ctx = _make_hook_context("chat_completion", telemetry=None)

request_body = _dump(
ChatCompletionRequest(
Expand Down Expand Up @@ -2002,12 +2088,13 @@ def test_custom_provider_captures_spans(self):
]
self.assertEqual(len(global_spans), 0)

def test_fallback_to_global_provider(self):
"""When tracer_provider is None (default), spans go to the global provider."""
def test_global_telemetry_uses_global_provider(self):
"""When telemetry is set to global, spans go to the global provider."""
_EXPORTER.clear()

hook = TracingHook()
# tracer_provider defaults to None — should use global provider
# tracer_provider defaults to None; telemetry="global" opts into the
# configured global provider.
self.assertIsNone(hook.tracer_provider)

hook_ctx = _make_hook_context("chat_completion")
Expand Down
Loading