FEATURE: Anthropic streaming with tools (#1795)

This commit is contained in:
LangChain4j 2024-09-23 11:59:07 +02:00
parent d1f5775f8b
commit a91ea8ae4f
8 changed files with 128 additions and 201 deletions

View File

@ -32,7 +32,10 @@ import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*; import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicTools;
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages; import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
@ -185,7 +188,7 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
AnthropicCreateMessageRequest request = requestBuilder.build(); AnthropicCreateMessageRequest request = requestBuilder.build();
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages); ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>(); Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> { listeners.forEach(listener -> {
@ -248,16 +251,4 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
client.createMessage(request, listenerHandler); client.createMessage(request, listenerHandler);
} }
static ChatModelRequest createModelListenerRequest(AnthropicCreateMessageRequest request,
List<ChatMessage> messages) {
return ChatModelRequest.builder()
.model(request.getModel())
.temperature(request.getTemperature())
.topP(request.getTopP())
.maxTokens(request.getMaxTokens())
.messages(messages)
.build();
}
} }

View File

@ -5,7 +5,17 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils; import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.*; import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
import okhttp3.OkHttpClient; import okhttp3.OkHttpClient;
@ -28,15 +38,17 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty; import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason; import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
import static java.util.Collections.synchronizedList; import static java.util.Collections.synchronizedList;
import static java.util.stream.Collectors.toList;
public class DefaultAnthropicClient extends AnthropicClient { public class DefaultAnthropicClient extends AnthropicClient {
@ -135,16 +147,18 @@ public class DefaultAnthropicClient extends AnthropicClient {
EventSourceListener eventSourceListener = new EventSourceListener() { EventSourceListener eventSourceListener = new EventSourceListener() {
private final ReentrantLock lock = new ReentrantLock(); final ReentrantLock lock = new ReentrantLock();
final List<String> contents = synchronizedList(new ArrayList<>()); final List<String> contents = synchronizedList(new ArrayList<>());
volatile StringBuffer currentContentBuilder = new StringBuffer(); volatile StringBuffer currentContentBuilder = new StringBuffer();
private final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
final AtomicInteger inputTokenCount = new AtomicInteger(); final AtomicInteger inputTokenCount = new AtomicInteger();
final AtomicInteger outputTokenCount = new AtomicInteger(); final AtomicInteger outputTokenCount = new AtomicInteger();
private final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
AtomicReference<String> responseId = new AtomicReference<>(); final AtomicReference<String> responseId = new AtomicReference<>();
AtomicReference<String> responseModel = new AtomicReference<>(); final AtomicReference<String> responseModel = new AtomicReference<>();
volatile String stopReason; volatile String stopReason;
@ -233,13 +247,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
currentContentBlockStartType.set(data.contentBlock.type); currentContentBlockStartType.set(data.contentBlock.type);
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) { if (currentContentBlockStartType.get() == TEXT) {
String text = data.contentBlock.text; String text = data.contentBlock.text;
if (isNotNullOrEmpty(text)) { if (isNotNullOrEmpty(text)) {
currentContentBuilder().append(text); currentContentBuilder().append(text);
handler.onNext(text); handler.onNext(text);
} }
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) { } else if (currentContentBlockStartType.get() == TOOL_USE) {
toolExecutionRequestBuilderMap.putIfAbsent( toolExecutionRequestBuilderMap.putIfAbsent(
data.index, data.index,
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name) new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
@ -252,13 +266,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
return; return;
} }
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) { if (currentContentBlockStartType.get() == TEXT) {
String text = data.delta.text; String text = data.delta.text;
if (isNotNullOrEmpty(text)) { if (isNotNullOrEmpty(text)) {
currentContentBuilder().append(text); currentContentBuilder().append(text);
handler.onNext(text); handler.onNext(text);
} }
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) { } else if (currentContentBlockStartType.get() == TOOL_USE) {
String partialJson = data.delta.partialJson; String partialJson = data.delta.partialJson;
if (isNotNullOrEmpty(partialJson)) { if (isNotNullOrEmpty(partialJson)) {
Integer toolExecutionsIndex = data.index; Integer toolExecutionsIndex = data.index;
@ -293,30 +307,36 @@ public class DefaultAnthropicClient extends AnthropicClient {
} }
private Response<AiMessage> build() { private Response<AiMessage> build() {
if (!toolExecutionRequestBuilderMap.isEmpty()) {
String text = String.join("\n", contents);
TokenUsage tokenUsage = new TokenUsage(inputTokenCount.get(), outputTokenCount.get());
FinishReason finishReason = toFinishReason(stopReason);
Map<String, Object> metadata = createMetadata();
if (toolExecutionRequestBuilderMap.isEmpty()) {
return Response.from(
AiMessage.from(text),
tokenUsage,
finishReason,
metadata
);
} else {
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
.values().stream() .values().stream()
.map(AnthropicToolExecutionRequestBuilder::build) .map(AnthropicToolExecutionRequestBuilder::build)
.collect(Collectors.toList()); .collect(toList());
AiMessage aiMessage = isNullOrBlank(text)
? AiMessage.from(toolExecutionRequests)
: AiMessage.from(text, toolExecutionRequests);
return Response.from( return Response.from(
AiMessage.from(toolExecutionRequests), aiMessage,
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()), tokenUsage,
toFinishReason(stopReason), finishReason,
createMetadata() metadata
); );
} }
String content = String.join("\n", contents);
if (!content.isEmpty()) {
return Response.from(
AiMessage.from(content),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason),
createMetadata()
);
}
return null;
} }
private Map<String, Object> createMetadata() { private Map<String, Object> createMetadata() {

View File

@ -6,8 +6,23 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image; import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.anthropic.internal.api.*; import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTool;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
@ -16,11 +31,18 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import static dev.langchain4j.internal.Exceptions.illegalArgument; import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT; import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER; import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER;
import static dev.langchain4j.model.output.FinishReason.*; import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.OTHER;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.joining;
@ -130,12 +152,12 @@ public class AnthropicMapper {
public static AiMessage toAiMessage(List<AnthropicContent> contents) { public static AiMessage toAiMessage(List<AnthropicContent> contents) {
String text = contents.stream() String text = contents.stream()
.filter(content -> AnthropicContentBlockType.TEXT == content.type) .filter(content -> content.type == TEXT)
.map(content -> content.text) .map(content -> content.text)
.collect(joining("\n")); .collect(joining("\n"));
List<ToolExecutionRequest> toolExecutionRequests = contents.stream() List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
.filter(content -> AnthropicContentBlockType.TOOL_USE == content.type) .filter(content -> content.type == TOOL_USE)
.map(content -> { .map(content -> {
try { try {
return ToolExecutionRequest.builder() return ToolExecutionRequest.builder()
@ -150,7 +172,7 @@ public class AnthropicMapper {
.collect(toList()); .collect(toList());
if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) { if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) {
return new AiMessage(text, toolExecutionRequests); return AiMessage.from(text, toolExecutionRequests);
} else if (!isNullOrEmpty(toolExecutionRequests)) { } else if (!isNullOrEmpty(toolExecutionRequests)) {
return AiMessage.from(toolExecutionRequests); return AiMessage.from(toolExecutionRequests);
} else { } else {

View File

@ -126,7 +126,7 @@
"allPublicFields": true "allPublicFields": true
}, },
{ {
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice", "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
"allDeclaredConstructors": true, "allDeclaredConstructors": true,
"allPublicConstructors": true, "allPublicConstructors": true,
"allDeclaredMethods": true, "allDeclaredMethods": true,
@ -178,122 +178,5 @@
"allPublicMethods": true, "allPublicMethods": true,
"allDeclaredFields": true, "allDeclaredFields": true,
"allPublicFields": true "allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClient",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClientBuilderFactory",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseLoggingInterceptor",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolExecutionRequestBuilder",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.DefaultAnthropicClient",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicChatModel",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicChatModelName",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicStreamingChatModel",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.InternalAnthropicHelper",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
} }
] ]

View File

@ -2,7 +2,13 @@ package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
@ -17,12 +23,18 @@ import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static dev.langchain4j.model.output.FinishReason.*; import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.OTHER;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static java.util.Arrays.stream; import static java.util.Arrays.stream;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
@ -336,14 +348,13 @@ class AnthropicChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP); assertThat(secondResponse.finishReason()).isEqualTo(STOP);
} }
@ParameterizedTest @Test
@MethodSource("models_supporting_tools") void should_execute_multiple_tools_in_parallel_then_answer() {
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
// given // given
ChatLanguageModel model = AnthropicChatModel.builder() ChatLanguageModel model = AnthropicChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY")) .apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName) .modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0) .temperature(0.0)
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)
@ -399,14 +410,13 @@ class AnthropicChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP); assertThat(secondResponse.finishReason()).isEqualTo(STOP);
} }
@ParameterizedTest @Test
@MethodSource("models_supporting_tools") void should_execute_a_tool_with_nested_properties_then_answer() {
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
// given // given
ChatLanguageModel model = AnthropicChatModel.builder() ChatLanguageModel model = AnthropicChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY")) .apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName) .modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0) .temperature(0.0)
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)

View File

@ -2,13 +2,17 @@ package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
@ -18,11 +22,14 @@ import java.time.Duration;
import java.util.Base64; import java.util.Base64;
import java.util.List; import java.util.List;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
import static dev.langchain4j.data.message.SystemMessage.systemMessage; import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL; import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static dev.langchain4j.model.output.FinishReason.STOP; import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
@ -175,11 +182,11 @@ class AnthropicStreamingChatModelIT {
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder() StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY")) .apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName) .modelName(modelName)
.maxTokens(200)
.temperature(0.0) .temperature(0.0)
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)
.build(); .build();
UserMessage userMessage = userMessage("2+2=?"); UserMessage userMessage = userMessage("2+2=?");
List<ToolSpecification> toolSpecifications = singletonList(calculator); List<ToolSpecification> toolSpecifications = singletonList(calculator);
@ -190,7 +197,6 @@ class AnthropicStreamingChatModelIT {
// then // then
Response<AiMessage> response = handler.get(); Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content(); AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests(); List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
assertThat(toolExecutionRequests).hasSize(1); assertThat(toolExecutionRequests).hasSize(1);
@ -226,11 +232,12 @@ class AnthropicStreamingChatModelIT {
// given // given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder() StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY")) .apiKey(System.getenv("ANTHROPIC_API_KEY"))
.maxTokens(200) .modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0) .temperature(0.0)
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)
.build(); .build();
UserMessage userMessage = userMessage("2+2=?"); UserMessage userMessage = userMessage("2+2=?");
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>(); TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
@ -254,19 +261,18 @@ class AnthropicStreamingChatModelIT {
} }
@Disabled("Parallel execution of tools is not supported in the streaming mode yet.") @Test
@ParameterizedTest void should_execute_multiple_tools_in_parallel_then_answer() {
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
// given // given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder() StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY")) .apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName) .modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0) .temperature(0.0)
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)
.build(); .build();
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only."); SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!"); UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
List<ToolSpecification> toolSpecifications = singletonList(calculator); List<ToolSpecification> toolSpecifications = singletonList(calculator);
@ -313,18 +319,18 @@ class AnthropicStreamingChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP); assertThat(secondResponse.finishReason()).isEqualTo(STOP);
} }
@ParameterizedTest @Test
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools") void should_execute_a_tool_with_nested_properties_then_answer() {
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
// given // given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder() StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY")) .apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName) .modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0) .temperature(0.0)
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)
.build(); .build();
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?"); UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
List<ToolSpecification> toolSpecifications = singletonList(weather); List<ToolSpecification> toolSpecifications = singletonList(weather);

View File

@ -5,6 +5,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatModelListenerIT; import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener; import dev.langchain4j.model.chat.listener.ChatModelListener;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT { class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
@ -25,7 +26,7 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
@Override @Override
protected String modelName() { protected String modelName() {
return AnthropicChatModelName.CLAUDE_3_SONNET_20240229.toString(); return CLAUDE_3_SONNET_20240229.toString();
} }
@Override @Override
@ -42,9 +43,4 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
protected Class<? extends Exception> expectedExceptionClass() { protected Class<? extends Exception> expectedExceptionClass() {
return AnthropicHttpException.class; return AnthropicHttpException.class;
} }
@Override
protected boolean supportsTools() {
return false; // TODO remove this method after https://github.com/langchain4j/langchain4j/pull/1795 is merged
}
} }

View File

@ -30,7 +30,6 @@ public class TestStreamingResponseHandler<T> implements StreamingResponseHandler
AiMessage aiMessage = (AiMessage) response.content(); AiMessage aiMessage = (AiMessage) response.content();
if (aiMessage.hasToolExecutionRequests()){ if (aiMessage.hasToolExecutionRequests()){
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0); assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
assertThat(aiMessage.text()).isNull();
} else { } else {
assertThat(aiMessage.text()).isEqualTo(expectedTextContent); assertThat(aiMessage.text()).isEqualTo(expectedTextContent);
} }