LLM Observability: Part 1 (#1058)
## Issue https://github.com/langchain4j/langchain4j/issues/199 ## Change - Added `ModelListener`, `ChatLanguageModelRequest`, and `ChatLanguageModelResponse` that are compatible (have all the required attributes) with [OTEL LLM semconv draft](https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md). - Added an option to attach multiple `ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>` to `OpenAiChatModel` and `OpenAiStreamingChatModel` (pilot module). ### `ChatLanguageModelRequest` ```java public class ChatLanguageModelRequest { private final String model; private final Double temperature; private final Double topP; private final Integer maxTokens; private final List<ChatMessage> messages; private final List<ToolSpecification> toolSpecifications; } ``` ### `ChatLanguageModelResponse` ```java public class ChatLanguageModelResponse { private final String id; private final String model; private final TokenUsage tokenUsage; private final FinishReason finishReason; private final AiMessage aiMessage; } ``` ## Example ```java ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener = new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() { @Override public void onRequest(ChatLanguageModelRequest request) { // handle request } @Override public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) { // handle response } @Override public void onError(Throwable error, ChatLanguageModelResponse response, ChatLanguageModelRequest request) { // handle error } }; OpenAiChatModel model = OpenAiChatModel.builder() .apiKey(...) .listeners(singletonList(modelListener)) .build(); ``` ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features)
This commit is contained in:
parent
a49ac33519
commit
6818e279bf
|
@ -125,6 +125,8 @@
|
|||
<rule>
|
||||
<excludes>
|
||||
<exclude>dev.langchain4j.data.document</exclude>
|
||||
<exclude>dev.langchain4j.model.chat.listener</exclude>
|
||||
<exclude>dev.langchain4j.model.listener</exclude>
|
||||
<exclude>dev.langchain4j.store.embedding</exclude>
|
||||
<exclude>dev.langchain4j.store.embedding.filter</exclude>
|
||||
<exclude>dev.langchain4j.store.embedding.filter.logical</exclude>
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
package dev.langchain4j.model.chat.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||
|
||||
/**
|
||||
* A request to the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
|
||||
* intended to be used with {@link ModelListener}.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatLanguageModelRequest {
|
||||
|
||||
private final String model;
|
||||
private final Double temperature;
|
||||
private final Double topP;
|
||||
private final Integer maxTokens;
|
||||
private final List<ChatMessage> messages;
|
||||
private final List<ToolSpecification> toolSpecifications;
|
||||
|
||||
@Builder
|
||||
public ChatLanguageModelRequest(String model,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
Integer maxTokens,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
this.model = model;
|
||||
this.temperature = temperature;
|
||||
this.topP = topP;
|
||||
this.maxTokens = maxTokens;
|
||||
this.messages = copyIfNotNull(messages);
|
||||
this.toolSpecifications = copyIfNotNull(toolSpecifications);
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public Double temperature() {
|
||||
return temperature;
|
||||
}
|
||||
|
||||
public Double topP() {
|
||||
return topP;
|
||||
}
|
||||
|
||||
public Integer maxTokens() {
|
||||
return maxTokens;
|
||||
}
|
||||
|
||||
public List<ChatMessage> messages() {
|
||||
return messages;
|
||||
}
|
||||
|
||||
public List<ToolSpecification> toolSpecifications() {
|
||||
return toolSpecifications;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package dev.langchain4j.model.chat.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import lombok.Builder;
|
||||
|
||||
/**
|
||||
* A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
|
||||
* intended to be used with {@link ModelListener}.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatLanguageModelResponse {
|
||||
|
||||
private final String id;
|
||||
private final String model;
|
||||
private final TokenUsage tokenUsage;
|
||||
private final FinishReason finishReason;
|
||||
private final AiMessage aiMessage;
|
||||
|
||||
@Builder
|
||||
public ChatLanguageModelResponse(String id,
|
||||
String model,
|
||||
TokenUsage tokenUsage,
|
||||
FinishReason finishReason,
|
||||
AiMessage aiMessage) {
|
||||
this.id = id;
|
||||
this.model = model;
|
||||
this.tokenUsage = tokenUsage;
|
||||
this.finishReason = finishReason;
|
||||
this.aiMessage = aiMessage;
|
||||
}
|
||||
|
||||
public String id() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public TokenUsage tokenUsage() {
|
||||
return tokenUsage;
|
||||
}
|
||||
|
||||
public FinishReason finishReason() {
|
||||
return finishReason;
|
||||
}
|
||||
|
||||
public AiMessage aiMessage() {
|
||||
return aiMessage;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package dev.langchain4j.model.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
/**
|
||||
* A generic model listener.
|
||||
* It can listen for requests to and responses from various model types,
|
||||
* such as {@link ChatLanguageModel}, {@link StreamingChatLanguageModel}, {@link EmbeddingModel}, etc.
|
||||
*/
|
||||
@Experimental
|
||||
public interface ModelListener<Request, Response> {
|
||||
|
||||
/**
|
||||
* This method is called before the request is sent to the model.
|
||||
*
|
||||
* @param request The request to the model.
|
||||
*/
|
||||
@Experimental
|
||||
default void onRequest(Request request) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called after the response is received from the model.
|
||||
*
|
||||
* @param response The response from the model.
|
||||
* @param request The request this response corresponds to.
|
||||
*/
|
||||
@Experimental
|
||||
default void onResponse(Response response, Request request) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called when an error occurs.
|
||||
* <br>
|
||||
* When streaming (e.g., using {@link StreamingChatLanguageModel}),
|
||||
* the {@code response} might contain a partial response that was received before the error occurred.
|
||||
*
|
||||
* @param error The error that occurred.
|
||||
* @param response The partial response, if available.
|
||||
* @param request The request this error corresponds to.
|
||||
*/
|
||||
@Experimental
|
||||
default void onError(Throwable error, Response response, Request request) {
|
||||
|
||||
}
|
||||
}
|
|
@ -10,6 +10,8 @@ import dev.langchain4j.data.message.Content;
|
|||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
@ -280,4 +282,33 @@ public class InternalOpenAiHelper {
|
|||
static Response<AiMessage> removeTokenUsage(Response<AiMessage> response) {
|
||||
return Response.from(response.content(), null, response.finishReason());
|
||||
}
|
||||
|
||||
static ChatLanguageModelRequest createModelListenerRequest(ChatCompletionRequest request,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
return ChatLanguageModelRequest.builder()
|
||||
.model(request.model())
|
||||
.temperature(request.temperature())
|
||||
.topP(request.topP())
|
||||
.maxTokens(request.maxTokens())
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
}
|
||||
|
||||
static ChatLanguageModelResponse createModelListenerResponse(String responseId,
|
||||
String responseModel,
|
||||
Response<AiMessage> response) {
|
||||
if (response == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return ChatLanguageModelResponse.builder()
|
||||
.id(responseId)
|
||||
.model(responseModel)
|
||||
.tokenUsage(response.tokenUsage())
|
||||
.finishReason(response.finishReason())
|
||||
.aiMessage(response.content())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.OpenAiClient;
|
||||
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
|
@ -9,12 +10,17 @@ import dev.langchain4j.data.message.ChatMessage;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.net.Proxy;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -24,12 +30,14 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
|||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
|
||||
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
@Slf4j
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
private final OpenAiClient client;
|
||||
|
@ -46,6 +54,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
private final String user;
|
||||
private final Integer maxRetries;
|
||||
private final Tokenizer tokenizer;
|
||||
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiChatModel(String baseUrl,
|
||||
|
@ -68,7 +77,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders) {
|
||||
Map<String, String> customHeaders,
|
||||
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
|
||||
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
|
@ -104,6 +114,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
this.user = user;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
}
|
||||
|
||||
public String modelName() {
|
||||
|
@ -152,13 +163,56 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
|
||||
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
return Response.from(
|
||||
aiMessageFrom(response),
|
||||
tokenUsageFrom(response.usage()),
|
||||
finishReasonFrom(response.choices().get(0).finishReason())
|
||||
);
|
||||
try {
|
||||
ChatCompletionResponse chatCompletionResponse = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
|
||||
|
||||
Response<AiMessage> response = Response.from(
|
||||
aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
|
||||
);
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
chatCompletionResponse.id(),
|
||||
chatCompletionResponse.model(),
|
||||
response
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(modelListenerResponse, modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (RuntimeException e) {
|
||||
|
||||
Throwable error;
|
||||
if (e.getCause() instanceof OpenAiHttpException) {
|
||||
error = e.getCause();
|
||||
} else {
|
||||
error = e;
|
||||
}
|
||||
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(error, null, modelListenerRequest);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -12,21 +12,27 @@ import dev.langchain4j.model.StreamingResponseHandler;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.openai.spi.OpenAiStreamingChatModelBuilderFactory;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.net.Proxy;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
|
@ -34,6 +40,7 @@ import static java.util.Collections.singletonList;
|
|||
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
|
||||
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
@Slf4j
|
||||
public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
private final OpenAiClient client;
|
||||
|
@ -50,6 +57,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
private final String user;
|
||||
private final Tokenizer tokenizer;
|
||||
private final boolean isOpenAiModel;
|
||||
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiStreamingChatModel(String baseUrl,
|
||||
|
@ -71,7 +79,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders) {
|
||||
Map<String, String> customHeaders,
|
||||
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
|
@ -102,6 +111,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
this.user = user;
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.isOpenAiModel = isOpenAiModel(this.modelName);
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
}
|
||||
|
||||
public String modelName() {
|
||||
|
@ -152,25 +162,81 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);
|
||||
|
||||
AtomicReference<String> responseId = new AtomicReference<>();
|
||||
AtomicReference<String> responseModel = new AtomicReference<>();
|
||||
|
||||
client.chatCompletion(request)
|
||||
.onPartialResponse(partialResponse -> {
|
||||
responseBuilder.append(partialResponse);
|
||||
handle(partialResponse, handler);
|
||||
|
||||
if (!isNullOrBlank(partialResponse.id())) {
|
||||
responseId.set(partialResponse.id());
|
||||
}
|
||||
if (!isNullOrBlank(partialResponse.model())) {
|
||||
responseModel.set(partialResponse.model());
|
||||
}
|
||||
})
|
||||
.onComplete(() -> {
|
||||
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
|
||||
if (!isOpenAiModel) {
|
||||
response = removeTokenUsage(response);
|
||||
}
|
||||
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
responseId.get(),
|
||||
responseModel.get(),
|
||||
response
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(modelListenerResponse, modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
handler.onComplete(response);
|
||||
})
|
||||
.onError(handler::onError)
|
||||
.onError(error -> {
|
||||
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
responseId.get(),
|
||||
responseModel.get(),
|
||||
response
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(error, modelListenerResponse, modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
handler.onError(error);
|
||||
})
|
||||
.execute();
|
||||
}
|
||||
|
||||
private Response<AiMessage> createResponse(OpenAiStreamingResponseBuilder responseBuilder,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
|
||||
if (isOpenAiModel) {
|
||||
return response;
|
||||
}
|
||||
return removeTokenUsage(response);
|
||||
}
|
||||
|
||||
private int countInputTokens(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
||||
|
@ -23,6 +28,8 @@ import static dev.langchain4j.model.output.FinishReason.*;
|
|||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Fail.fail;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
class OpenAiChatModelIT {
|
||||
|
||||
|
@ -216,6 +223,8 @@ class OpenAiChatModelIT {
|
|||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_3_5_TURBO_1106) // supports parallel function calling
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = userMessage("2+2=? 3+3=?");
|
||||
|
@ -463,4 +472,132 @@ class OpenAiChatModelIT {
|
|||
// then
|
||||
assertThat(tokenCount).isEqualTo(42);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_listen_request_and_response() {
|
||||
|
||||
// given
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
responseReference.set(response);
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
fail("onError() must not be called");
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiChatModelName modelName = GPT_3_5_TURBO;
|
||||
double temperature = 0.7;
|
||||
double topP = 1.0;
|
||||
int maxTokens = 7;
|
||||
|
||||
OpenAiChatModel model = OpenAiChatModel.builder()
|
||||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(modelName)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.maxTokens(maxTokens)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("hello");
|
||||
|
||||
ToolSpecification toolSpecification = ToolSpecification.builder()
|
||||
.name("add")
|
||||
.addParameter("a", INTEGER)
|
||||
.addParameter("b", INTEGER)
|
||||
.build();
|
||||
|
||||
// when
|
||||
AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
|
||||
|
||||
// then
|
||||
ChatLanguageModelRequest request = requestReference.get();
|
||||
assertThat(request.model()).isEqualTo(modelName.toString());
|
||||
assertThat(request.temperature()).isEqualTo(temperature);
|
||||
assertThat(request.topP()).isEqualTo(topP);
|
||||
assertThat(request.maxTokens()).isEqualTo(maxTokens);
|
||||
assertThat(request.messages()).containsExactly(userMessage);
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
|
||||
ChatLanguageModelResponse response = responseReference.get();
|
||||
assertThat(response.id()).isNotBlank();
|
||||
assertThat(response.model()).isNotBlank();
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.finishReason()).isNotNull();
|
||||
assertThat(response.aiMessage()).isEqualTo(aiMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_listen_error() {
|
||||
|
||||
// given
|
||||
String wrongApiKey = "banana";
|
||||
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<Throwable> errorReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
fail("onResponse() must not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
errorReference.set(error);
|
||||
assertThat(response).isNull();
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiChatModel model = OpenAiChatModel.builder()
|
||||
.apiKey(wrongApiKey)
|
||||
.maxRetries(0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.build();
|
||||
|
||||
String userMessage = "this message will fail";
|
||||
|
||||
// when
|
||||
assertThrows(RuntimeException.class, () -> model.generate(userMessage));
|
||||
|
||||
// then
|
||||
Throwable throwable = errorReference.get();
|
||||
assertThat(throwable).isExactlyInstanceOf(OpenAiHttpException.class);
|
||||
assertThat(throwable).hasMessageContaining("Incorrect API key provided");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.*;
|
||||
|
@ -7,6 +8,9 @@ import dev.langchain4j.model.StreamingResponseHandler;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.assertj.core.data.Percentage;
|
||||
|
@ -15,6 +19,7 @@ import org.junit.jupiter.api.Test;
|
|||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
||||
|
@ -31,6 +36,7 @@ import static java.util.Arrays.asList;
|
|||
import static java.util.Collections.singletonList;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Fail.fail;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
class OpenAiStreamingChatModelIT {
|
||||
|
@ -639,4 +645,154 @@ class OpenAiStreamingChatModelIT {
|
|||
// then
|
||||
assertThat(tokenCount).isEqualTo(42);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_listen_request_and_response() {
|
||||
|
||||
// given
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
responseReference.set(response);
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
fail("onError() must not be called");
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiChatModelName modelName = GPT_3_5_TURBO;
|
||||
double temperature = 0.7;
|
||||
double topP = 1.0;
|
||||
int maxTokens = 7;
|
||||
|
||||
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
|
||||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(modelName)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.maxTokens(maxTokens)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("hello");
|
||||
|
||||
ToolSpecification toolSpecification = ToolSpecification.builder()
|
||||
.name("add")
|
||||
.addParameter("a", INTEGER)
|
||||
.addParameter("b", INTEGER)
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(singletonList(userMessage), singletonList(toolSpecification), handler);
|
||||
AiMessage aiMessage = handler.get().content();
|
||||
|
||||
// then
|
||||
ChatLanguageModelRequest request = requestReference.get();
|
||||
assertThat(request.model()).isEqualTo(modelName.toString());
|
||||
assertThat(request.temperature()).isEqualTo(temperature);
|
||||
assertThat(request.topP()).isEqualTo(topP);
|
||||
assertThat(request.maxTokens()).isEqualTo(maxTokens);
|
||||
assertThat(request.messages()).containsExactly(userMessage);
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
|
||||
ChatLanguageModelResponse response = responseReference.get();
|
||||
assertThat(response.id()).isNotBlank();
|
||||
assertThat(response.model()).isNotBlank();
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.finishReason()).isNotNull();
|
||||
assertThat(response.aiMessage()).isEqualTo(aiMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_listen_error() throws Exception {
|
||||
|
||||
// given
|
||||
String wrongApiKey = "banana";
|
||||
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<Throwable> errorReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
fail("onResponse() must not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
errorReference.set(error);
|
||||
assertThat(response).isNull(); // can be non-null if it fails in the middle of streaming
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
};
|
||||
|
||||
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
|
||||
.apiKey(wrongApiKey)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.build();
|
||||
|
||||
String userMessage = "this message will fail";
|
||||
|
||||
CompletableFuture<Throwable> future = new CompletableFuture<>();
|
||||
StreamingResponseHandler<AiMessage> handler = new StreamingResponseHandler<AiMessage>() {
|
||||
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
fail("onNext() must not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
future.complete(error);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
fail("onComplete() must not be called");
|
||||
}
|
||||
};
|
||||
|
||||
// when
|
||||
model.generate(userMessage, handler);
|
||||
Throwable throwable = future.get(5, SECONDS);
|
||||
|
||||
// then
|
||||
assertThat(throwable).isExactlyInstanceOf(OpenAiHttpException.class);
|
||||
assertThat(throwable).hasMessageContaining("Incorrect API key provided");
|
||||
|
||||
assertThat(errorReference.get()).isSameAs(throwable);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue