FEATURE: Anthropic streaming with tools (#1795)

## Issue
Closes #1738

## Change
- Added AnthropicContentBlockType an enum that specifies content to
follow (text, or tool_use).
- Added AnthropicToolChoice, an enum
- Introduced AnthropicToolExecutionRequestBuilder to build tool
execution requests.
- Updated
`langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json`
- Update handling of streamed content for tools
- Added integration tests

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [x] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
Milan Le 2024-09-23 11:57:33 +02:00 committed by GitHub
parent bcfaf735da
commit d1f5775f8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 593 additions and 48 deletions

View File

@ -81,7 +81,7 @@ Identical to the `AnthropicChatModel`, see above.
## Tools
Anthropic supports [tools](/tutorials/tools), but only in a non-streaming mode.
Anthropic supports [tools](/tutorials/tools) in both streaming and non-streaming mode.
Anthropic documentation on tools can be found [here](https://docs.anthropic.com/claude/docs/tool-use).

View File

@ -1,5 +1,6 @@
package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
@ -8,6 +9,7 @@ import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
@ -27,12 +29,13 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
/**
* Represents an Anthropic language model with a Messages (chat) API.
@ -141,21 +144,46 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
generate(messages, null, null, handler);
}
@Override
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
generate(messages, toolSpecifications, null, handler);
}
@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
generate(messages, null, toolSpecification, handler);
}
private void generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
StreamingResponseHandler<AiMessage> handler) {
List<ChatMessage> sanitizedMessages = sanitizeMessages(messages);
String systemPrompt = toAnthropicSystemPrompt(messages);
ensureNotNull(handler, "handler");
AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder()
AnthropicCreateMessageRequest.AnthropicCreateMessageRequestBuilder requestBuilder = AnthropicCreateMessageRequest.builder()
.stream(true)
.model(modelName)
.messages(toAnthropicMessages(sanitizedMessages))
.system(systemPrompt)
.maxTokens(maxTokens)
.stopSequences(stopSequences)
.stream(true)
.temperature(temperature)
.topP(topP)
.topK(topK)
.build();
.topK(topK);
if (toolThatMustBeExecuted != null) {
requestBuilder.tools(toAnthropicTools(singletonList(toolThatMustBeExecuted)));
requestBuilder.toolChoice(AnthropicToolChoice.from(toolThatMustBeExecuted.name()));
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toAnthropicTools(toolSpecifications));
}
AnthropicCreateMessageRequest request = requestBuilder.build();
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages);
Map<Object, Object> attributes = new ConcurrentHashMap<>();

View File

@ -14,7 +14,7 @@ import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;
@JsonNaming(SnakeCaseStrategy.class)
public class AnthropicContent {
public String type;
public AnthropicContentBlockType type;
// when type = "text"
public String text;

View File

@ -0,0 +1,10 @@
package dev.langchain4j.model.anthropic.internal.api;
import com.fasterxml.jackson.annotation.JsonProperty;
public enum AnthropicContentBlockType {
@JsonProperty("text")
TEXT,
@JsonProperty("tool_use")
TOOL_USE
}

View File

@ -32,4 +32,5 @@ public class AnthropicCreateMessageRequest {
public Double topP;
public Integer topK;
public List<AnthropicTool> tools;
public AnthropicToolChoice toolChoice;
}

View File

@ -15,6 +15,7 @@ public class AnthropicDelta {
// when AnthropicStreamingData.type = "content_block_delta"
public String type;
public String text;
public String partialJson;
// when AnthropicStreamingData.type = "message_delta"
public String stopReason;

View File

@ -0,0 +1,51 @@
package dev.langchain4j.model.anthropic.internal.api;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType.AUTO;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType.TOOL;
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class AnthropicToolChoice {
@JsonProperty
private AnthropicToolChoiceType type = AUTO;
@JsonProperty
private String name;
private AnthropicToolChoice(Builder builder) {
this.type = builder.type;
this.name = builder.name;
}
public static AnthropicToolChoice from(String functionName) {
return new Builder().name(functionName).type(TOOL).build();
}
public static final class Builder {
private AnthropicToolChoiceType type;
private String name;
private Builder() {
}
public Builder name(String name) {
this.name = name;
return this;
}
public Builder type(AnthropicToolChoiceType type) {
this.type = type;
return this;
}
public AnthropicToolChoice build() {
return new AnthropicToolChoice(this);
}
}
}

View File

@ -0,0 +1,13 @@
package dev.langchain4j.model.anthropic.internal.api;
import com.fasterxml.jackson.annotation.JsonProperty;
public enum AnthropicToolChoiceType {
@JsonProperty("auto")
AUTO,
@JsonProperty("any")
ANY,
@JsonProperty("tool")
TOOL;
}

View File

@ -0,0 +1,28 @@
package dev.langchain4j.model.anthropic.internal.client;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
class AnthropicToolExecutionRequestBuilder {
private final String id;
private final String name;
private final StringBuilder argumentsBuilder = new StringBuilder();
public AnthropicToolExecutionRequestBuilder(String id, String name) {
this.id = id;
this.name = name;
}
public void appendArguments(String partialArguments) {
this.argumentsBuilder.append(partialArguments);
}
public ToolExecutionRequest build() {
return ToolExecutionRequest
.builder()
.id(id)
.name(name)
.arguments(argumentsBuilder.toString())
.build();
}
}

View File

@ -1,18 +1,11 @@
package dev.langchain4j.model.anthropic.internal.client;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.anthropic.internal.api.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import okhttp3.OkHttpClient;
@ -31,9 +24,11 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
@ -143,10 +138,11 @@ public class DefaultAnthropicClient extends AnthropicClient {
private final ReentrantLock lock = new ReentrantLock();
final List<String> contents = synchronizedList(new ArrayList<>());
volatile StringBuffer currentContentBuilder = new StringBuffer();
private final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
final AtomicInteger inputTokenCount = new AtomicInteger();
final AtomicInteger outputTokenCount = new AtomicInteger();
private final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
AtomicReference<String> responseId = new AtomicReference<>();
AtomicReference<String> responseModel = new AtomicReference<>();
@ -231,22 +227,46 @@ public class DefaultAnthropicClient extends AnthropicClient {
}
private void handleContentBlockStart(AnthropicStreamingData data) {
if (data.contentBlock != null && "text".equals(data.contentBlock.type)) {
if (data.contentBlock == null) {
return;
}
currentContentBlockStartType.set(data.contentBlock.type);
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
String text = data.contentBlock.text;
if (isNotNullOrEmpty(text)) {
currentContentBuilder().append(text);
handler.onNext(text);
}
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
toolExecutionRequestBuilderMap.putIfAbsent(
data.index,
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
);
}
}
private void handleContentBlockDelta(AnthropicStreamingData data) {
if (data.delta != null && "text_delta".equals(data.delta.type)) {
if (data.delta == null) {
return;
}
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
String text = data.delta.text;
if (isNotNullOrEmpty(text)) {
currentContentBuilder().append(text);
handler.onNext(text);
}
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
String partialJson = data.delta.partialJson;
if (isNotNullOrEmpty(partialJson)) {
Integer toolExecutionsIndex = data.index;
if (toolExecutionsIndex != null) {
AnthropicToolExecutionRequestBuilder toolExecutionRequestBuilder = toolExecutionRequestBuilderMap.get(toolExecutionsIndex);
toolExecutionRequestBuilder.appendArguments(partialJson);
}
}
}
}
@ -268,15 +288,37 @@ public class DefaultAnthropicClient extends AnthropicClient {
}
private void handleMessageStop() {
Response<AiMessage> response = Response.from(
AiMessage.from(String.join("\n", contents)),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason),
createMetadata()
);
Response<AiMessage> response = build();
handler.onComplete(response);
}
private Response<AiMessage> build() {
if (!toolExecutionRequestBuilderMap.isEmpty()) {
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
.values().stream()
.map(AnthropicToolExecutionRequestBuilder::build)
.collect(Collectors.toList());
return Response.from(
AiMessage.from(toolExecutionRequests),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason),
createMetadata()
);
}
String content = String.join("\n", contents);
if (!content.isEmpty()) {
return Response.from(
AiMessage.from(content),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason),
createMetadata()
);
}
return null;
}
private Map<String, Object> createMetadata() {
Map<String, Object> metadata = new HashMap<>();
if (responseId.get() != null) {

View File

@ -130,12 +130,12 @@ public class AnthropicMapper {
public static AiMessage toAiMessage(List<AnthropicContent> contents) {
String text = contents.stream()
.filter(content -> "text".equals(content.type))
.filter(content -> AnthropicContentBlockType.TEXT == content.type)
.map(content -> content.text)
.collect(joining("\n"));
List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
.filter(content -> "tool_use".equals(content.type))
.filter(content -> AnthropicContentBlockType.TOOL_USE == content.type)
.map(content -> {
try {
return ToolExecutionRequest.builder()

View File

@ -1,6 +1,6 @@
[
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicContent",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContent",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -9,7 +9,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicCreateMessageRequest",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -18,7 +18,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicCreateMessageResponse",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -27,7 +27,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicDelta",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -36,7 +36,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicImageContent",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicDelta",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -45,7 +45,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicImageContentSource",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -54,7 +54,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicMessage",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContentSource",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -63,7 +63,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicMessageContent",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessage",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -72,7 +72,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseMessage",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -81,7 +81,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRole",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -90,7 +90,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicStreamingData",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicRole",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -99,7 +99,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicTextContent",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -108,7 +108,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicTool",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -117,7 +117,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolResultContent",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTool",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -126,7 +126,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolSchema",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -135,7 +135,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolUseContent",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -144,7 +144,151 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicUsage",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicUsage",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClient",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClientBuilderFactory",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseLoggingInterceptor",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolExecutionRequestBuilder",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.DefaultAnthropicClient",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicChatModel",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicChatModelName",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicStreamingChatModel",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.InternalAnthropicHelper",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,

View File

@ -1,26 +1,35 @@
package dev.langchain4j.model.anthropic;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import java.time.Duration;
import java.util.Base64;
import java.util.List;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -33,6 +42,21 @@ class AnthropicStreamingChatModelIT {
.logResponses(true)
.build();
ToolSpecification calculator = ToolSpecification.builder()
.name("calculator")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build();
ToolSpecification weather = ToolSpecification.builder()
.name("weather")
.description("returns a weather forecast for a given location")
// TODO simplify defining nested properties
.addParameter("location", OBJECT, property("properties", singletonMap("city", singletonMap("type", "string"))))
.build();
@Test
void should_stream_answer_and_return_token_usage_and_finish_reason_stop() {
@ -142,4 +166,207 @@ class AnthropicStreamingChatModelIT {
.hasMessage("Anthropic API key must be defined. " +
"It can be generated here: https://console.anthropic.com/settings/keys");
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
void should_execute_a_tool_then_stream_answer(AnthropicChatModelName modelName) {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.maxTokens(200)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("2+2=?");
List<ToolSpecification> toolSpecifications = singletonList(calculator);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), toolSpecifications, handler);
// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
assertThat(toolExecutionRequests).hasSize(1);
ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
assertTokenUsage(response.tokenUsage());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
// when
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
model.generate(messages, toolSpecifications, secondHandler);
Response<AiMessage> secondResponse = secondHandler.get();
// then
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("4");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
assertTokenUsage(secondResponse.tokenUsage());
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
@Test
void must_execute_a_tool() {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.maxTokens(200)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("2+2=?");
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
// when
model.generate(singletonList(userMessage), calculator, handler);
// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
assertThat(toolExecutionRequests).hasSize(1);
ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
assertTokenUsage(response.tokenUsage());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
}
@Disabled("Parallel execution of tools is not supported in the streaming mode yet.")
@ParameterizedTest
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
List<ToolSpecification> toolSpecifications = singletonList(calculator);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(asList(systemMessage, userMessage), toolSpecifications, handler);
// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.hasToolExecutionRequests()).isTrue();
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
assertThat(toolExecutionRequests).hasSize(2);
ToolExecutionRequest toolExecutionRequest1 = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest1.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest1.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
ToolExecutionRequest toolExecutionRequest2 = aiMessage.toolExecutionRequests().get(1);
assertThat(toolExecutionRequest2.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"first\": 3, \"second\": 3}");
assertTokenUsage(response.tokenUsage());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4");
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6");
List<ChatMessage> messages = asList(systemMessage, userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2);
// when
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
model.generate(messages, toolSpecifications, secondHandler);
Response<AiMessage> secondResponse = secondHandler.get();
// then
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("4", "6");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
assertTokenUsage(secondResponse.tokenUsage());
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
List<ToolSpecification> toolSpecifications = singletonList(weather);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), toolSpecifications, handler);
// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.id()).isNotBlank();
assertThat(toolExecutionRequest.name()).isEqualTo("weather");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"location\": {\"city\": \"Berlin\"}}");
assertTokenUsage(response.tokenUsage());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "Super hot, 42 Celsius");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
// when
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
model.generate(messages, toolSpecifications, secondHandler);
Response<AiMessage> secondResponse = secondHandler.get();
// then
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("42");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
assertTokenUsage(secondResponse.tokenUsage());
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
private static void assertTokenUsage(@NotNull TokenUsage tokenUsage) {
assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
}
}