Added ChatModelListener in anthropic model (#1791)

This commit is contained in:
LangChain4j 2024-09-19 11:44:28 +02:00
parent 6e191b794c
commit 0b29cb21e6
9 changed files with 134 additions and 97 deletions

View File

@ -26,9 +26,9 @@ import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
import static dev.langchain4j.model.anthropic.internal.InternalAnthropicHelper.createErrorContext;
import static dev.langchain4j.model.anthropic.internal.InternalAnthropicHelper.createModelListenerRequest;
import static dev.langchain4j.model.anthropic.internal.InternalAnthropicHelper.createModelListenerResponse;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createErrorContext;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerResponse;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
import static java.util.Collections.emptyList;

View File

@ -1,13 +1,14 @@
package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.InternalAnthropicHelper;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
@ -28,11 +29,8 @@ import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAiMessage;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toTokenUsage;
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
import static java.util.Collections.emptyList;
@ -198,8 +196,8 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
@Override
public void onComplete(Response<AiMessage> response) {
ChatModelResponse modelListenerResponse = InternalAnthropicHelper.createModelListenerResponse(
null,
null,
(String) response.metadata().get("id"),
(String) response.metadata().get("model"),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(

View File

@ -1,4 +1,4 @@
package dev.langchain4j.model.anthropic.internal;
package dev.langchain4j.model.anthropic;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
@ -13,9 +13,11 @@ import dev.langchain4j.model.output.Response;
import java.util.List;
import java.util.Map;
public class InternalAnthropicHelper {
class InternalAnthropicHelper {
public static ChatModelErrorContext createErrorContext(Throwable e, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
static ChatModelErrorContext createErrorContext(Throwable e,
ChatModelRequest modelListenerRequest,
Map<Object, Object> attributes) {
Throwable error;
if (e.getCause() instanceof AnthropicHttpException) {
error = e.getCause();
@ -29,10 +31,9 @@ public class InternalAnthropicHelper {
null,
attributes
);
}
public static ChatModelRequest createModelListenerRequest(AnthropicCreateMessageRequest request,
static ChatModelRequest createModelListenerRequest(AnthropicCreateMessageRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return ChatModelRequest.builder()
@ -45,7 +46,7 @@ public class InternalAnthropicHelper {
.build();
}
public static ChatModelResponse createModelListenerResponse(String responseId,
static ChatModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
if (response == null) {

View File

@ -4,7 +4,15 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.*;
import dev.langchain4j.model.anthropic.internal.api.AnthropicApi;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta;
import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import okhttp3.OkHttpClient;
@ -20,12 +28,17 @@ import retrofit2.converter.jackson.JacksonConverterFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason;
import static java.util.Collections.synchronizedList;
@ -134,6 +147,9 @@ public class DefaultAnthropicClient extends AnthropicClient {
final AtomicInteger inputTokenCount = new AtomicInteger();
final AtomicInteger outputTokenCount = new AtomicInteger();
AtomicReference<String> responseId = new AtomicReference<>();
AtomicReference<String> responseModel = new AtomicReference<>();
volatile String stopReason;
private StringBuffer currentContentBuilder() {
@ -191,8 +207,17 @@ public class DefaultAnthropicClient extends AnthropicClient {
}
private void handleMessageStart(AnthropicStreamingData data) {
if (data.message != null && data.message.usage != null) {
handleUsage(data.message.usage);
AnthropicResponseMessage message = data.message;
if (message != null) {
if (message.usage != null) {
handleUsage(message.usage);
}
if (message.id != null) {
responseId.set(message.id);
}
if (message.model != null) {
responseModel.set(message.model);
}
}
}
@ -246,11 +271,23 @@ public class DefaultAnthropicClient extends AnthropicClient {
Response<AiMessage> response = Response.from(
AiMessage.from(String.join("\n", contents)),
new TokenUsage(inputTokenCount.get(), outputTokenCount.get()),
toFinishReason(stopReason)
toFinishReason(stopReason),
createMetadata()
);
handler.onComplete(response);
}
private Map<String, Object> createMetadata() {
Map<String, Object> metadata = new HashMap<>();
if (responseId.get() != null) {
metadata.put("id", responseId.get());
}
if (responseModel.get() != null) {
metadata.put("model", responseModel.get());
}
return metadata;
}
private void handleError(String dataString) {
handler.onError(new AnthropicHttpException(null, dataString));
}

View File

@ -5,27 +5,19 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.ChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import java.time.Duration;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
public class AnthropicChatModelListenerIT extends ChatModelListenerIT {
class AnthropicChatModelListenerIT extends ChatModelListenerIT {
@Override
protected ChatLanguageModel createModel(ChatModelListener listener) {
return AnthropicChatModel.builder()
.baseUrl("https://api.anthropic.com/v1/")
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.version("2023-06-01")
.modelName(modelName())
.temperature(temperature())
.topP(topP())
.topK(1)
.maxTokens(maxTokens())
.stopSequences(asList("hello", "world"))
.timeout(Duration.ofSeconds(30))
.maxRetries(1)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
@ -40,14 +32,7 @@ public class AnthropicChatModelListenerIT extends ChatModelListenerIT {
@Override
protected ChatLanguageModel createFailingModel(ChatModelListener listener) {
return AnthropicChatModel.builder()
.apiKey("err")
.topP(topP())
.topK(1)
.maxTokens(maxTokens())
.modelName("test")
.stopSequences(asList("hello", "world"))
.timeout(Duration.ofSeconds(30))
.maxRetries(1)
.apiKey("banana")
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
@ -58,14 +43,4 @@ public class AnthropicChatModelListenerIT extends ChatModelListenerIT {
protected Class<? extends Exception> expectedExceptionClass() {
return AnthropicHttpException.class;
}
@Override
protected boolean assertResponseId() {
return false;
}
@Override
protected boolean assertFinishReason() {
return false;
}
}

View File

@ -5,26 +5,18 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import java.time.Duration;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
public class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
class AnthropicStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
@Override
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
return AnthropicStreamingChatModel.builder()
.baseUrl("https://api.anthropic.com/v1/")
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.version("2023-06-01")
.modelName(modelName())
.temperature(temperature())
.topP(topP())
.topK(1)
.maxTokens(maxTokens())
.stopSequences(asList("hello", "world"))
.timeout(Duration.ofSeconds(3))
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
@ -39,8 +31,7 @@ public class AnthropicStreamingChatModelListenerIT extends StreamingChatModelLis
@Override
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
return AnthropicStreamingChatModel.builder()
.apiKey("err")
.timeout(Duration.ofSeconds(3))
.apiKey("banana")
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
@ -52,23 +43,8 @@ public class AnthropicStreamingChatModelListenerIT extends StreamingChatModelLis
return AnthropicHttpException.class;
}
@Override
protected boolean assertResponseModel() {
return false;
}
@Override
protected boolean supportsTools() {
return false;
}
@Override
protected boolean assertResponseId() {
return false;
}
@Override
protected boolean assertFinishReason() {
return false;
return false; // TODO remove this method after https://github.com/langchain4j/langchain4j/pull/1795 is merged
}
}

View File

@ -1,12 +1,15 @@
package dev.langchain4j.model.output;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Collections.emptyMap;
/**
* Represents the response from various types of models, including language, chat, embedding, and moderation models.
* This class encapsulates the generated content, token usage statistics, and finish reason.
* This class encapsulates the generated content, token usage statistics, finish reason, and response metadata.
*
* @param <T> The type of content generated by the model.
*/
@ -15,6 +18,7 @@ public class Response<T> {
private final T content;
private final TokenUsage tokenUsage;
private final FinishReason finishReason;
private final Map<String, Object> metadata;
/**
* Create a new Response.
@ -24,24 +28,38 @@ public class Response<T> {
* @param content the content to wrap.
*/
public Response(T content) {
this(content, null, null);
this(content, null, null, emptyMap());
}
/**
* Create a new Response.
*
* @param content the content to wrap.
* @param tokenUsage the token usage statistics, or {@code null}.
* @param content the content to wrap.
* @param tokenUsage the token usage statistics, or {@code null}.
* @param finishReason the finish reason, or {@code null}.
*/
public Response(T content, TokenUsage tokenUsage, FinishReason finishReason) {
this(content, tokenUsage, finishReason, emptyMap());
}
/**
* Create a new Response.
*
* @param content the content to wrap.
* @param tokenUsage the token usage statistics, or {@code null}.
* @param finishReason the finish reason, or {@code null}.
* @param metadata the response metadata, or {@code null}.
*/
public Response(T content, TokenUsage tokenUsage, FinishReason finishReason, Map<String, Object> metadata) {
this.content = ensureNotNull(content, "content");
this.tokenUsage = tokenUsage;
this.finishReason = finishReason;
this.metadata = metadata == null ? emptyMap() : new HashMap<>(metadata);
}
/**
* Get the content.
*
* @return the content.
*/
public T content() {
@ -50,6 +68,7 @@ public class Response<T> {
/**
* Get the token usage statistics.
*
* @return the token usage statistics, or {@code null}.
*/
public TokenUsage tokenUsage() {
@ -58,12 +77,22 @@ public class Response<T> {
/**
* Get the finish reason.
*
* @return the finish reason, or {@code null}.
*/
public FinishReason finishReason() {
return finishReason;
}
/**
* Get the response metadata.
*
* @return the response metadata.
*/
public Map<String, Object> metadata() {
return metadata;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
@ -71,12 +100,13 @@ public class Response<T> {
Response<?> that = (Response<?>) o;
return Objects.equals(this.content, that.content)
&& Objects.equals(this.tokenUsage, that.tokenUsage)
&& Objects.equals(this.finishReason, that.finishReason);
&& Objects.equals(this.finishReason, that.finishReason)
&& Objects.equals(this.metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(content, tokenUsage, finishReason);
return Objects.hash(content, tokenUsage, finishReason, metadata);
}
@Override
@ -85,14 +115,16 @@ public class Response<T> {
" content = " + content +
", tokenUsage = " + tokenUsage +
", finishReason = " + finishReason +
", metadata = " + metadata +
" }";
}
/**
* Create a new Response.
*
* @param content the content to wrap.
* @param <T> the type of content.
* @return the new Response.
* @param <T> the type of content.
*/
public static <T> Response<T> from(T content) {
return new Response<>(content);
@ -100,10 +132,11 @@ public class Response<T> {
/**
* Create a new Response.
* @param content the content to wrap.
*
* @param content the content to wrap.
* @param tokenUsage the token usage statistics.
* @param <T> the type of content.
* @return the new Response.
* @param <T> the type of content.
*/
public static <T> Response<T> from(T content, TokenUsage tokenUsage) {
return new Response<>(content, tokenUsage, null);
@ -111,13 +144,28 @@ public class Response<T> {
/**
* Create a new Response.
* @param content the content to wrap.
* @param tokenUsage the token usage statistics.
*
* @param content the content to wrap.
* @param tokenUsage the token usage statistics.
* @param finishReason the finish reason.
* @param <T> the type of content.
* @return the new Response.
* @param <T> the type of content.
*/
public static <T> Response<T> from(T content, TokenUsage tokenUsage, FinishReason finishReason) {
return new Response<>(content, tokenUsage, finishReason);
}
/**
* Create a new Response.
*
* @param content the content to wrap.
* @param tokenUsage the token usage statistics.
* @param finishReason the finish reason.
* @param metadata the response metadata.
* @param <T> the type of content.
* @return the new Response.
*/
public static <T> Response<T> from(T content, TokenUsage tokenUsage, FinishReason finishReason, Map<String, Object> metadata) {
return new Response<>(content, tokenUsage, finishReason, metadata);
}
}

View File

@ -158,11 +158,11 @@ public abstract class StreamingChatModelListenerIT {
return true;
}
protected boolean assertResponseModel() {
protected boolean assertResponseId() {
return true;
}
protected boolean assertResponseId() {
protected boolean assertResponseModel() {
return true;
}

View File

@ -11,7 +11,7 @@ class ResponseTest implements WithAssertions {
assertThat(response.content()).isEqualTo("content");
assertThat(response.tokenUsage()).isNull();
assertThat(response.finishReason()).isNull();
assertThat(response).hasToString("Response { content = content, tokenUsage = null, finishReason = null }");
assertThat(response).hasToString("Response { content = content, tokenUsage = null, finishReason = null, metadata = {} }");
}
{
TokenUsage tokenUsage = new TokenUsage(1, 2, 3);
@ -22,9 +22,10 @@ class ResponseTest implements WithAssertions {
assertThat(response)
.hasToString(
"Response { " +
"content = content, tokenUsage = TokenUsage { " +
"inputTokenCount = 1, outputTokenCount = 2, totalTokenCount = 3 }, " +
"finishReason = null }");
"content = content, tokenUsage = TokenUsage { " +
"inputTokenCount = 1, outputTokenCount = 2, totalTokenCount = 3 }, " +
"finishReason = null, " +
"metadata = {} }");
}
{
TokenUsage tokenUsage = new TokenUsage(1, 2, 3);
@ -37,7 +38,8 @@ class ResponseTest implements WithAssertions {
"Response { " +
"content = content, tokenUsage = TokenUsage { " +
"inputTokenCount = 1, outputTokenCount = 2, totalTokenCount = 3 }, " +
"finishReason = LENGTH }");
"finishReason = LENGTH, " +
"metadata = {} }");
}
}