From 69680bbeae11578199eca4efcaf5ecddea2dd552 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 14 Apr 2026 08:48:02 -0700 Subject: [PATCH] fix: Fix ADK Runner race condition for sequential tool execution Ensure that events are appended to the session and processed sequentially before proceeding to the next step in BaseLlmFlow. PiperOrigin-RevId: 899609964 --- .../adk/flows/llmflows/BaseLlmFlow.java | 25 +--- .../google/adk/flows/llmflows/Functions.java | 25 ++-- .../java/com/google/adk/runner/Runner.java | 38 ++---- .../com/google/adk/runner/RunnerTest.java | 123 +----------------- 4 files changed, 27 insertions(+), 184 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fdda5219d..fffeab698 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -461,31 +461,14 @@ public Flowable run(InvocationContext invocationContext) { private Flowable run( Context spanContext, InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(spanContext, invocationContext); - - Flowable processedEvents = - currentStepEvents - .concatMap( - event -> - invocationContext - .sessionService() - .appendEvent(invocationContext.session(), event) - .flatMap( - registeredEvent -> - invocationContext - .pluginManager() - .onEventCallback(invocationContext, registeredEvent) - .defaultIfEmpty(registeredEvent)) - .toFlowable()) - .cache(); - + Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); - return processedEvents; + return currentStepEvents; } - return processedEvents.concatWith( - processedEvents + return currentStepEvents.concatWith( + currentStepEvents .toList() .flatMapPublisher( eventList -> { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 49af2a122..0b0e5b4d5 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -71,10 +71,8 @@ public final class Functions { private static final Logger logger = LoggerFactory.getLogger(Functions.class); /** Generates a unique ID for a function call. */ - public static String generateClientFunctionCallId(FunctionCall functionCall) { - String source = - functionCall.name().orElse("") + functionCall.args().orElse(ImmutableMap.of()).toString(); - return AF_FUNCTION_CALL_ID_PREFIX + UUID.nameUUIDFromBytes(source.getBytes()).toString(); + public static String generateClientFunctionCallId() { + return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID(); } /** @@ -103,7 +101,7 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) { FunctionCall functionCall = part.functionCall().get(); if (functionCall.id().isEmpty() || functionCall.id().get().isEmpty()) { FunctionCall updatedFunctionCall = - functionCall.toBuilder().id(generateClientFunctionCallId(functionCall)).build(); + functionCall.toBuilder().id(generateClientFunctionCallId()).build(); newParts.add(part.toBuilder().functionCall(updatedFunctionCall).build()); modified = true; } else { @@ -623,7 +621,7 @@ private static Event buildResponseEvent( .build(); return Event.builder() - .id(toolContext.functionCallId().orElseGet(Event::generateEventId)) + .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)) @@ -659,7 +657,7 @@ public static Optional generateRequestConfirmationEvent( .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)) .entrySet()) { - FunctionCall.Builder builder = + FunctionCall requestConfirmationFunctionCall = FunctionCall.builder() .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) .args( @@ -667,9 +665,9 @@ public static Optional generateRequestConfirmationEvent( "originalFunctionCall", functionCallsById.get(entry.getKey()), "toolConfirmation", - entry.getValue())); - FunctionCall requestConfirmationFunctionCall = - builder.id(generateClientFunctionCallId(builder.build())).build(); + entry.getValue())) + .id(generateClientFunctionCallId()) + .build(); longRunningToolIds.add(requestConfirmationFunctionCall.id().get()); parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build()); @@ -682,15 +680,8 @@ public static Optional generateRequestConfirmationEvent( var contentBuilder = Content.builder().parts(parts); functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role); - String deterministicId = - "req-conf-" - + functionResponseEvent.actions().requestedToolConfirmations().keySet().stream() - .sorted() - .collect(java.util.stream.Collectors.joining("-")); - return Optional.of( Event.builder() - .id(deterministicId) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index f6fe08c2b..44a281f72 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -68,12 +68,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.jspecify.annotations.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** The main class for the GenAI Agents runner. */ public class Runner { - private static final Logger logger = LoggerFactory.getLogger(Runner.class); private final BaseAgent agent; private final String appName; private final BaseArtifactService artifactService; @@ -573,28 +570,19 @@ private Flowable runAgentWithUpdatedSession( .agent() .runAsync(contextWithUpdatedSession) .concatMap( - agentEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - - // TODO: b/502182243 - Investigate if appendEvent should be made idempotent in - // SessionService to avoid this check. - if (updatedSession.events().stream() - .anyMatch(e -> e.id() != null && e.id().equals(agentEvent.id()))) { - logger.debug("Event {} already in session, skipping append", agentEvent.id()); - return io.reactivex.rxjava3.core.Flowable.just(agentEvent); - } - return this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - return contextWithUpdatedSession - .pluginManager() - .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable(); - }); + agentEvent -> + this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + return contextWithUpdatedSession + .pluginManager() + .onEventCallback(contextWithUpdatedSession, registeredEvent) + .defaultIfEmpty(registeredEvent); + }) + .toFlowable()); // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 5f3c7295d..ff75c97b0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -46,12 +46,9 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; -import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.BaseSessionService; -import com.google.adk.sessions.GetSessionConfig; -import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -83,7 +80,6 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -592,22 +588,12 @@ public void onToolErrorCallback_error() { @Test public void onEventCallback_success() { when(plugin.onEventCallback(any(), any())) - .thenAnswer( - invocation -> { - Event event = invocation.getArgument(1); - return Maybe.just( - Event.builder() - .id(event.id()) - .invocationId(event.invocationId()) - .author("model") - .content(createContent("from plugin")) - .build()); - }); + .thenReturn(Maybe.just(TestUtils.createEvent("form plugin"))); List events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(simplifyEvents(events)).containsExactly("model: from plugin"); + assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin"); verify(plugin).onEventCallback(any(), any()); } @@ -1700,109 +1686,4 @@ public void runner_executesSaveArtifactFlow() { // agent was run assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); } - - @Test - public void runAsync_ensuresSequentialConsistencyForTools() { - // Arrange - TestLlm testLlm = - createTestLlm( - createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")), - createTextLlmResponse("Final response")); - - LlmAgent agent = - createTestAgentBuilder(testLlm) - .tools( - ImmutableList.of( - FunctionTool.create(RaceConditionTools.class, "tool1"), - FunctionTool.create(RaceConditionTools.class, "tool2"))) - .build(); - - BaseSessionService delegate = new InMemorySessionService(); - BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 0); - - Runner runner = - Runner.builder() - .app(App.builder().name("test").rootAgent(agent).build()) - .sessionService(delayedSessionService) - .build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - - // Act - var unused = - runner - .runAsync("user", session.id(), Content.fromParts(Part.fromText("start"))) - .toList() - .blockingGet(); - - // Assert - ImmutableList requests = ImmutableList.copyOf(testLlm.getRequests()); - assertThat(requests).hasSize(2); - - // Second request should contain the result of tool1 - LlmRequest secondRequest = requests.get(1); - List history = secondRequest.contents(); - - boolean foundToolResponse = - history.stream() - .flatMap(content -> content.parts().stream().flatMap(List::stream)) - .filter(part -> part.functionResponse().isPresent()) - .map(part -> part.functionResponse().get()) - .anyMatch( - response -> - response.name().orElse("").equals("tool1") - && response - .response() - .map( - r -> - java.util.Objects.equals( - r, ImmutableMap.of("result", "result_value1"))) - .orElse(false)); - - assertThat(foundToolResponse).isTrue(); - } - - @SuppressWarnings({"unchecked", "deprecation"}) - private static BaseSessionService createDelayedSessionService( - BaseSessionService delegate, long delayMs) { - BaseSessionService delayedSessionService = mock(BaseSessionService.class); - when(delayedSessionService.createSession(anyString(), anyString(), any(Map.class), anyString())) - .thenAnswer( - inv -> - delegate.createSession( - (String) inv.getArgument(0), - (String) inv.getArgument(1), - (Map) inv.getArgument(2), - (String) inv.getArgument(3))); - when(delayedSessionService.createSession(anyString(), anyString())) - .thenAnswer( - inv -> - delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1))); - when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any())) - .thenAnswer( - inv -> - delegate.getSession( - (String) inv.getArgument(0), - (String) inv.getArgument(1), - (String) inv.getArgument(2), - (Optional) inv.getArgument(3))); - when(delayedSessionService.appendEvent(any(), any())) - .thenAnswer( - inv -> - delegate - .appendEvent(inv.getArgument(0), inv.getArgument(1)) - .delay(delayMs, MILLISECONDS)); - return delayedSessionService; - } - - public static class RaceConditionTools { - private RaceConditionTools() {} - - public static ImmutableMap tool1(String arg) { - return ImmutableMap.of("result", "result_" + arg); - } - - public static ImmutableMap tool2(String input) { - return ImmutableMap.of("status", "received_" + input); - } - } }