FEATURE: Anthropic streaming with tools (#1795)

This commit is contained in:
LangChain4j 2024-09-23 11:59:07 +02:00
parent d1f5775f8b
commit a91ea8ae4f
8 changed files with 128 additions and 201 deletions

View File

@ -32,7 +32,10 @@ import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicTools;
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
@ -185,7 +188,7 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
AnthropicCreateMessageRequest request = requestBuilder.build();
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages);
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
@ -248,16 +251,4 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
client.createMessage(request, listenerHandler);
}
static ChatModelRequest createModelListenerRequest(AnthropicCreateMessageRequest request,
List<ChatMessage> messages) {
return ChatModelRequest.builder()
.model(request.getModel())
.temperature(request.getTemperature())
.topP(request.getTopP())
.maxTokens(request.getMaxTokens())
.messages(messages)
.build();
}
}

View File

@ -5,7 +5,17 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.*;
import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import okhttp3.OkHttpClient;
@ -28,15 +38,17 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
import static java.util.Collections.synchronizedList;
import static java.util.stream.Collectors.toList;
public class DefaultAnthropicClient extends AnthropicClient {
@ -135,16 +147,18 @@ public class DefaultAnthropicClient extends AnthropicClient {
EventSourceListener eventSourceListener = new EventSourceListener() {
private final ReentrantLock lock = new ReentrantLock();
final ReentrantLock lock = new ReentrantLock();
final List<String> contents = synchronizedList(new ArrayList<>());
volatile StringBuffer currentContentBuilder = new StringBuffer();
private final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
final AtomicInteger inputTokenCount = new AtomicInteger();
final AtomicInteger outputTokenCount = new AtomicInteger();
private final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
AtomicReference<String> responseId = new AtomicReference<>();
AtomicReference<String> responseModel = new AtomicReference<>();
final AtomicReference<String> responseId = new AtomicReference<>();
final AtomicReference<String> responseModel = new AtomicReference<>();
volatile String stopReason;
@ -233,13 +247,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
currentContentBlockStartType.set(data.contentBlock.type);
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
if (currentContentBlockStartType.get() == TEXT) {
String text = data.contentBlock.text;
if (isNotNullOrEmpty(text)) {
currentContentBuilder().append(text);
handler.onNext(text);
}
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
} else if (currentContentBlockStartType.get() == TOOL_USE) {
toolExecutionRequestBuilderMap.putIfAbsent(
data.index,
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
@ -252,13 +266,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
return;
}
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
if (currentContentBlockStartType.get() == TEXT) {
String text = data.delta.text;
if (isNotNullOrEmpty(text)) {
currentContentBuilder().append(text);
handler.onNext(text);
}
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
} else if (currentContentBlockStartType.get() == TOOL_USE) {
String partialJson = data.delta.partialJson;
if (isNotNullOrEmpty(partialJson)) {
Integer toolExecutionsIndex = data.index;
@ -293,30 +307,36 @@ public class DefaultAnthropicClient extends AnthropicClient {
}
private Response<AiMessage> build() {
if (!toolExecutionRequestBuilderMap.isEmpty()) {
String text = String.join("\n", contents);
TokenUsage tokenUsage = new TokenUsage(inputTokenCount.get(), outputTokenCount.get());
FinishReason finishReason = toFinishReason(stopReason);
Map<String, Object> metadata = createMetadata();
if (toolExecutionRequestBuilderMap.isEmpty()) {
return Response.from(
AiMessage.from(text),
tokenUsage,
finishReason,
metadata
);
} else {
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
.values().stream()
.map(AnthropicToolExecutionRequestBuilder::build)
.collect(Collectors.toList());
.collect(toList());
AiMessage aiMessage = isNullOrBlank(text)
? AiMessage.from(toolExecutionRequests)
: AiMessage.from(text, toolExecutionRequests);
return Response.from(
AiMessage.from(toolExecutionRequests),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason),
createMetadata()
aiMessage,
tokenUsage,
finishReason,
metadata
);
}
String content = String.join("\n", contents);
if (!content.isEmpty()) {
return Response.from(
AiMessage.from(content),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason),
createMetadata()
);
}
return null;
}
private Map<String, Object> createMetadata() {

View File

@ -6,8 +6,23 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.anthropic.internal.api.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTool;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
@ -16,11 +31,18 @@ import java.util.List;
import java.util.Map;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER;
import static dev.langchain4j.model.output.FinishReason.*;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.OTHER;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.stream.Collectors.joining;
@ -130,12 +152,12 @@ public class AnthropicMapper {
public static AiMessage toAiMessage(List<AnthropicContent> contents) {
String text = contents.stream()
.filter(content -> AnthropicContentBlockType.TEXT == content.type)
.filter(content -> content.type == TEXT)
.map(content -> content.text)
.collect(joining("\n"));
List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
.filter(content -> AnthropicContentBlockType.TOOL_USE == content.type)
.filter(content -> content.type == TOOL_USE)
.map(content -> {
try {
return ToolExecutionRequest.builder()
@ -150,7 +172,7 @@ public class AnthropicMapper {
.collect(toList());
if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) {
return new AiMessage(text, toolExecutionRequests);
return AiMessage.from(text, toolExecutionRequests);
} else if (!isNullOrEmpty(toolExecutionRequests)) {
return AiMessage.from(toolExecutionRequests);
} else {

View File

@ -126,7 +126,7 @@
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
@ -178,122 +178,5 @@
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClient",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClientBuilderFactory",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseLoggingInterceptor",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolExecutionRequestBuilder",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.client.DefaultAnthropicClient",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicChatModel",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicChatModelName",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.AnthropicStreamingChatModel",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
},
{
"name": "dev.langchain4j.model.anthropic.InternalAnthropicHelper",
"allDeclaredConstructors": true,
"allPublicConstructors": true,
"allDeclaredMethods": true,
"allPublicMethods": true,
"allDeclaredFields": true,
"allPublicFields": true
}
]

View File

@ -2,7 +2,13 @@ package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
@ -17,12 +23,18 @@ import java.util.Base64;
import java.util.List;
import java.util.stream.Stream;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static dev.langchain4j.model.output.FinishReason.*;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.OTHER;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.Arrays.asList;
import static java.util.Arrays.stream;
import static java.util.Collections.singletonList;
@ -336,14 +348,13 @@ class AnthropicChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest
@MethodSource("models_supporting_tools")
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
@Test
void should_execute_multiple_tools_in_parallel_then_answer() {
// given
ChatLanguageModel model = AnthropicChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
@ -399,14 +410,13 @@ class AnthropicChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest
@MethodSource("models_supporting_tools")
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
@Test
void should_execute_a_tool_with_nested_properties_then_answer() {
// given
ChatLanguageModel model = AnthropicChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0)
.logRequests(true)
.logResponses(true)

View File

@ -2,13 +2,17 @@ package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
@ -18,11 +22,14 @@ import java.time.Duration;
import java.util.Base64;
import java.util.List;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
@ -175,11 +182,11 @@ class AnthropicStreamingChatModelIT {
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.maxTokens(200)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("2+2=?");
List<ToolSpecification> toolSpecifications = singletonList(calculator);
@ -190,7 +197,6 @@ class AnthropicStreamingChatModelIT {
// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
assertThat(toolExecutionRequests).hasSize(1);
@ -226,11 +232,12 @@ class AnthropicStreamingChatModelIT {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.maxTokens(200)
.modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("2+2=?");
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
@ -254,19 +261,18 @@ class AnthropicStreamingChatModelIT {
}
@Disabled("Parallel execution of tools is not supported in the streaming mode yet.")
@ParameterizedTest
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
@Test
void should_execute_multiple_tools_in_parallel_then_answer() {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
List<ToolSpecification> toolSpecifications = singletonList(calculator);
@ -313,18 +319,18 @@ class AnthropicStreamingChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
@Test
void should_execute_a_tool_with_nested_properties_then_answer() {
// given
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.modelName(modelName)
.modelName(CLAUDE_3_5_SONNET_20240620)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
List<ToolSpecification> toolSpecifications = singletonList(weather);

View File

@ -5,6 +5,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static java.util.Collections.singletonList;
class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
@ -25,7 +26,7 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
@Override
protected String modelName() {
return AnthropicChatModelName.CLAUDE_3_SONNET_20240229.toString();
return CLAUDE_3_SONNET_20240229.toString();
}
@Override
@ -42,9 +43,4 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
protected Class<? extends Exception> expectedExceptionClass() {
return AnthropicHttpException.class;
}
@Override
protected boolean supportsTools() {
return false; // TODO remove this method after https://github.com/langchain4j/langchain4j/pull/1795 is merged
}
}

View File

@ -30,7 +30,6 @@ public class TestStreamingResponseHandler<T> implements StreamingResponseHandler
AiMessage aiMessage = (AiMessage) response.content();
if (aiMessage.hasToolExecutionRequests()){
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
assertThat(aiMessage.text()).isNull();
} else {
assertThat(aiMessage.text()).isEqualTo(expectedTextContent);
}