diff --git a/core/src/main/java/com/google/adk/models/Claude.java b/core/src/main/java/com/google/adk/models/Claude.java index 79201c261..3576f768b 100644 --- a/core/src/main/java/com/google/adk/models/Claude.java +++ b/core/src/main/java/com/google/adk/models/Claude.java @@ -17,18 +17,23 @@ package com.google.adk.models; import com.anthropic.client.AnthropicClient; +import com.anthropic.core.http.StreamResponse; import com.anthropic.models.messages.ContentBlock; import com.anthropic.models.messages.ContentBlockParam; import com.anthropic.models.messages.Message; import com.anthropic.models.messages.MessageCreateParams; import com.anthropic.models.messages.MessageParam; import com.anthropic.models.messages.MessageParam.Role; +import com.anthropic.models.messages.RawContentBlockDeltaEvent; +import com.anthropic.models.messages.RawContentBlockStartEvent; +import com.anthropic.models.messages.RawMessageStreamEvent; import com.anthropic.models.messages.TextBlockParam; import com.anthropic.models.messages.Tool; import com.anthropic.models.messages.ToolChoice; import com.anthropic.models.messages.ToolChoiceAuto; import com.anthropic.models.messages.ToolResultBlockParam; import com.anthropic.models.messages.ToolUnion; +import com.anthropic.models.messages.ToolUseBlock; import com.anthropic.models.messages.ToolUseBlockParam; import com.fasterxml.jackson.core.type.TypeReference; import com.google.adk.JsonBaseModel; @@ -53,8 +58,8 @@ /** * Represents the Claude Generative AI model by Anthropic. * - *

This class provides methods for interacting with Claude models. Streaming and live connections - * are not currently supported for Claude. + *

This class provides methods for interacting with Claude models, including streaming responses. + * Live connections are not currently supported for Claude. */ public class Claude extends BaseLlm { @@ -81,7 +86,23 @@ public Claude(String modelName, AnthropicClient anthropicClient, int maxTokens) @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { - // TODO: Switch to streaming API. + MessageCreateParams params = buildMessageCreateParams(llmRequest); + + if (stream) { + logger.debug("Sending streaming request to Claude model {}", params.model()); + return Flowable.using( + () -> this.anthropicClient.messages().createStreaming(params), + streamResponse -> processStreamingResponse(streamResponse.stream()), + StreamResponse::close); + } else { + logger.debug("Sending request to Claude model {}", params.model()); + var message = this.anthropicClient.messages().create(params); + logger.debug("Claude response: {}", message); + return Flowable.just(convertAnthropicResponseToLlmResponse(message)); + } + } + + private MessageCreateParams buildMessageCreateParams(LlmRequest llmRequest) { List messages = llmRequest.contents().stream() .map(this::contentToAnthropicMessageParam) @@ -132,11 +153,112 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre paramsBuilder.toolChoice(toolChoice); } - var message = this.anthropicClient.messages().create(paramsBuilder.build()); + return paramsBuilder.build(); + } + + /** + * Converts a stream of raw Anthropic streaming events into a Flowable of {@link LlmResponse}. + * + *

Text deltas are emitted immediately as partial responses. Tool use blocks are accumulated + * and emitted as function calls when the block is complete. + */ + private Flowable processStreamingResponse( + java.util.stream.Stream events) { + // Mutable state for accumulating tool call data across events. + // Keys are content block indices from the stream. + Map toolUseIds = new HashMap<>(); + Map toolUseNames = new HashMap<>(); + Map toolUseInputJsons = new HashMap<>(); + + return Flowable.fromStream(events) + .concatMap( + event -> { + if (event.isContentBlockStart()) { + RawContentBlockStartEvent startEvent = event.asContentBlockStart(); + long index = startEvent.index(); + Optional toolUseOpt = startEvent.contentBlock().toolUse(); + if (toolUseOpt.isPresent()) { + ToolUseBlock toolUse = toolUseOpt.get(); + toolUseIds.put(index, toolUse.id()); + toolUseNames.put(index, toolUse.name()); + toolUseInputJsons.put(index, new StringBuilder()); + } + return Flowable.empty(); + + } else if (event.isContentBlockDelta()) { + RawContentBlockDeltaEvent deltaEvent = event.asContentBlockDelta(); + long index = deltaEvent.index(); + var delta = deltaEvent.delta(); + + if (delta.isText()) { + String textChunk = delta.asText().text(); + logger.trace("Claude streaming text chunk: {}", textChunk); + return Flowable.just( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().text(textChunk).build())) + .build()) + .partial(true) + .build()); + + } else if (delta.isInputJson()) { + String jsonChunk = delta.asInputJson().partialJson(); + StringBuilder accumulator = toolUseInputJsons.get(index); + if (accumulator != null) { + accumulator.append(jsonChunk); + } + return Flowable.empty(); + } + return Flowable.empty(); + + } else if (event.isContentBlockStop()) { + long index = event.asContentBlockStop().index(); + String id = toolUseIds.remove(index); + String name = toolUseNames.remove(index); + StringBuilder inputJsonBuilder = toolUseInputJsons.remove(index); + + if (id != null && name != null && inputJsonBuilder != null) { + Map args; + try { + args = + JsonBaseModel.getMapper() + .readValue( + inputJsonBuilder.toString(), + new TypeReference>() {}); + } catch (Exception e) { + logger.warn( + "Failed to parse tool input JSON for tool '{}': {}", name, e.getMessage()); + args = ImmutableMap.of(); + } + logger.debug("Claude streaming tool call: id={}, name={}", id, name); + return Flowable.just( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder() + .id(id) + .name(name) + .args(args) + .build()) + .build())) + .build()) + .build()); + } + return Flowable.empty(); - logger.debug("Claude response: {}", message); + } else if (event.isMessageStop()) { + return Flowable.just(LlmResponse.builder().turnComplete(true).build()); + } - return Flowable.just(convertAnthropicResponseToLlmResponse(message)); + return Flowable.empty(); + }); } private Role toClaudeRole(String role) { diff --git a/core/src/test/java/com/google/adk/models/ClaudeTest.java b/core/src/test/java/com/google/adk/models/ClaudeTest.java index febcaf4be..2c9ad2f15 100644 --- a/core/src/test/java/com/google/adk/models/ClaudeTest.java +++ b/core/src/test/java/com/google/adk/models/ClaudeTest.java @@ -17,15 +17,29 @@ package com.google.adk.models; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; import com.anthropic.client.AnthropicClient; +import com.anthropic.core.http.StreamResponse; import com.anthropic.models.messages.ContentBlockParam; +import com.anthropic.models.messages.DirectCaller; +import com.anthropic.models.messages.RawContentBlockDeltaEvent; +import com.anthropic.models.messages.RawContentBlockStartEvent; +import com.anthropic.models.messages.RawContentBlockStopEvent; +import com.anthropic.models.messages.RawMessageStopEvent; +import com.anthropic.models.messages.RawMessageStreamEvent; import com.anthropic.models.messages.ToolResultBlockParam; +import com.anthropic.models.messages.ToolUseBlock; +import com.anthropic.services.blocking.MessageService; import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import java.lang.reflect.Method; +import java.util.List; import java.util.Map; +import java.util.stream.Stream; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,19 +50,25 @@ public final class ClaudeTest { private Claude claude; + private AnthropicClient mockClient; + private MessageService mockMessageService; private Method partToAnthropicMessageBlockMethod; @Before public void setUp() throws Exception { - AnthropicClient mockClient = Mockito.mock(AnthropicClient.class); + mockClient = Mockito.mock(AnthropicClient.class); + mockMessageService = Mockito.mock(MessageService.class); + when(mockClient.messages()).thenReturn(mockMessageService); + claude = new Claude("claude-3-opus", mockClient); - // Access private method for testing the extraction logic partToAnthropicMessageBlockMethod = Claude.class.getDeclaredMethod("partToAnthropicMessageBlock", Part.class); partToAnthropicMessageBlockMethod.setAccessible(true); } + // --- Existing partToAnthropicMessageBlock tests --- + @Test public void testPartToAnthropicMessageBlock_mcpTool_legacyTextOutputKey() throws Exception { Map responseData = @@ -78,4 +98,318 @@ public void testPartToAnthropicMessageBlock_jsonFallback() throws Exception { ToolResultBlockParam toolResult = result.asToolResult(); assertThat(toolResult.content().get().asString()).contains("\"custom_key\":\"custom_value\""); } + + // --- Streaming tests --- + + @Test + public void testStreaming_textChunksEmittedAsPartialResponses() { + Stream events = + Stream.of( + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder().index(0).textDelta("Hello").build()), + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder().index(0).textDelta(", world!").build()), + RawMessageStreamEvent.ofMessageStop(new RawMessageStopEvent.Builder().build())); + + StreamResponse mockStreamResponse = mockStreamResponse(events); + when(mockMessageService.createStreaming( + any(com.anthropic.models.messages.MessageCreateParams.class))) + .thenReturn(mockStreamResponse); + + LlmRequest request = + LlmRequest.builder() + .contents( + List.of( + Content.builder() + .role("user") + .parts(List.of(Part.builder().text("Say hello").build())) + .build())) + .build(); + + List responses = + claude.generateContent(request, /* stream= */ true).toList().blockingGet(); + + // Filter out the turnComplete sentinel + List textResponses = + responses.stream() + .filter(r -> r.content().isPresent()) + .collect(java.util.stream.Collectors.toList()); + + assertThat(textResponses).hasSize(2); + assertThat(textResponses.get(0).content().get().parts().get().get(0).text().get()) + .isEqualTo("Hello"); + assertThat(textResponses.get(0).partial().get()).isTrue(); + assertThat(textResponses.get(1).content().get().parts().get().get(0).text().get()) + .isEqualTo(", world!"); + assertThat(textResponses.get(1).partial().get()).isTrue(); + } + + @Test + public void testStreaming_messageStopEmitsTurnComplete() { + Stream events = + Stream.of(RawMessageStreamEvent.ofMessageStop(new RawMessageStopEvent.Builder().build())); + + StreamResponse mockStreamResponse = mockStreamResponse(events); + when(mockMessageService.createStreaming( + any(com.anthropic.models.messages.MessageCreateParams.class))) + .thenReturn(mockStreamResponse); + + LlmRequest request = + LlmRequest.builder() + .contents( + List.of( + Content.builder() + .role("user") + .parts(List.of(Part.builder().text("Hi").build())) + .build())) + .build(); + + List responses = + claude.generateContent(request, /* stream= */ true).toList().blockingGet(); + + assertThat(responses).hasSize(1); + assertThat(responses.get(0).turnComplete().get()).isTrue(); + } + + @Test + public void testStreaming_toolCallAccumulatedAndEmittedOnBlockStop() { + ToolUseBlock toolUseBlock = + new ToolUseBlock.Builder() + .id("tool_abc") + .name("get_weather") + .caller(new DirectCaller.Builder().build()) + .input(com.anthropic.core.JsonValue.from(java.util.Collections.emptyMap())) + .build(); + + Stream events = + Stream.of( + RawMessageStreamEvent.ofContentBlockStart( + new RawContentBlockStartEvent.Builder() + .index(0) + .contentBlock(toolUseBlock) + .build()), + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder() + .index(0) + .inputJsonDelta("{\"city\":") + .build()), + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder() + .index(0) + .inputJsonDelta("\"London\"}") + .build()), + RawMessageStreamEvent.ofContentBlockStop( + new RawContentBlockStopEvent.Builder().index(0).build()), + RawMessageStreamEvent.ofMessageStop(new RawMessageStopEvent.Builder().build())); + + StreamResponse mockStreamResponse = mockStreamResponse(events); + when(mockMessageService.createStreaming( + any(com.anthropic.models.messages.MessageCreateParams.class))) + .thenReturn(mockStreamResponse); + + LlmRequest request = + LlmRequest.builder() + .contents( + List.of( + Content.builder() + .role("user") + .parts(List.of(Part.builder().text("What's the weather?").build())) + .build())) + .build(); + + List responses = + claude.generateContent(request, /* stream= */ true).toList().blockingGet(); + + List toolResponses = + responses.stream() + .filter(r -> r.content().isPresent()) + .collect(java.util.stream.Collectors.toList()); + + assertThat(toolResponses).hasSize(1); + Part functionCallPart = toolResponses.get(0).content().get().parts().get().get(0); + assertThat(functionCallPart.functionCall().isPresent()).isTrue(); + assertThat(functionCallPart.functionCall().get().id().get()).isEqualTo("tool_abc"); + assertThat(functionCallPart.functionCall().get().name().get()).isEqualTo("get_weather"); + assertThat(functionCallPart.functionCall().get().args().get()).containsEntry("city", "London"); + } + + @Test + public void testStreaming_mixedTextAndToolCall() { + ToolUseBlock toolUseBlock = + new ToolUseBlock.Builder() + .id("tool_xyz") + .name("search") + .caller(new DirectCaller.Builder().build()) + .input(com.anthropic.core.JsonValue.from(java.util.Collections.emptyMap())) + .build(); + + Stream events = + Stream.of( + // Text block first + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder() + .index(0) + .textDelta("Let me search.") + .build()), + // Tool call block second + RawMessageStreamEvent.ofContentBlockStart( + new RawContentBlockStartEvent.Builder() + .index(1) + .contentBlock(toolUseBlock) + .build()), + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder() + .index(1) + .inputJsonDelta("{\"query\":\"java\"}") + .build()), + RawMessageStreamEvent.ofContentBlockStop( + new RawContentBlockStopEvent.Builder().index(1).build()), + RawMessageStreamEvent.ofMessageStop(new RawMessageStopEvent.Builder().build())); + + StreamResponse mockStreamResponse = mockStreamResponse(events); + when(mockMessageService.createStreaming( + any(com.anthropic.models.messages.MessageCreateParams.class))) + .thenReturn(mockStreamResponse); + + LlmRequest request = + LlmRequest.builder() + .contents( + List.of( + Content.builder() + .role("user") + .parts(List.of(Part.builder().text("Search for java").build())) + .build())) + .build(); + + List responses = + claude.generateContent(request, /* stream= */ true).toList().blockingGet(); + + List contentResponses = + responses.stream() + .filter(r -> r.content().isPresent()) + .collect(java.util.stream.Collectors.toList()); + + assertThat(contentResponses).hasSize(2); + + // First: partial text + assertThat(contentResponses.get(0).content().get().parts().get().get(0).text().get()) + .isEqualTo("Let me search."); + assertThat(contentResponses.get(0).partial().get()).isTrue(); + + // Second: function call + assertThat( + contentResponses.get(1).content().get().parts().get().get(0).functionCall().isPresent()) + .isTrue(); + assertThat( + contentResponses + .get(1) + .content() + .get() + .parts() + .get() + .get(0) + .functionCall() + .get() + .name() + .get()) + .isEqualTo("search"); + } + + @Test + public void testStreaming_emptyStream_producesNoResponses() { + StreamResponse mockStreamResponse = mockStreamResponse(Stream.of()); + when(mockMessageService.createStreaming( + any(com.anthropic.models.messages.MessageCreateParams.class))) + .thenReturn(mockStreamResponse); + + LlmRequest request = + LlmRequest.builder() + .contents( + List.of( + Content.builder() + .role("user") + .parts(List.of(Part.builder().text("Hi").build())) + .build())) + .build(); + + List responses = + claude.generateContent(request, /* stream= */ true).toList().blockingGet(); + + assertThat(responses).isEmpty(); + } + + @Test + public void testStreaming_toolCallWithInvalidJson_fallsBackToEmptyArgs() { + ToolUseBlock toolUseBlock = + new ToolUseBlock.Builder() + .id("tool_bad") + .name("broken_tool") + .caller(new DirectCaller.Builder().build()) + .input(com.anthropic.core.JsonValue.from(java.util.Collections.emptyMap())) + .build(); + + Stream events = + Stream.of( + RawMessageStreamEvent.ofContentBlockStart( + new RawContentBlockStartEvent.Builder() + .index(0) + .contentBlock(toolUseBlock) + .build()), + RawMessageStreamEvent.ofContentBlockDelta( + new RawContentBlockDeltaEvent.Builder() + .index(0) + .inputJsonDelta("not-valid-json") + .build()), + RawMessageStreamEvent.ofContentBlockStop( + new RawContentBlockStopEvent.Builder().index(0).build())); + + StreamResponse mockStreamResponse = mockStreamResponse(events); + when(mockMessageService.createStreaming( + any(com.anthropic.models.messages.MessageCreateParams.class))) + .thenReturn(mockStreamResponse); + + LlmRequest request = + LlmRequest.builder() + .contents( + List.of( + Content.builder() + .role("user") + .parts(List.of(Part.builder().text("Run it").build())) + .build())) + .build(); + + List responses = + claude.generateContent(request, /* stream= */ true).toList().blockingGet(); + + List toolResponses = + responses.stream() + .filter(r -> r.content().isPresent()) + .collect(java.util.stream.Collectors.toList()); + + // Tool call is still emitted but with empty args on parse failure + assertThat(toolResponses).hasSize(1); + assertThat(toolResponses.get(0).content().get().parts().get().get(0).functionCall().isPresent()) + .isTrue(); + assertThat( + toolResponses + .get(0) + .content() + .get() + .parts() + .get() + .get(0) + .functionCall() + .get() + .args() + .get()) + .isEmpty(); + } + + @SuppressWarnings("unchecked") + private static StreamResponse mockStreamResponse( + Stream events) { + StreamResponse mock = Mockito.mock(StreamResponse.class); + when(mock.stream()).thenReturn(events); + return mock; + } }