FEATURE: Anthropic streaming with tools (#1795)
This commit is contained in:
parent
d1f5775f8b
commit
a91ea8ae4f
|
@ -32,7 +32,10 @@ import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
|
||||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
|
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicTools;
|
||||||
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
|
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
|
||||||
import static java.util.Collections.emptyList;
|
import static java.util.Collections.emptyList;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
|
@ -185,7 +188,7 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
|
||||||
|
|
||||||
AnthropicCreateMessageRequest request = requestBuilder.build();
|
AnthropicCreateMessageRequest request = requestBuilder.build();
|
||||||
|
|
||||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages);
|
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||||
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
|
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||||
listeners.forEach(listener -> {
|
listeners.forEach(listener -> {
|
||||||
|
@ -248,16 +251,4 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
|
||||||
|
|
||||||
client.createMessage(request, listenerHandler);
|
client.createMessage(request, listenerHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static ChatModelRequest createModelListenerRequest(AnthropicCreateMessageRequest request,
|
|
||||||
List<ChatMessage> messages) {
|
|
||||||
return ChatModelRequest.builder()
|
|
||||||
.model(request.getModel())
|
|
||||||
.temperature(request.getTemperature())
|
|
||||||
.topP(request.getTopP())
|
|
||||||
.maxTokens(request.getMaxTokens())
|
|
||||||
.messages(messages)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,17 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.internal.Utils;
|
import dev.langchain4j.internal.Utils;
|
||||||
import dev.langchain4j.model.StreamingResponseHandler;
|
import dev.langchain4j.model.StreamingResponseHandler;
|
||||||
import dev.langchain4j.model.anthropic.internal.api.*;
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
|
||||||
|
import dev.langchain4j.model.output.FinishReason;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
import okhttp3.OkHttpClient;
|
import okhttp3.OkHttpClient;
|
||||||
|
@ -28,15 +38,17 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
import java.util.concurrent.locks.ReentrantLock;
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
|
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
|
||||||
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
|
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
|
||||||
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
|
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
|
||||||
import static java.util.Collections.synchronizedList;
|
import static java.util.Collections.synchronizedList;
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
|
||||||
public class DefaultAnthropicClient extends AnthropicClient {
|
public class DefaultAnthropicClient extends AnthropicClient {
|
||||||
|
|
||||||
|
@ -135,16 +147,18 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
||||||
|
|
||||||
EventSourceListener eventSourceListener = new EventSourceListener() {
|
EventSourceListener eventSourceListener = new EventSourceListener() {
|
||||||
|
|
||||||
private final ReentrantLock lock = new ReentrantLock();
|
final ReentrantLock lock = new ReentrantLock();
|
||||||
final List<String> contents = synchronizedList(new ArrayList<>());
|
final List<String> contents = synchronizedList(new ArrayList<>());
|
||||||
volatile StringBuffer currentContentBuilder = new StringBuffer();
|
volatile StringBuffer currentContentBuilder = new StringBuffer();
|
||||||
private final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
|
|
||||||
|
final AtomicReference<AnthropicContentBlockType> currentContentBlockStartType = new AtomicReference<>();
|
||||||
|
final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
final AtomicInteger inputTokenCount = new AtomicInteger();
|
final AtomicInteger inputTokenCount = new AtomicInteger();
|
||||||
final AtomicInteger outputTokenCount = new AtomicInteger();
|
final AtomicInteger outputTokenCount = new AtomicInteger();
|
||||||
private final Map<Integer, AnthropicToolExecutionRequestBuilder> toolExecutionRequestBuilderMap = new ConcurrentHashMap<>();
|
|
||||||
AtomicReference<String> responseId = new AtomicReference<>();
|
final AtomicReference<String> responseId = new AtomicReference<>();
|
||||||
AtomicReference<String> responseModel = new AtomicReference<>();
|
final AtomicReference<String> responseModel = new AtomicReference<>();
|
||||||
|
|
||||||
volatile String stopReason;
|
volatile String stopReason;
|
||||||
|
|
||||||
|
@ -233,13 +247,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
||||||
|
|
||||||
currentContentBlockStartType.set(data.contentBlock.type);
|
currentContentBlockStartType.set(data.contentBlock.type);
|
||||||
|
|
||||||
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
|
if (currentContentBlockStartType.get() == TEXT) {
|
||||||
String text = data.contentBlock.text;
|
String text = data.contentBlock.text;
|
||||||
if (isNotNullOrEmpty(text)) {
|
if (isNotNullOrEmpty(text)) {
|
||||||
currentContentBuilder().append(text);
|
currentContentBuilder().append(text);
|
||||||
handler.onNext(text);
|
handler.onNext(text);
|
||||||
}
|
}
|
||||||
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
|
} else if (currentContentBlockStartType.get() == TOOL_USE) {
|
||||||
toolExecutionRequestBuilderMap.putIfAbsent(
|
toolExecutionRequestBuilderMap.putIfAbsent(
|
||||||
data.index,
|
data.index,
|
||||||
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
|
new AnthropicToolExecutionRequestBuilder(data.contentBlock.id, data.contentBlock.name)
|
||||||
|
@ -252,13 +266,13 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (currentContentBlockStartType.get() == AnthropicContentBlockType.TEXT) {
|
if (currentContentBlockStartType.get() == TEXT) {
|
||||||
String text = data.delta.text;
|
String text = data.delta.text;
|
||||||
if (isNotNullOrEmpty(text)) {
|
if (isNotNullOrEmpty(text)) {
|
||||||
currentContentBuilder().append(text);
|
currentContentBuilder().append(text);
|
||||||
handler.onNext(text);
|
handler.onNext(text);
|
||||||
}
|
}
|
||||||
} else if (currentContentBlockStartType.get() == AnthropicContentBlockType.TOOL_USE) {
|
} else if (currentContentBlockStartType.get() == TOOL_USE) {
|
||||||
String partialJson = data.delta.partialJson;
|
String partialJson = data.delta.partialJson;
|
||||||
if (isNotNullOrEmpty(partialJson)) {
|
if (isNotNullOrEmpty(partialJson)) {
|
||||||
Integer toolExecutionsIndex = data.index;
|
Integer toolExecutionsIndex = data.index;
|
||||||
|
@ -293,30 +307,36 @@ public class DefaultAnthropicClient extends AnthropicClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Response<AiMessage> build() {
|
private Response<AiMessage> build() {
|
||||||
if (!toolExecutionRequestBuilderMap.isEmpty()) {
|
|
||||||
|
String text = String.join("\n", contents);
|
||||||
|
TokenUsage tokenUsage = new TokenUsage(inputTokenCount.get(), outputTokenCount.get());
|
||||||
|
FinishReason finishReason = toFinishReason(stopReason);
|
||||||
|
Map<String, Object> metadata = createMetadata();
|
||||||
|
|
||||||
|
if (toolExecutionRequestBuilderMap.isEmpty()) {
|
||||||
|
return Response.from(
|
||||||
|
AiMessage.from(text),
|
||||||
|
tokenUsage,
|
||||||
|
finishReason,
|
||||||
|
metadata
|
||||||
|
);
|
||||||
|
} else {
|
||||||
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
|
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderMap
|
||||||
.values().stream()
|
.values().stream()
|
||||||
.map(AnthropicToolExecutionRequestBuilder::build)
|
.map(AnthropicToolExecutionRequestBuilder::build)
|
||||||
.collect(Collectors.toList());
|
.collect(toList());
|
||||||
|
|
||||||
|
AiMessage aiMessage = isNullOrBlank(text)
|
||||||
|
? AiMessage.from(toolExecutionRequests)
|
||||||
|
: AiMessage.from(text, toolExecutionRequests);
|
||||||
|
|
||||||
return Response.from(
|
return Response.from(
|
||||||
AiMessage.from(toolExecutionRequests),
|
aiMessage,
|
||||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
tokenUsage,
|
||||||
toFinishReason(stopReason),
|
finishReason,
|
||||||
createMetadata()
|
metadata
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
String content = String.join("\n", contents);
|
|
||||||
if (!content.isEmpty()) {
|
|
||||||
return Response.from(
|
|
||||||
AiMessage.from(content),
|
|
||||||
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
|
|
||||||
toFinishReason(stopReason),
|
|
||||||
createMetadata()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, Object> createMetadata() {
|
private Map<String, Object> createMetadata() {
|
||||||
|
|
|
@ -6,8 +6,23 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
import dev.langchain4j.agent.tool.ToolParameters;
|
import dev.langchain4j.agent.tool.ToolParameters;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.image.Image;
|
import dev.langchain4j.data.image.Image;
|
||||||
import dev.langchain4j.data.message.*;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.anthropic.internal.api.*;
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.data.message.ImageContent;
|
||||||
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
|
import dev.langchain4j.data.message.TextContent;
|
||||||
|
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||||
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicTool;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
|
||||||
|
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
|
||||||
import dev.langchain4j.model.output.FinishReason;
|
import dev.langchain4j.model.output.FinishReason;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
|
||||||
|
@ -16,11 +31,18 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||||
import static dev.langchain4j.internal.Utils.*;
|
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
|
||||||
|
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||||
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT;
|
||||||
|
import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE;
|
||||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT;
|
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT;
|
||||||
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER;
|
import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER;
|
||||||
import static dev.langchain4j.model.output.FinishReason.*;
|
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||||
|
import static dev.langchain4j.model.output.FinishReason.OTHER;
|
||||||
|
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||||
|
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||||
import static java.util.Collections.emptyList;
|
import static java.util.Collections.emptyList;
|
||||||
import static java.util.Collections.emptyMap;
|
import static java.util.Collections.emptyMap;
|
||||||
import static java.util.stream.Collectors.joining;
|
import static java.util.stream.Collectors.joining;
|
||||||
|
@ -130,12 +152,12 @@ public class AnthropicMapper {
|
||||||
public static AiMessage toAiMessage(List<AnthropicContent> contents) {
|
public static AiMessage toAiMessage(List<AnthropicContent> contents) {
|
||||||
|
|
||||||
String text = contents.stream()
|
String text = contents.stream()
|
||||||
.filter(content -> AnthropicContentBlockType.TEXT == content.type)
|
.filter(content -> content.type == TEXT)
|
||||||
.map(content -> content.text)
|
.map(content -> content.text)
|
||||||
.collect(joining("\n"));
|
.collect(joining("\n"));
|
||||||
|
|
||||||
List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
|
List<ToolExecutionRequest> toolExecutionRequests = contents.stream()
|
||||||
.filter(content -> AnthropicContentBlockType.TOOL_USE == content.type)
|
.filter(content -> content.type == TOOL_USE)
|
||||||
.map(content -> {
|
.map(content -> {
|
||||||
try {
|
try {
|
||||||
return ToolExecutionRequest.builder()
|
return ToolExecutionRequest.builder()
|
||||||
|
@ -150,7 +172,7 @@ public class AnthropicMapper {
|
||||||
.collect(toList());
|
.collect(toList());
|
||||||
|
|
||||||
if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) {
|
if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) {
|
||||||
return new AiMessage(text, toolExecutionRequests);
|
return AiMessage.from(text, toolExecutionRequests);
|
||||||
} else if (!isNullOrEmpty(toolExecutionRequests)) {
|
} else if (!isNullOrEmpty(toolExecutionRequests)) {
|
||||||
return AiMessage.from(toolExecutionRequests);
|
return AiMessage.from(toolExecutionRequests);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -178,122 +178,5 @@
|
||||||
"allPublicMethods": true,
|
"allPublicMethods": true,
|
||||||
"allDeclaredFields": true,
|
"allDeclaredFields": true,
|
||||||
"allPublicFields": true
|
"allPublicFields": true
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClient",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicClientBuilderFactory",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseLoggingInterceptor",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolExecutionRequestBuilder",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.client.DefaultAnthropicClient",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.AnthropicChatModel",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.AnthropicChatModelName",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.AnthropicStreamingChatModel",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "dev.langchain4j.model.anthropic.InternalAnthropicHelper",
|
|
||||||
"allDeclaredConstructors": true,
|
|
||||||
"allPublicConstructors": true,
|
|
||||||
"allDeclaredMethods": true,
|
|
||||||
"allPublicMethods": true,
|
|
||||||
"allDeclaredFields": true,
|
|
||||||
"allPublicFields": true
|
|
||||||
}
|
}
|
||||||
]
|
]
|
|
@ -2,7 +2,13 @@ package dev.langchain4j.model.anthropic;
|
||||||
|
|
||||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.*;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.data.message.ImageContent;
|
||||||
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
|
import dev.langchain4j.data.message.TextContent;
|
||||||
|
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||||
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
@ -17,12 +23,18 @@ import java.util.Base64;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||||
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
|
||||||
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
|
||||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
||||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||||
import static dev.langchain4j.internal.Utils.readBytes;
|
import static dev.langchain4j.internal.Utils.readBytes;
|
||||||
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
|
||||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||||
import static dev.langchain4j.model.output.FinishReason.*;
|
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||||
|
import static dev.langchain4j.model.output.FinishReason.OTHER;
|
||||||
|
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||||
|
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static java.util.Arrays.stream;
|
import static java.util.Arrays.stream;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
|
@ -336,14 +348,13 @@ class AnthropicChatModelIT {
|
||||||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@Test
|
||||||
@MethodSource("models_supporting_tools")
|
void should_execute_multiple_tools_in_parallel_then_answer() {
|
||||||
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
|
|
||||||
|
|
||||||
// given
|
// given
|
||||||
ChatLanguageModel model = AnthropicChatModel.builder()
|
ChatLanguageModel model = AnthropicChatModel.builder()
|
||||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||||
.modelName(modelName)
|
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
|
@ -399,14 +410,13 @@ class AnthropicChatModelIT {
|
||||||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@Test
|
||||||
@MethodSource("models_supporting_tools")
|
void should_execute_a_tool_with_nested_properties_then_answer() {
|
||||||
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
|
|
||||||
|
|
||||||
// given
|
// given
|
||||||
ChatLanguageModel model = AnthropicChatModel.builder()
|
ChatLanguageModel model = AnthropicChatModel.builder()
|
||||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||||
.modelName(modelName)
|
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
|
|
|
@ -2,13 +2,17 @@ package dev.langchain4j.model.anthropic;
|
||||||
|
|
||||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.*;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.data.message.ImageContent;
|
||||||
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
|
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||||
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.junit.jupiter.api.Disabled;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.EnumSource;
|
import org.junit.jupiter.params.provider.EnumSource;
|
||||||
|
@ -18,11 +22,14 @@ import java.time.Duration;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||||
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
|
||||||
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.property;
|
||||||
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
|
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
|
||||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||||
import static dev.langchain4j.internal.Utils.readBytes;
|
import static dev.langchain4j.internal.Utils.readBytes;
|
||||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL;
|
||||||
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620;
|
||||||
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||||
|
@ -175,11 +182,11 @@ class AnthropicStreamingChatModelIT {
|
||||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||||
.modelName(modelName)
|
.modelName(modelName)
|
||||||
.maxTokens(200)
|
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
UserMessage userMessage = userMessage("2+2=?");
|
UserMessage userMessage = userMessage("2+2=?");
|
||||||
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
||||||
|
|
||||||
|
@ -190,7 +197,6 @@ class AnthropicStreamingChatModelIT {
|
||||||
// then
|
// then
|
||||||
Response<AiMessage> response = handler.get();
|
Response<AiMessage> response = handler.get();
|
||||||
AiMessage aiMessage = response.content();
|
AiMessage aiMessage = response.content();
|
||||||
assertThat(aiMessage.text()).isNull();
|
|
||||||
|
|
||||||
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
|
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
|
||||||
assertThat(toolExecutionRequests).hasSize(1);
|
assertThat(toolExecutionRequests).hasSize(1);
|
||||||
|
@ -226,11 +232,12 @@ class AnthropicStreamingChatModelIT {
|
||||||
// given
|
// given
|
||||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||||
.maxTokens(200)
|
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
UserMessage userMessage = userMessage("2+2=?");
|
UserMessage userMessage = userMessage("2+2=?");
|
||||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
|
||||||
|
@ -254,19 +261,18 @@ class AnthropicStreamingChatModelIT {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Disabled("Parallel execution of tools is not supported in the streaming mode yet.")
|
@Test
|
||||||
@ParameterizedTest
|
void should_execute_multiple_tools_in_parallel_then_answer() {
|
||||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
|
||||||
void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) {
|
|
||||||
|
|
||||||
// given
|
// given
|
||||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||||
.modelName(modelName)
|
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
|
SystemMessage systemMessage = systemMessage("Do not think, nor explain step by step what you do. Output the result only.");
|
||||||
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
|
UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!");
|
||||||
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
List<ToolSpecification> toolSpecifications = singletonList(calculator);
|
||||||
|
@ -313,18 +319,18 @@ class AnthropicStreamingChatModelIT {
|
||||||
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@Test
|
||||||
@MethodSource("dev.langchain4j.model.anthropic.AnthropicChatModelIT#models_supporting_tools")
|
void should_execute_a_tool_with_nested_properties_then_answer() {
|
||||||
void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) {
|
|
||||||
|
|
||||||
// given
|
// given
|
||||||
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder()
|
||||||
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
|
||||||
.modelName(modelName)
|
.modelName(CLAUDE_3_5_SONNET_20240620)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
|
UserMessage userMessage = userMessage("What is the weather in Berlin in Celsius?");
|
||||||
List<ToolSpecification> toolSpecifications = singletonList(weather);
|
List<ToolSpecification> toolSpecifications = singletonList(weather);
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
|
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
|
||||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
|
|
||||||
class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
||||||
|
@ -25,7 +26,7 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected String modelName() {
|
protected String modelName() {
|
||||||
return AnthropicChatModelName.CLAUDE_3_SONNET_20240229.toString();
|
return CLAUDE_3_SONNET_20240229.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -42,9 +43,4 @@ class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT
|
||||||
protected Class<? extends Exception> expectedExceptionClass() {
|
protected Class<? extends Exception> expectedExceptionClass() {
|
||||||
return AnthropicHttpException.class;
|
return AnthropicHttpException.class;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected boolean supportsTools() {
|
|
||||||
return false; // TODO remove this method after https://github.com/langchain4j/langchain4j/pull/1795 is merged
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,6 @@ public class TestStreamingResponseHandler<T> implements StreamingResponseHandler
|
||||||
AiMessage aiMessage = (AiMessage) response.content();
|
AiMessage aiMessage = (AiMessage) response.content();
|
||||||
if (aiMessage.hasToolExecutionRequests()){
|
if (aiMessage.hasToolExecutionRequests()){
|
||||||
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
|
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
|
||||||
assertThat(aiMessage.text()).isNull();
|
|
||||||
} else {
|
} else {
|
||||||
assertThat(aiMessage.text()).isEqualTo(expectedTextContent);
|
assertThat(aiMessage.text()).isEqualTo(expectedTextContent);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue