Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Scheduler;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.functions.Function;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -153,15 +155,8 @@ public static Maybe<Event> handleFunctionCalls(
Function<FunctionCall, Maybe<Event>> functionCallMapper =
getFunctionCallMapper(invocationContext, tools, toolConfirmations, false, parentContext);

Observable<Event> functionResponseEventsObservable;
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
functionResponseEventsObservable =
Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
} else {
functionResponseEventsObservable =
Observable.fromIterable(validFunctionCalls)
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
}
Observable<Event> functionResponseEventsObservable =
buildToolExecutionObservable(invocationContext, validFunctionCalls, functionCallMapper);
return functionResponseEventsObservable
.toList()
.toMaybe()
Expand Down Expand Up @@ -224,15 +219,8 @@ public static Maybe<Event> handleFunctionCallsLive(
Function<FunctionCall, Maybe<Event>> functionCallMapper =
getFunctionCallMapper(invocationContext, tools, toolConfirmations, true, parentContext);

Observable<Event> responseEventsObservable;
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
responseEventsObservable =
Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
} else {
responseEventsObservable =
Observable.fromIterable(validFunctionCalls)
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
}
Observable<Event> responseEventsObservable =
buildToolExecutionObservable(invocationContext, validFunctionCalls, functionCallMapper);

return responseEventsObservable
.toList()
Expand All @@ -247,6 +235,38 @@ public static Maybe<Event> handleFunctionCallsLive(
});
}

/**
* Builds the tool-execution {@link Observable}.
*
* <p>SEQUENTIAL (or a single call, where parallelism is moot) runs on the caller thread via
* {@code concatMapMaybe} to keep synchronous semantics. PARALLEL with multiple calls dispatches
* each tool on a worker so blocking calls run concurrently; {@code concatMapEager} preserves
* input order required by {@link #mergeParallelFunctionResponseEvents}.
*/
private static Observable<Event> buildToolExecutionObservable(
InvocationContext invocationContext,
List<FunctionCall> validFunctionCalls,
Function<FunctionCall, Maybe<Event>> functionCallMapper) {
boolean parallel =
invocationContext.runConfig().toolExecutionMode() != ToolExecutionMode.SEQUENTIAL
&& validFunctionCalls.size() > 1;
if (!parallel) {
return Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
}
Scheduler scheduler = resolveToolExecutionScheduler(invocationContext);
return Observable.fromIterable(validFunctionCalls)
.concatMapEager(
call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler));
}

/** Agent executor if set, otherwise the IO scheduler. */
private static Scheduler resolveToolExecutionScheduler(InvocationContext invocationContext) {
if (invocationContext.agent() instanceof LlmAgent llmAgent) {
return llmAgent.executor().map(Schedulers::from).orElse(Schedulers.io());
}
return Schedulers.io();
}

private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
InvocationContext invocationContext,
Map<String, BaseTool> tools,
Expand Down
137 changes: 137 additions & 0 deletions core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,21 @@
import com.google.adk.agents.RunConfig.ToolExecutionMode;
import com.google.adk.events.Event;
import com.google.adk.testing.TestUtils;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Single;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -388,4 +397,132 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal
ImmutableList<FunctionCall> result = Functions.getAskUserConfirmationFunctionCalls(event);
assertThat(result).containsExactly(confirmationCall1, confirmationCall2);
}

@Test
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_twoTools() {
runParallelBlockingToolsTest(/* toolCount= */ 2);
}

@Test
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_threeTools() {
runParallelBlockingToolsTest(/* toolCount= */ 3);
}

@Test
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_fiveTools() {
runParallelBlockingToolsTest(/* toolCount= */ 5);
}

/** Single-tool case bypasses the parallel scheduler path; must still return the correct event. */
@Test
public void handleFunctionCalls_parallel_blockingTool_singleTool() {
long sleepMillis = 200L;
InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
SleepingTool tool = new SleepingTool("slow_tool_1", sleepMillis);
Event event =
createEvent("event").toBuilder()
.content(
Content.fromParts(
Part.builder()
.functionCall(
FunctionCall.builder()
.id("call_1")
.name("slow_tool_1")
.args(ImmutableMap.of())
.build())
.build()))
.build();

Event functionResponseEvent =
Functions.handleFunctionCalls(
invocationContext, event, ImmutableMap.of("slow_tool_1", tool))
.blockingGet();

assertThat(functionResponseEvent).isNotNull();
assertThat(functionResponseEvent.content().get().parts().get())
.containsExactly(
Part.builder()
.functionResponse(
FunctionResponse.builder()
.id("call_1")
.name("slow_tool_1")
.response(ImmutableMap.of("tool", "slow_tool_1"))
.build())
.build());
}

/** Asserts that {@code toolCount} blocking tools in PARALLEL mode run faster than sequential. */
private static void runParallelBlockingToolsTest(int toolCount) {
long sleepMillis = 500L;
InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());

Map<String, BaseTool> tools = new LinkedHashMap<>();
List<Part> callParts = new ArrayList<>();
List<Part> expectedResponseParts = new ArrayList<>();
for (int i = 1; i <= toolCount; i++) {
String toolName = "slow_tool_" + i;
String callId = "call_" + i;
tools.put(toolName, new SleepingTool(toolName, sleepMillis));
callParts.add(
Part.builder()
.functionCall(
FunctionCall.builder().id(callId).name(toolName).args(ImmutableMap.of()).build())
.build());
expectedResponseParts.add(
Part.builder()
.functionResponse(
FunctionResponse.builder()
.id(callId)
.name(toolName)
.response(ImmutableMap.of("tool", toolName))
.build())
.build());
}
Event event =
createEvent("event").toBuilder()
.content(Content.fromParts(callParts.toArray(new Part[0])))
.build();

long start = System.currentTimeMillis();
Event functionResponseEvent =
Functions.handleFunctionCalls(invocationContext, event, tools).blockingGet();
long durationMillis = System.currentTimeMillis() - start;

assertThat(functionResponseEvent).isNotNull();
assertThat(functionResponseEvent.content().get().parts().get())
.containsExactlyElementsIn(expectedResponseParts)
.inOrder();
// Sequential would be ~toolCount * sleepMillis; parallel is ~sleepMillis + fixed overhead.
assertThat(durationMillis).isLessThan((long) toolCount * sleepMillis);
}

/** Tool that blocks the executing thread for {@code sleepMillis} before returning. */
private static final class SleepingTool extends BaseTool {
private final long sleepMillis;

SleepingTool(String name, long sleepMillis) {
super(name, "Blocking tool used to verify parallel execution.");
this.sleepMillis = sleepMillis;
}

@Override
public Optional<FunctionDeclaration> declaration() {
return Optional.of(FunctionDeclaration.builder().name(name()).build());
}

@Override
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
return Single.fromCallable(
() -> {
Thread.sleep(sleepMillis);
return ImmutableMap.<String, Object>of("tool", name());
});
}
}
}
Loading