FEATURE: Anthropic streaming with tools (#1795)
## Issue Closes #1738 ## Change - Added AnthropicContentBlockType an enum that specifies content to follow (text, or tool_use). - Added AnthropicToolChoice, an enum - Introduced AnthropicToolExecutionRequestBuilder to build tool execution requests. - Updated `langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json` - Update handling of streamed content for tools - Added integration tests ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [x] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [x] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
bcfaf735da
commit
d1f5775f8b
|
@ -81,7 +81,7 @@ Identical to the `AnthropicChatModel`, see above.
|
|||
|
||||
## Tools
|
||||
|
||||
Anthropic supports [tools](/tutorials/tools), but only in a non-streaming mode.
|
||||
Anthropic supports [tools](/tutorials/tools) in both streaming and non-streaming mode.
|
||||
|
||||
Anthropic documentation on tools can be found [here](https://docs.anthropic.com/claude/docs/tool-use).
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package dev.langchain4j.model.anthropic;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
|
@ -8,6 +9,7 @@ import dev.langchain4j.data.message.SystemMessage;
|
|||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
|
||||
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
|
||||
|
@ -27,12 +29,13 @@ import java.util.Map;
|
|||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
|
||||
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an Anthropic language model with a Messages (chat) API.
|
||||
|
@ -141,21 +144,46 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
|
|||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, null, null, handler);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, toolSpecifications, null, handler);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, null, toolSpecification, handler);
|
||||
}
|
||||
|
||||
private void generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted,
|
||||
StreamingResponseHandler<AiMessage> handler) {
|
||||
List<ChatMessage> sanitizedMessages = sanitizeMessages(messages);
|
||||
String systemPrompt = toAnthropicSystemPrompt(messages);
|
||||
ensureNotNull(handler, "handler");
|
||||
|
||||
AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder()
|
||||
AnthropicCreateMessageRequest.AnthropicCreateMessageRequestBuilder requestBuilder = AnthropicCreateMessageRequest.builder()
|
||||
.stream(true)
|
||||
.model(modelName)
|
||||
.messages(toAnthropicMessages(sanitizedMessages))
|
||||
.system(systemPrompt)
|
||||
.maxTokens(maxTokens)
|
||||
.stopSequences(stopSequences)
|
||||
.stream(true)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.topK(topK)
|
||||
.build();
|
||||
.topK(topK);
|
||||
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.tools(toAnthropicTools(singletonList(toolThatMustBeExecuted)));
|
||||
requestBuilder.toolChoice(AnthropicToolChoice.from(toolThatMustBeExecuted.name()));
|
||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toAnthropicTools(toolSpecifications));
|
||||
}
|
||||
|
||||
AnthropicCreateMessageRequest request = requestBuilder.build();
|
||||
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages);
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
|
|
|
@ -14,7 +14,7 @@ import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;
|
|||
@JsonNaming(SnakeCaseStrategy.class)
|
||||
public class AnthropicContent {
|
||||
|
||||
public String type;
|
||||
public AnthropicContentBlockType type;
|
||||
|
||||
// when type = "text"
|
||||
public String text;
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
package dev.langchain4j.model.anthropic.internal.api;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
public enum AnthropicContentBlockType {
|
||||
@JsonProperty("text")
|
||||
TEXT,
|
||||
@JsonProperty("tool_use")
|
||||
TOOL_USE
|
||||
}
|
|
@ -32,4 +32,5 @@ public class AnthropicCreateMessageRequest {
|
|||
public Double topP;
|
||||
public Integer topK;
|
||||
public List<AnthropicTool> tools;
|
||||
public AnthropicToolChoice toolChoice;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ public class AnthropicDelta {
|
|||
// when AnthropicStreamingData.type = "content_block_delta"
|
||||
public String type;
|
||||
public String text;
|
||||
public String partialJson;
|
||||
|
||||
// when AnthropicStreamingData.type = "message_delta"
|
||||
public String stopReason;
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
package dev.langchain4j.model.anthropic.internal.api;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonNaming;
|
||||
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType.AUTO;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType.TOOL;
|
||||
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
|
||||
public class AnthropicToolChoice {
|
||||
|
||||
@JsonProperty
|
||||
private AnthropicToolChoiceType type = AUTO;
|
||||
|
||||
@JsonProperty
|
||||
private String name;
|
||||
|
||||
private AnthropicToolChoice(Builder builder) {
|
||||
this.type = builder.type;
|
||||
this.name = builder.name;
|
||||
}
|
||||
|
||||
public static AnthropicToolChoice from(String functionName) {
|
||||
return new Builder().name(functionName).type(TOOL).build();
|
||||
}
|
||||
|
||||
public static final class Builder {
|
||||
private AnthropicToolChoiceType type;
|
||||
private String name;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
public Builder name(String name) {
|
||||
this.name = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder type(AnthropicToolChoiceType type) {
|
||||
this.type = type;
|
||||
return this;
|
||||
}
|
||||
|
||||
public AnthropicToolChoice build() {
|
||||
return new AnthropicToolChoice(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package dev.langchain4j.model.anthropic.internal.api;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
public enum AnthropicToolChoiceType {
|
||||
|
||||
@JsonProperty("auto")
|
||||
AUTO,
|
||||
@JsonProperty("any")
|
||||
ANY,
|
||||
@JsonProperty("tool")
|
||||
TOOL;
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package dev.langchain4j.model.anthropic.internal.client;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
|
||||
class AnthropicToolExecutionRequestBuilder {
|
||||
|
||||
private final String id;
|
||||
private final String name;
|
||||
private final StringBuilder argumentsBuilder = new StringBuilder();
|
||||
|
||||
public AnthropicToolExecutionRequestBuilder(String id, String name) {
|
||||
this.id = id;
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public void appendArguments(String partialArguments) {
|
||||
this.argumentsBuilder.append(partialArguments);
|
||||
}
|
||||
|
||||
public ToolExecutionRequest build() {
|
||||
return ToolExecutionRequest
|
||||
.builder()
|
||||
.id(id)
|
||||
.name(name)
|
||||
.arguments(argumentsBuilder.toString())
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -1,18 +1,11 @@
|
|||
package dev.langchain4j.model.anthropic.internal.client;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
|
||||
import dev.langchain4j.model.anthropic.internal.api.*;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import okhttp3.OkHttpClient;
|
||||
|
@ -31,9 +24,11 @@ import java.util.ArrayList;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
|
||||
|
@ -143,10 +138,11 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
private final ReentrantLock lock = new ReentrantLock();
|
||||
final List<String> contents = synchronizedList(new ArrayList<>());
|
||||
volatile StringBuffer currentContentBuilder = new StringBuffer();
|
||||
private final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
|
||||
|
||||
final AtomicInteger inputTokenCount = new AtomicInteger();
|
||||
final AtomicInteger outputTokenCount = new AtomicInteger();
|
||||
|
||||
private final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
|
||||
AtomicReference<String> responseId = new AtomicReference<>();
|
||||
AtomicReference<String> responseModel = new AtomicReference<>();
|
||||
|
||||
|
@ -231,22 +227,46 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
}
|
||||
|
||||
private void handleContentBlockStart(AnthropicStreamingData data) {
|
||||
if (data.contentBlock != null && "text".equals(data.contentBlock.type)) {
|
||||
if (data.contentBlock == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
currentContentBlockStartType.set(data.contentBlock.type);
|
||||
|
||||
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
|
||||
String text = data.contentBlock.text;
|
||||
if (isNotNullOrEmpty(text)) {
|
||||
currentContentBuilder().append(text);
|
||||
handler.onNext(text);
|
||||
}
|
||||
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
|
||||
toolExecutionRequestBuilderMap.putIfAbsent(
|
||||
data.index,
|
||||
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleContentBlockDelta(AnthropicStreamingData data) {
|
||||
if (data.delta != null && "text_delta".equals(data.delta.type)) {
|
||||
if (data.delta == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
|
||||
String text = data.delta.text;
|
||||
if (isNotNullOrEmpty(text)) {
|
||||
currentContentBuilder().append(text);
|
||||
handler.onNext(text);
|
||||
}
|
||||
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
|
||||
String partialJson = data.delta.partialJson;
|
||||
if (isNotNullOrEmpty(partialJson)) {
|
||||
Integer toolExecutionsIndex = data.index;
|
||||
if (toolExecutionsIndex != null) {
|
||||
AnthropicToolExecutionRequestBuilder toolExecutionRequestBuilder = toolExecutionRequestBuilderMap.get(toolExecutionsIndex);
|
||||
toolExecutionRequestBuilder.appendArguments(partialJson);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,15 +288,37 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
}
|
||||
|
||||
private void handleMessageStop() {
|
||||
Response<AiMessage> response = Response.from(
|
||||
AiMessage.from(String.join("\n", contents)),
|
||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
||||
toFinishReason(stopReason),
|
||||
createMetadata()
|
||||
);
|
||||
Response<AiMessage> response = build();
|
||||
handler.onComplete(response);
|
||||
}
|
||||
|
||||
private Response<AiMessage> build() {
|
||||
if (!toolExecutionRequestBuilderMap.isEmpty()) {
|
||||
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
|
||||
.values().stream()
|
||||
.map(AnthropicToolExecutionRequestBuilder::build)
|
||||
.collect(Collectors.toList());
|
||||
return Response.from(
|
||||
AiMessage.from(toolExecutionRequests),
|
||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
||||
toFinishReason(stopReason),
|
||||
createMetadata()
|
||||
);
|
||||
}
|
||||
|
||||
String content = String.join("\n", contents);
|
||||
if (!content.isEmpty()) {
|
||||
return Response.from(
|
||||
AiMessage.from(content),
|
||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
||||
toFinishReason(stopReason),
|
||||
createMetadata()
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private Map<String, Object> createMetadata() {
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
if (responseId.get() != null) {
|
||||
|
|
|
@ -130,12 +130,12 @@ public class AnthropicMapper {
|
|||
public static AiMessage toAiMessage(List<AnthropicContent> contents) {
|
||||
|
||||
String text = contents.stream()
|
||||
.filter(content -> "text".equals(content.type))
|
||||
.filter(content -> AnthropicContentBlockType.TEXT == content.type)
|
||||
.map(content -> content.text)
|
||||
.collect(joining("\n"));
|
||||
|
||||
List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
|
||||
.filter(content -> "tool_use".equals(content.type))
|
||||
.filter(content -> AnthropicContentBlockType.TOOL_USE == content.type)
|
||||
.map(content -> {
|
||||
try {
|
||||
return ToolExecutionRequest.builder()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicContent",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContent",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -9,7 +9,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicCreateMessageRequest",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -18,7 +18,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicCreateMessageResponse",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -27,7 +27,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicDelta",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -36,7 +36,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicImageContent",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicDelta",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -45,7 +45,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicImageContentSource",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -54,7 +54,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicMessage",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContentSource",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -63,7 +63,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicMessageContent",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessage",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -72,7 +72,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseMessage",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -81,7 +81,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRole",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -90,7 +90,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicStreamingData",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicRole",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -99,7 +99,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicTextContent",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -108,7 +108,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicTool",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -117,7 +117,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolResultContent",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTool",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -126,7 +126,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolSchema",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -135,7 +135,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolUseContent",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -144,7 +144,151 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicUsage",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicUsage",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClient",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClientBuilderFactory",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseLoggingInterceptor",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolExecutionRequestBuilder",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.DefaultAnthropicClient",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.AnthropicChatModel",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.AnthropicChatModelName",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.AnthropicStreamingChatModel",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.InternalAnthropicHelper",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
|
|
@ -1,26 +1,35 @@
|
|||
package dev.langchain4j.model.anthropic;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
||||
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
import static java.lang.System.getenv;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.Collections.singletonMap;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
|
@ -33,6 +42,21 @@ class AnthropicStreamingChatModelIT {
|
|||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
ToolSpecification calculator = ToolSpecification.builder()
|
||||
.name("calculator")
|
||||
.description("returns a sum of two numbers")
|
||||
.addParameter("first", INTEGER)
|
||||
.addParameter("second", INTEGER)
|
||||
.build();
|
||||
|
||||
ToolSpecification weather = ToolSpecification.builder()
|
||||
.name("weather")
|
||||
.description("returns a weather forecast for a given location")
|
||||
// TODO simplify defining nested properties
|
||||
.addParameter("location", OBJECT, property("properties", singletonMap("city", singletonMap("type", "string"))))
|
||||
.build();
|
||||
|
||||
|
||||
@Test
|
||||
void should_stream_answer_and_return_token_usage_and_finish_reason_stop() {
|
||||
|
||||
|
@ -142,4 +166,207 @@ class AnthropicStreamingChatModelIT {
|
|||
.hasMessage("Anthropic API key must be defined. " +
|
||||
"It can be generated here: https://console.anthropic.com/settings/keys");
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
||||
void should_execute_a_tool_then_stream_answer(AnthropicChatModelName modelName) {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.maxTokens(200)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
UserMessage userMessage = userMessage("2+2=?");
|
||||
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(singletonList(userMessage), toolSpecifications, handler);
|
||||
|
||||
// then
|
||||
Response<AiMessage> response = handler.get();
|
||||
AiMessage aiMessage = response.content();
|
||||
assertThat(aiMessage.text()).isNull();
|
||||
|
||||
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
|
||||
assertThat(toolExecutionRequests).hasSize(1);
|
||||
|
||||
ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0);
|
||||
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
|
||||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||
|
||||
assertTokenUsage(response.tokenUsage());
|
||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4");
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
|
||||
model.generate(messages, toolSpecifications, secondHandler);
|
||||
Response<AiMessage> secondResponse = secondHandler.get();
|
||||
|
||||
// then
|
||||
AiMessage secondAiMessage = secondResponse.content();
|
||||
assertThat(secondAiMessage.text()).contains("4");
|
||||
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
|
||||
|
||||
assertTokenUsage(secondResponse.tokenUsage());
|
||||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@Test
|
||||
void must_execute_a_tool() {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.maxTokens(200)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
UserMessage userMessage = userMessage("2+2=?");
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
|
||||
// when
|
||||
model.generate(singletonList(userMessage), calculator, handler);
|
||||
|
||||
// then
|
||||
Response<AiMessage> response = handler.get();
|
||||
AiMessage aiMessage = response.content();
|
||||
assertThat(aiMessage.text()).isNull();
|
||||
|
||||
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
|
||||
assertThat(toolExecutionRequests).hasSize(1);
|
||||
|
||||
ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0);
|
||||
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
|
||||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||
|
||||
assertTokenUsage(response.tokenUsage());
|
||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
}
|
||||
|
||||
|
||||
@Disabled("Parallel execution of tools is not supported in the streaming mode yet.")
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
||||
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
|
||||
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
|
||||
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(asList(systemMessage, userMessage), toolSpecifications, handler);
|
||||
|
||||
// then
|
||||
Response<AiMessage> response = handler.get();
|
||||
AiMessage aiMessage = response.content();
|
||||
|
||||
assertThat(aiMessage.hasToolExecutionRequests()).isTrue();
|
||||
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
|
||||
assertThat(toolExecutionRequests).hasSize(2);
|
||||
|
||||
ToolExecutionRequest toolExecutionRequest1 = aiMessage.toolExecutionRequests().get(0);
|
||||
assertThat(toolExecutionRequest1.name()).isEqualTo("calculator");
|
||||
assertThat(toolExecutionRequest1.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||
|
||||
ToolExecutionRequest toolExecutionRequest2 = aiMessage.toolExecutionRequests().get(1);
|
||||
assertThat(toolExecutionRequest2.name()).isEqualTo("calculator");
|
||||
assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"first\": 3, \"second\": 3}");
|
||||
|
||||
assertTokenUsage(response.tokenUsage());
|
||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6");
|
||||
List<ChatMessage> messages = asList(systemMessage, userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
|
||||
model.generate(messages, toolSpecifications, secondHandler);
|
||||
Response<AiMessage> secondResponse = secondHandler.get();
|
||||
|
||||
// then
|
||||
AiMessage secondAiMessage = secondResponse.content();
|
||||
assertThat(secondAiMessage.text()).contains("4", "6");
|
||||
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
|
||||
|
||||
assertTokenUsage(secondResponse.tokenUsage());
|
||||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
||||
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
|
||||
List<ToolSpecification> toolSpecifications = singletonList(weather);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(singletonList(userMessage), toolSpecifications, handler);
|
||||
|
||||
// then
|
||||
Response<AiMessage> response = handler.get();
|
||||
AiMessage aiMessage = response.content();
|
||||
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
|
||||
|
||||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
||||
assertThat(toolExecutionRequest.id()).isNotBlank();
|
||||
assertThat(toolExecutionRequest.name()).isEqualTo("weather");
|
||||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"location\": {\"city\": \"Berlin\"}}");
|
||||
|
||||
assertTokenUsage(response.tokenUsage());
|
||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "Super hot, 42 Celsius");
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
|
||||
model.generate(messages, toolSpecifications, secondHandler);
|
||||
Response<AiMessage> secondResponse = secondHandler.get();
|
||||
|
||||
// then
|
||||
AiMessage secondAiMessage = secondResponse.content();
|
||||
assertThat(secondAiMessage.text()).contains("42");
|
||||
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
|
||||
|
||||
assertTokenUsage(secondResponse.tokenUsage());
|
||||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
private static void assertTokenUsage(@NotNull TokenUsage tokenUsage) {
|
||||
assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0);
|
||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue