FEATURE: Anthropic streaming with tools (#1795)
This commit is contained in:
parent
d1f5775f8b
commit
a91ea8ae4f
|
@ -32,7 +32,10 @@ import static dev.langchain4j.internal.Utils.getOrDefault;
|
|||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
|
||||
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicTools;
|
||||
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
@ -185,7 +188,7 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
|
|||
|
||||
AnthropicCreateMessageRequest request = requestBuilder.build();
|
||||
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages);
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||
listeners.forEach(listener -> {
|
||||
|
@ -248,16 +251,4 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
|
|||
|
||||
client.createMessage(request, listenerHandler);
|
||||
}
|
||||
|
||||
|
||||
static ChatModelRequest createModelListenerRequest(AnthropicCreateMessageRequest request,
|
||||
List<ChatMessage> messages) {
|
||||
return ChatModelRequest.builder()
|
||||
.model(request.getModel())
|
||||
.temperature(request.getTemperature())
|
||||
.topP(request.getTopP())
|
||||
.maxTokens(request.getMaxTokens())
|
||||
.messages(messages)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,17 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.anthropic.internal.api.*;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import okhttp3.OkHttpClient;
|
||||
|
@ -28,15 +38,17 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
|
||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
|
||||
import static java.util.Collections.synchronizedList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class DefaultAnthropicClient extends AnthropicClient {
|
||||
|
||||
|
@ -135,16 +147,18 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
|
||||
EventSourceListener eventSourceListener = new EventSourceListener() {
|
||||
|
||||
private final ReentrantLock lock = new ReentrantLock();
|
||||
final ReentrantLock lock = new ReentrantLock();
|
||||
final List<String> contents = synchronizedList(new ArrayList<>());
|
||||
volatile StringBuffer currentContentBuilder = new StringBuffer();
|
||||
private final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
|
||||
|
||||
final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
|
||||
final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
|
||||
|
||||
final AtomicInteger inputTokenCount = new AtomicInteger();
|
||||
final AtomicInteger outputTokenCount = new AtomicInteger();
|
||||
private final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
|
||||
AtomicReference<String> responseId = new AtomicReference<>();
|
||||
AtomicReference<String> responseModel = new AtomicReference<>();
|
||||
|
||||
final AtomicReference<String> responseId = new AtomicReference<>();
|
||||
final AtomicReference<String> responseModel = new AtomicReference<>();
|
||||
|
||||
volatile String stopReason;
|
||||
|
||||
|
@ -233,13 +247,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
|
||||
currentContentBlockStartType.set(data.contentBlock.type);
|
||||
|
||||
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
|
||||
if (currentContentBlockStartType.get() == TEXT) {
|
||||
String text = data.contentBlock.text;
|
||||
if (isNotNullOrEmpty(text)) {
|
||||
currentContentBuilder().append(text);
|
||||
handler.onNext(text);
|
||||
}
|
||||
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
|
||||
} else if (currentContentBlockStartType.get() == TOOL_USE) {
|
||||
toolExecutionRequestBuilderMap.putIfAbsent(
|
||||
data.index,
|
||||
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
|
||||
|
@ -252,13 +266,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
return;
|
||||
}
|
||||
|
||||
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
|
||||
if (currentContentBlockStartType.get() == TEXT) {
|
||||
String text = data.delta.text;
|
||||
if (isNotNullOrEmpty(text)) {
|
||||
currentContentBuilder().append(text);
|
||||
handler.onNext(text);
|
||||
}
|
||||
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
|
||||
} else if (currentContentBlockStartType.get() == TOOL_USE) {
|
||||
String partialJson = data.delta.partialJson;
|
||||
if (isNotNullOrEmpty(partialJson)) {
|
||||
Integer toolExecutionsIndex = data.index;
|
||||
|
@ -293,30 +307,36 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
|||
}
|
||||
|
||||
private Response<AiMessage> build() {
|
||||
if (!toolExecutionRequestBuilderMap.isEmpty()) {
|
||||
|
||||
String text = String.join("\n", contents);
|
||||
TokenUsage tokenUsage = new TokenUsage(inputTokenCount.get(), outputTokenCount.get());
|
||||
FinishReason finishReason = toFinishReason(stopReason);
|
||||
Map<String, Object> metadata = createMetadata();
|
||||
|
||||
if (toolExecutionRequestBuilderMap.isEmpty()) {
|
||||
return Response.from(
|
||||
AiMessage.from(text),
|
||||
tokenUsage,
|
||||
finishReason,
|
||||
metadata
|
||||
);
|
||||
} else {
|
||||
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
|
||||
.values().stream()
|
||||
.map(AnthropicToolExecutionRequestBuilder::build)
|
||||
.collect(Collectors.toList());
|
||||
.collect(toList());
|
||||
|
||||
AiMessage aiMessage = isNullOrBlank(text)
|
||||
? AiMessage.from(toolExecutionRequests)
|
||||
: AiMessage.from(text, toolExecutionRequests);
|
||||
|
||||
return Response.from(
|
||||
AiMessage.from(toolExecutionRequests),
|
||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
||||
toFinishReason(stopReason),
|
||||
createMetadata()
|
||||
aiMessage,
|
||||
tokenUsage,
|
||||
finishReason,
|
||||
metadata
|
||||
);
|
||||
}
|
||||
|
||||
String content = String.join("\n", contents);
|
||||
if (!content.isEmpty()) {
|
||||
return Response.from(
|
||||
AiMessage.from(content),
|
||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
||||
toFinishReason(stopReason),
|
||||
createMetadata()
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private Map<String, Object> createMetadata() {
|
||||
|
|
|
@ -6,8 +6,23 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
|||
import dev.langchain4j.agent.tool.ToolParameters;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.anthropic.internal.api.*;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.TextContent;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicTool;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
|
||||
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
|
@ -16,11 +31,18 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT;
|
||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.OTHER;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.emptyMap;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
|
@ -130,12 +152,12 @@ public class AnthropicMapper {
|
|||
public static AiMessage toAiMessage(List<AnthropicContent> contents) {
|
||||
|
||||
String text = contents.stream()
|
||||
.filter(content -> AnthropicContentBlockType.TEXT == content.type)
|
||||
.filter(content -> content.type == TEXT)
|
||||
.map(content -> content.text)
|
||||
.collect(joining("\n"));
|
||||
|
||||
List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
|
||||
.filter(content -> AnthropicContentBlockType.TOOL_USE == content.type)
|
||||
.filter(content -> content.type == TOOL_USE)
|
||||
.map(content -> {
|
||||
try {
|
||||
return ToolExecutionRequest.builder()
|
||||
|
@ -150,7 +172,7 @@ public class AnthropicMapper {
|
|||
.collect(toList());
|
||||
|
||||
if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) {
|
||||
return new AiMessage(text, toolExecutionRequests);
|
||||
return AiMessage.from(text, toolExecutionRequests);
|
||||
} else if (!isNullOrEmpty(toolExecutionRequests)) {
|
||||
return AiMessage.from(toolExecutionRequests);
|
||||
} else {
|
||||
|
|
|
@ -126,7 +126,7 @@
|
|||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
|
||||
"name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
|
@ -178,122 +178,5 @@
|
|||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClient",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClientBuilderFactory",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseLoggingInterceptor",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolExecutionRequestBuilder",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.client.DefaultAnthropicClient",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.AnthropicChatModel",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.AnthropicChatModelName",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.AnthropicStreamingChatModel",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
},
|
||||
{
|
||||
"name": "dev.langchain4j.model.anthropic.InternalAnthropicHelper",
|
||||
"allDeclaredConstructors": true,
|
||||
"allPublicConstructors": true,
|
||||
"allDeclaredMethods": true,
|
||||
"allPublicMethods": true,
|
||||
"allDeclaredFields": true,
|
||||
"allPublicFields": true
|
||||
}
|
||||
]
|
|
@ -2,7 +2,13 @@ package dev.langchain4j.model.anthropic;
|
|||
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.TextContent;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
@ -17,12 +23,18 @@ import java.util.Base64;
|
|||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
|
||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.OTHER;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Arrays.stream;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
@ -336,14 +348,13 @@ class AnthropicChatModelIT {
|
|||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models_supporting_tools")
|
||||
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
|
||||
@Test
|
||||
void should_execute_multiple_tools_in_parallel_then_answer() {
|
||||
|
||||
// given
|
||||
ChatLanguageModel model = AnthropicChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
|
@ -399,14 +410,13 @@ class AnthropicChatModelIT {
|
|||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models_supporting_tools")
|
||||
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
|
||||
@Test
|
||||
void should_execute_a_tool_with_nested_properties_then_answer() {
|
||||
|
||||
// given
|
||||
ChatLanguageModel model = AnthropicChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
|
|
|
@ -2,13 +2,17 @@ package dev.langchain4j.model.anthropic;
|
|||
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
|
@ -18,11 +22,14 @@ import java.time.Duration;
|
|||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
|
||||
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
|
@ -175,11 +182,11 @@ class AnthropicStreamingChatModelIT {
|
|||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.maxTokens(200)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = userMessage("2+2=?");
|
||||
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
||||
|
||||
|
@ -190,7 +197,6 @@ class AnthropicStreamingChatModelIT {
|
|||
// then
|
||||
Response<AiMessage> response = handler.get();
|
||||
AiMessage aiMessage = response.content();
|
||||
assertThat(aiMessage.text()).isNull();
|
||||
|
||||
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
|
||||
assertThat(toolExecutionRequests).hasSize(1);
|
||||
|
@ -226,11 +232,12 @@ class AnthropicStreamingChatModelIT {
|
|||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.maxTokens(200)
|
||||
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = userMessage("2+2=?");
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
|
||||
|
@ -254,19 +261,18 @@ class AnthropicStreamingChatModelIT {
|
|||
}
|
||||
|
||||
|
||||
@Disabled("Parallel execution of tools is not supported in the streaming mode yet.")
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
||||
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
|
||||
@Test
|
||||
void should_execute_multiple_tools_in_parallel_then_answer() {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
|
||||
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
|
||||
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
||||
|
@ -313,18 +319,18 @@ class AnthropicStreamingChatModelIT {
|
|||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
||||
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
|
||||
@Test
|
||||
void should_execute_a_tool_with_nested_properties_then_answer() {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||
.modelName(modelName)
|
||||
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
|
||||
List<ToolSpecification> toolSpecifications = singletonList(weather);
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
|||
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
|
||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
||||
|
@ -25,7 +26,7 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
|
|||
|
||||
@Override
|
||||
protected String modelName() {
|
||||
return AnthropicChatModelName.CLAUDE_3_SONNET_20240229.toString();
|
||||
return CLAUDE_3_SONNET_20240229.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -42,9 +43,4 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
|
|||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return AnthropicHttpException.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsTools() {
|
||||
return false; // TODO remove this method after https://github.com/langchain4j/langchain4j/pull/1795 is merged
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,6 @@ public class TestStreamingResponseHandler<T> implements StreamingResponseHandler
|
|||
AiMessage aiMessage = (AiMessage) response.content();
|
||||
if (aiMessage.hasToolExecutionRequests()){
|
||||
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
|
||||
assertThat(aiMessage.text()).isNull();
|
||||
} else {
|
||||
assertThat(aiMessage.text()).isEqualTo(expectedTextContent);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue