Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 128 additions & 6 deletions core/src/main/java/com/google/adk/models/Claude.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@
package com.google.adk.models;

import com.anthropic.client.AnthropicClient;
import com.anthropic.core.http.StreamResponse;
import com.anthropic.models.messages.ContentBlock;
import com.anthropic.models.messages.ContentBlockParam;
import com.anthropic.models.messages.Message;
import com.anthropic.models.messages.MessageCreateParams;
import com.anthropic.models.messages.MessageParam;
import com.anthropic.models.messages.MessageParam.Role;
import com.anthropic.models.messages.RawContentBlockDeltaEvent;
import com.anthropic.models.messages.RawContentBlockStartEvent;
import com.anthropic.models.messages.RawMessageStreamEvent;
import com.anthropic.models.messages.TextBlockParam;
import com.anthropic.models.messages.Tool;
import com.anthropic.models.messages.ToolChoice;
import com.anthropic.models.messages.ToolChoiceAuto;
import com.anthropic.models.messages.ToolResultBlockParam;
import com.anthropic.models.messages.ToolUnion;
import com.anthropic.models.messages.ToolUseBlock;
import com.anthropic.models.messages.ToolUseBlockParam;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.adk.JsonBaseModel;
Expand All @@ -53,8 +58,8 @@
/**
* Represents the Claude Generative AI model by Anthropic.
*
* <p>This class provides methods for interacting with Claude models. Streaming and live connections
* are not currently supported for Claude.
* <p>This class provides methods for interacting with Claude models, including streaming responses.
* Live connections are not currently supported for Claude.
*/
public class Claude extends BaseLlm {

Expand All @@ -81,7 +86,23 @@ public Claude(String modelName, AnthropicClient anthropicClient, int maxTokens)

@Override
public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stream) {
// TODO: Switch to streaming API.
MessageCreateParams params = buildMessageCreateParams(llmRequest);

if (stream) {
logger.debug("Sending streaming request to Claude model {}", params.model());
return Flowable.using(
() -> this.anthropicClient.messages().createStreaming(params),
streamResponse -> processStreamingResponse(streamResponse.stream()),
StreamResponse::close);
} else {
logger.debug("Sending request to Claude model {}", params.model());
var message = this.anthropicClient.messages().create(params);
logger.debug("Claude response: {}", message);
return Flowable.just(convertAnthropicResponseToLlmResponse(message));
}
}

private MessageCreateParams buildMessageCreateParams(LlmRequest llmRequest) {
List<MessageParam> messages =
llmRequest.contents().stream()
.map(this::contentToAnthropicMessageParam)
Expand Down Expand Up @@ -132,11 +153,112 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
paramsBuilder.toolChoice(toolChoice);
}

var message = this.anthropicClient.messages().create(paramsBuilder.build());
return paramsBuilder.build();
}

/**
* Converts a stream of raw Anthropic streaming events into a Flowable of {@link LlmResponse}.
*
* <p>Text deltas are emitted immediately as partial responses. Tool use blocks are accumulated
* and emitted as function calls when the block is complete.
*/
private Flowable<LlmResponse> processStreamingResponse(
java.util.stream.Stream<RawMessageStreamEvent> events) {
// Mutable state for accumulating tool call data across events.
// Keys are content block indices from the stream.
Map<Long, String> toolUseIds = new HashMap<>();
Map<Long, String> toolUseNames = new HashMap<>();
Map<Long, StringBuilder> toolUseInputJsons = new HashMap<>();

return Flowable.fromStream(events)
.concatMap(
event -> {
if (event.isContentBlockStart()) {
RawContentBlockStartEvent startEvent = event.asContentBlockStart();
long index = startEvent.index();
Optional<ToolUseBlock> toolUseOpt = startEvent.contentBlock().toolUse();
if (toolUseOpt.isPresent()) {
ToolUseBlock toolUse = toolUseOpt.get();
toolUseIds.put(index, toolUse.id());
toolUseNames.put(index, toolUse.name());
toolUseInputJsons.put(index, new StringBuilder());
}
return Flowable.<LlmResponse>empty();

} else if (event.isContentBlockDelta()) {
RawContentBlockDeltaEvent deltaEvent = event.asContentBlockDelta();
long index = deltaEvent.index();
var delta = deltaEvent.delta();

if (delta.isText()) {
String textChunk = delta.asText().text();
logger.trace("Claude streaming text chunk: {}", textChunk);
return Flowable.just(
LlmResponse.builder()
.content(
Content.builder()
.role("model")
.parts(ImmutableList.of(Part.builder().text(textChunk).build()))
.build())
.partial(true)
.build());

} else if (delta.isInputJson()) {
String jsonChunk = delta.asInputJson().partialJson();
StringBuilder accumulator = toolUseInputJsons.get(index);
if (accumulator != null) {
accumulator.append(jsonChunk);
}
return Flowable.<LlmResponse>empty();
}
return Flowable.<LlmResponse>empty();

} else if (event.isContentBlockStop()) {
long index = event.asContentBlockStop().index();
String id = toolUseIds.remove(index);
String name = toolUseNames.remove(index);
StringBuilder inputJsonBuilder = toolUseInputJsons.remove(index);

if (id != null && name != null && inputJsonBuilder != null) {
Map<String, Object> args;
try {
args =
JsonBaseModel.getMapper()
.readValue(
inputJsonBuilder.toString(),
new TypeReference<Map<String, Object>>() {});
} catch (Exception e) {
logger.warn(
"Failed to parse tool input JSON for tool '{}': {}", name, e.getMessage());
args = ImmutableMap.of();
}
logger.debug("Claude streaming tool call: id={}, name={}", id, name);
return Flowable.just(
LlmResponse.builder()
.content(
Content.builder()
.role("model")
.parts(
ImmutableList.of(
Part.builder()
.functionCall(
FunctionCall.builder()
.id(id)
.name(name)
.args(args)
.build())
.build()))
.build())
.build());
}
return Flowable.<LlmResponse>empty();

logger.debug("Claude response: {}", message);
} else if (event.isMessageStop()) {
return Flowable.just(LlmResponse.builder().turnComplete(true).build());
}

return Flowable.just(convertAnthropicResponseToLlmResponse(message));
return Flowable.<LlmResponse>empty();
});
}

private Role toClaudeRole(String role) {
Expand Down
Loading