diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 5b154862e..fc1f0f31e 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -323,11 +323,12 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Context otelContext = Context.current(); return Flowable.using( - () -> - Instrumentation.recordAgentInvocation( - createInvocationContext(parentContext), this, otelContext), + () -> { + Context otelContext = Context.current(); + return Instrumentation.recordAgentInvocation( + createInvocationContext(parentContext), this, otelContext); + }, agentInvocation -> { InvocationContext invocationContext = agentInvocation.getCtx(); Flowable mainAndAfterEvents = diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 3b28761a1..887cb9761 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -218,74 +218,87 @@ private Flowable callLlm( Event eventForCallbackUsage) { LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .toFlowable() - .concatMap( - llmResp -> - postprocess( - context, - eventForCallbackUsage, - llmRequestBuilder.build(), - llmResp, - spanContext)) - .switchIfEmpty( - Flowable.defer( - () -> { - LlmAgent agent = (LlmAgent) context.agent(); - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - LlmRequest finalLlmRequest = llmRequestBuilder.build(); - - Span span = - Tracing.getTracer() - .spanBuilder("call_llm") - .setParent(spanContext) - .startSpan(); - Context callLlmContext = spanContext.with(span); - - Flowable flowable = - llm.generateContent( - finalLlmRequest, - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, - llmRequestBuilder, - eventForCallbackUsage, - exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnError( - error -> { - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) + return Flowable.defer( + () -> { + Span span = + Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan(); + Context callLlmContext = spanContext.with(span); + + return Tracing.traceFlowable( + callLlmContext, + span, + () -> + handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) + .toFlowable() .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()) - .flatMap( llmResp -> postprocess( context, eventForCallbackUsage, - finalLlmRequest, + llmRequestBuilder.build(), llmResp, callLlmContext) .doOnSubscribe( - s -> + subscription -> traceCallLlm( span, context, eventForCallbackUsage.id(), - finalLlmRequest, - llmResp))); - - return Tracing.traceFlowable(callLlmContext, span, () -> flowable); - })); + llmRequestBuilder.build(), + llmResp))) + .switchIfEmpty( + Flowable.defer( + () -> { + LlmAgent agent = (LlmAgent) context.agent(); + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm( + agent.resolvedModel().modelName().get()); + LlmRequest finalLlmRequest = llmRequestBuilder.build(); + + return llm.generateContent( + finalLlmRequest, + context.runConfig().streamingMode() + == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnError( + error -> { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback( + context, llmResp, eventForCallbackUsage) + .toFlowable()) + .flatMap( + llmResp -> + postprocess( + context, + eventForCallbackUsage, + finalLlmRequest, + llmResp, + callLlmContext) + .doOnSubscribe( + subscription -> + traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + finalLlmRequest, + llmResp))); + }))) + .compose(Tracing.withContext(spanContext)); + }); } /** @@ -667,10 +680,12 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent - .get() - .runLive(invocationContext) - .compose(Tracing.withContext(spanContext)); + Flowable.defer( + () -> { + try (Scope scope = spanContext.makeCurrent()) { + return nextAgent.get().runLive(invocationContext); + } + }); events = Flowable.concat(events, nextAgentEvents); } return events; @@ -693,11 +708,12 @@ public void onError(Throwable e) { }); return Tracing.traceFlowable( - callLlmContext, - span, - () -> - receiveFlow.takeWhile( - event -> !event.actions().endInvocation().orElse(false))); + callLlmContext, + span, + () -> + receiveFlow.takeWhile( + event -> !event.actions().endInvocation().orElse(false))) + .compose(Tracing.withContext(spanContext)); })); } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index e13ce1235..37c1c2478 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -485,9 +485,9 @@ protected Flowable runAsyncImpl( Preconditions.checkNotNull(session, "session cannot be null"); Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); - Context capturedContext = Context.current(); return Flowable.defer( () -> { + Context capturedContext = Context.current(); BaseAgent rootAgent = this.agent; String invocationId = InvocationContext.newInvocationContextId(); diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 97e69d08b..d018ed997 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -543,7 +543,8 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - Flowable pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + lifecycle.start(); + Flowable pipeline = upstream; if (onSuccessConsumer != null) { pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); } @@ -556,7 +557,8 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - Single pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + lifecycle.start(); + Single pipeline = upstream; if (onSuccessConsumer != null) { pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); } @@ -569,7 +571,8 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - Maybe pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + lifecycle.start(); + Maybe pipeline = upstream; if (onSuccessConsumer != null) { pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); } @@ -582,7 +585,8 @@ public CompletableSource apply(Completable upstream) { return Completable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 331ae77b2..83fc90566 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -62,6 +62,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -215,6 +216,70 @@ public void testTraceTransformer() throws InterruptedException { assertTrue(transformerSpanData.hasEnded()); } + @Test + public void testTraceTransformerStartsSpanBeforeSubscribingToDeferredUpstream() + throws InterruptedException { + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + AtomicReference flowableSpanId = new AtomicReference<>(); + AtomicReference singleSpanId = new AtomicReference<>(); + AtomicReference maybeSpanId = new AtomicReference<>(); + AtomicReference completableSpanId = new AtomicReference<>(); + + try (Scope s = parentSpan.makeCurrent()) { + Flowable.defer( + () -> { + flowableSpanId.set(Span.current().getSpanContext().getSpanId()); + return Flowable.just(1); + }) + .compose(Tracing.trace("flowable-transformer")) + .test() + .await() + .assertComplete(); + + Single.defer( + () -> { + singleSpanId.set(Span.current().getSpanContext().getSpanId()); + return Single.just(1); + }) + .compose(Tracing.trace("single-transformer")) + .test() + .await() + .assertComplete(); + + Maybe.defer( + () -> { + maybeSpanId.set(Span.current().getSpanContext().getSpanId()); + return Maybe.just(1); + }) + .compose(Tracing.trace("maybe-transformer")) + .test() + .await() + .assertComplete(); + + Completable.defer( + () -> { + completableSpanId.set(Span.current().getSpanContext().getSpanId()); + return Completable.complete(); + }) + .compose(Tracing.trace("completable-transformer")) + .test() + .await() + .assertComplete(); + } finally { + parentSpan.end(); + } + + SpanData parentSpanData = findSpanByName("parent"); + assertDeferredUpstreamSawTransformerSpan( + parentSpanData, findSpanByName("flowable-transformer"), flowableSpanId); + assertDeferredUpstreamSawTransformerSpan( + parentSpanData, findSpanByName("single-transformer"), singleSpanId); + assertDeferredUpstreamSawTransformerSpan( + parentSpanData, findSpanByName("maybe-transformer"), maybeSpanId); + assertDeferredUpstreamSawTransformerSpan( + parentSpanData, findSpanByName("completable-transformer"), completableSpanId); + } + @Test public void testTraceAgentInvocation() { Span span = tracer.spanBuilder("test").startSpan(); @@ -464,6 +529,38 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { assertParent(invocation, agentSpan); } + @Test + public void testModelCallbacksObserveCallLlmSpan() throws InterruptedException { + TestLlm testLlm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("response")))); + AtomicReference beforeModelSpanId = new AtomicReference<>(); + AtomicReference afterModelSpanId = new AtomicReference<>(); + + LlmAgent agentWithCallbacks = + LlmAgent.builder() + .name("test_agent") + .description("description") + .model(testLlm) + .beforeModelCallback( + (callbackContext, llmRequest) -> { + beforeModelSpanId.set(Span.current().getSpanContext().getSpanId()); + return Maybe.empty(); + }) + .afterModelCallback( + (callbackContext, llmResponse) -> { + afterModelSpanId.set(Span.current().getSpanContext().getSpanId()); + return Maybe.empty(); + }) + .build(); + + runAgent(agentWithCallbacks); + + SpanData callLlm = findSpanByName("call_llm"); + assertEquals(callLlm.getSpanContext().getSpanId(), beforeModelSpanId.get()); + assertEquals(callLlm.getSpanContext().getSpanId(), afterModelSpanId.get()); + } + @Test public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // This test verifies the trace hierarchy created when an agent calls an LLM, @@ -594,9 +691,9 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { assertParent(agentASpan, agentACallLlm1); // ├── execute_tool transfer_to_agent assertParent(agentACallLlm1, executeTool); - // └── invoke_agent AgentB - assertParent(agentACallLlm1, agentBSpan); - // └── call_llm 2 + // └── invoke_agent AgentB + assertParent(agentASpan, agentBSpan); + // └── call_llm 2 assertParent(agentBSpan, agentBCallLlm); } @@ -645,6 +742,13 @@ private void assertParent(SpanData parent, SpanData child) { assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId()); } + private void assertDeferredUpstreamSawTransformerSpan( + SpanData parent, SpanData transformer, AtomicReference observedSpanId) { + assertParent(parent, transformer); + assertTrue(transformer.hasEnded()); + assertEquals(transformer.getSpanContext().getSpanId(), observedSpanId.get()); + } + /** * Finds a span by name, polling multiple times. *