LLM Observability: Part 1 (#1058)

## Issue
https://github.com/langchain4j/langchain4j/issues/199


## Change
- Added `ModelListener`, `ChatLanguageModelRequest`, and
`ChatLanguageModelResponse` that are compatible (have all the required
attributes) with [OTEL LLM semconv
draft](https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md).
- Added an option to attach multiple
`ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>` to
`OpenAiChatModel` and `OpenAiStreamingChatModel` (pilot module).

### `ChatLanguageModelRequest`
```java
public class ChatLanguageModelRequest {

    private final String model;
    private final Double temperature;
    private final Double topP;
    private final Integer maxTokens;
    private final List<ChatMessage> messages;
    private final List<ToolSpecification> toolSpecifications;
}
```

### `ChatLanguageModelResponse`
```java
public class ChatLanguageModelResponse {

    private final String id;
    private final String model;
    private final TokenUsage tokenUsage;
    private final FinishReason finishReason;
    private final AiMessage aiMessage;
}
```

## Example
```java
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
                new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {

                    @Override
                    public void onRequest(ChatLanguageModelRequest request) {
                        // handle request
                    }

                    @Override
                    public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
                        // handle response
                    }

                    @Override
                    public void onError(Throwable error, ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
                        // handle error
                    }
                };

        OpenAiChatModel model = OpenAiChatModel.builder()
                .apiKey(...)
                .listeners(singletonList(modelListener))
                .build();
```

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
This commit is contained in:
LangChain4j 2024-05-22 13:14:11 +02:00 committed by GitHub
parent a49ac33519
commit 6818e279bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 636 additions and 15 deletions

View File

@ -125,6 +125,8 @@
<rule>
<excludes>
<exclude>dev.langchain4j.data.document</exclude>
<exclude>dev.langchain4j.model.chat.listener</exclude>
<exclude>dev.langchain4j.model.listener</exclude>
<exclude>dev.langchain4j.store.embedding</exclude>
<exclude>dev.langchain4j.store.embedding.filter</exclude>
<exclude>dev.langchain4j.store.embedding.filter.logical</exclude>

View File

@ -0,0 +1,67 @@
package dev.langchain4j.model.chat.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.listener.ModelListener;
import lombok.Builder;
import java.util.List;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
/**
* A request to the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
* intended to be used with {@link ModelListener}.
*/
@Experimental
public class ChatLanguageModelRequest {
private final String model;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;
private final List<ChatMessage> messages;
private final List<ToolSpecification> toolSpecifications;
@Builder
public ChatLanguageModelRequest(String model,
Double temperature,
Double topP,
Integer maxTokens,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
this.model = model;
this.temperature = temperature;
this.topP = topP;
this.maxTokens = maxTokens;
this.messages = copyIfNotNull(messages);
this.toolSpecifications = copyIfNotNull(toolSpecifications);
}
public String model() {
return model;
}
public Double temperature() {
return temperature;
}
public Double topP() {
return topP;
}
public Integer maxTokens() {
return maxTokens;
}
public List<ChatMessage> messages() {
return messages;
}
public List<ToolSpecification> toolSpecifications() {
return toolSpecifications;
}
}

View File

@ -0,0 +1,57 @@
package dev.langchain4j.model.chat.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
/**
* A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
* intended to be used with {@link ModelListener}.
*/
@Experimental
public class ChatLanguageModelResponse {
private final String id;
private final String model;
private final TokenUsage tokenUsage;
private final FinishReason finishReason;
private final AiMessage aiMessage;
@Builder
public ChatLanguageModelResponse(String id,
String model,
TokenUsage tokenUsage,
FinishReason finishReason,
AiMessage aiMessage) {
this.id = id;
this.model = model;
this.tokenUsage = tokenUsage;
this.finishReason = finishReason;
this.aiMessage = aiMessage;
}
public String id() {
return id;
}
public String model() {
return model;
}
public TokenUsage tokenUsage() {
return tokenUsage;
}
public FinishReason finishReason() {
return finishReason;
}
public AiMessage aiMessage() {
return aiMessage;
}
}

View File

@ -0,0 +1,51 @@
package dev.langchain4j.model.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
/**
* A generic model listener.
* It can listen for requests to and responses from various model types,
* such as {@link ChatLanguageModel}, {@link StreamingChatLanguageModel}, {@link EmbeddingModel}, etc.
*/
@Experimental
public interface ModelListener<Request, Response> {
/**
* This method is called before the request is sent to the model.
*
* @param request The request to the model.
*/
@Experimental
default void onRequest(Request request) {
}
/**
* This method is called after the response is received from the model.
*
* @param response The response from the model.
* @param request The request this response corresponds to.
*/
@Experimental
default void onResponse(Response response, Request request) {
}
/**
* This method is called when an error occurs.
* <br>
* When streaming (e.g., using {@link StreamingChatLanguageModel}),
* the {@code response} might contain a partial response that was received before the error occurred.
*
* @param error The error that occurred.
* @param response The partial response, if available.
* @param request The request this error corresponds to.
*/
@Experimental
default void onError(Throwable error, Response response, Request request) {
}
}

View File

@ -10,6 +10,8 @@ import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
@ -280,4 +282,33 @@ public class InternalOpenAiHelper {
static Response<AiMessage> removeTokenUsage(Response<AiMessage> response) {
return Response.from(response.content(), null, response.finishReason());
}
static ChatLanguageModelRequest createModelListenerRequest(ChatCompletionRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return ChatLanguageModelRequest.builder()
.model(request.model())
.temperature(request.temperature())
.topP(request.topP())
.maxTokens(request.maxTokens())
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
}
static ChatLanguageModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
if (response == null) {
return null;
}
return ChatLanguageModelResponse.builder()
.id(responseId)
.model(responseModel)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build();
}
}

View File

@ -1,6 +1,7 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.agent.tool.ToolSpecification;
@ -9,12 +10,17 @@ import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.net.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -24,12 +30,14 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final OpenAiClient client;
@ -46,6 +54,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final String user;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
@Builder
public OpenAiChatModel(String baseUrl,
@ -68,7 +77,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders) {
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@ -104,6 +114,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.user = user;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
}
public String modelName() {
@ -152,13 +163,56 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatCompletionRequest request = requestBuilder.build();
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
listeners.forEach(listener -> {
try {
listener.onRequest(modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.usage()),
finishReasonFrom(response.choices().get(0).finishReason())
);
try {
ChatCompletionResponse chatCompletionResponse = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response = Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
);
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(),
chatCompletionResponse.model(),
response
);
listeners.forEach(listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
Throwable error;
if (e.getCause() instanceof OpenAiHttpException) {
error = e.getCause();
} else {
error = e;
}
listeners.forEach(listener -> {
try {
listener.onError(error, null, modelListenerRequest);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
}
@Override

View File

@ -12,21 +12,27 @@ import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.openai.spi.OpenAiStreamingChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.net.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
/**
@ -34,6 +40,7 @@ import static java.util.Collections.singletonList;
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {
private final OpenAiClient client;
@ -50,6 +57,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
private final String user;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
@Builder
public OpenAiStreamingChatModel(String baseUrl,
@ -71,7 +79,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders) {
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
timeout = getOrDefault(timeout, ofSeconds(60));
@ -102,6 +111,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
this.user = user;
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.isOpenAiModel = isOpenAiModel(this.modelName);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
}
public String modelName() {
@ -152,25 +162,81 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
ChatCompletionRequest request = requestBuilder.build();
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
listeners.forEach(listener -> {
try {
listener.onRequest(modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);
AtomicReference<String> responseId = new AtomicReference<>();
AtomicReference<String> responseModel = new AtomicReference<>();
client.chatCompletion(request)
.onPartialResponse(partialResponse -> {
responseBuilder.append(partialResponse);
handle(partialResponse, handler);
if (!isNullOrBlank(partialResponse.id())) {
responseId.set(partialResponse.id());
}
if (!isNullOrBlank(partialResponse.model())) {
responseModel.set(partialResponse.model());
}
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
if (!isOpenAiModel) {
response = removeTokenUsage(response);
}
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
responseModel.get(),
response
);
listeners.forEach(listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
handler.onComplete(response);
})
.onError(handler::onError)
.onError(error -> {
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
responseModel.get(),
response
);
listeners.forEach(listener -> {
try {
listener.onError(error, modelListenerResponse, modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
handler.onError(error);
})
.execute();
}
private Response<AiMessage> createResponse(OpenAiStreamingResponseBuilder responseBuilder,
ToolSpecification toolThatMustBeExecuted) {
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
if (isOpenAiModel) {
return response;
}
return removeTokenUsage(response);
}
private int countInputTokens(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {

View File

@ -1,16 +1,21 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
@ -23,6 +28,8 @@ import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.fail;
import static org.junit.jupiter.api.Assertions.assertThrows;
class OpenAiChatModelIT {
@ -216,6 +223,8 @@ class OpenAiChatModelIT {
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106) // supports parallel function calling
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
UserMessage userMessage = userMessage("2+2=? 3+3=?");
@ -463,4 +472,132 @@ class OpenAiChatModelIT {
// then
assertThat(tokenCount).isEqualTo(42);
}
@Test
void should_listen_request_and_response() {
// given
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
responseReference.set(response);
assertThat(request).isSameAs(requestReference.get());
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
fail("onError() must not be called");
}
};
OpenAiChatModelName modelName = GPT_3_5_TURBO;
double temperature = 0.7;
double topP = 1.0;
int maxTokens = 7;
OpenAiChatModel model = OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.build();
UserMessage userMessage = UserMessage.from("hello");
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("add")
.addParameter("a", INTEGER)
.addParameter("b", INTEGER)
.build();
// when
AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
// then
ChatLanguageModelRequest request = requestReference.get();
assertThat(request.model()).isEqualTo(modelName.toString());
assertThat(request.temperature()).isEqualTo(temperature);
assertThat(request.topP()).isEqualTo(topP);
assertThat(request.maxTokens()).isEqualTo(maxTokens);
assertThat(request.messages()).containsExactly(userMessage);
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
ChatLanguageModelResponse response = responseReference.get();
assertThat(response.id()).isNotBlank();
assertThat(response.model()).isNotBlank();
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
assertThat(response.finishReason()).isNotNull();
assertThat(response.aiMessage()).isEqualTo(aiMessage);
}
@Test
void should_listen_error() {
// given
String wrongApiKey = "banana";
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<Throwable> errorReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
fail("onResponse() must not be called");
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
errorReference.set(error);
assertThat(response).isNull();
assertThat(request).isSameAs(requestReference.get());
}
};
OpenAiChatModel model = OpenAiChatModel.builder()
.apiKey(wrongApiKey)
.maxRetries(0)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.build();
String userMessage = "this message will fail";
// when
assertThrows(RuntimeException.class, () -> model.generate(userMessage));
// then
Throwable throwable = errorReference.get();
assertThat(throwable).isExactlyInstanceOf(OpenAiHttpException.class);
assertThat(throwable).hasMessageContaining("Incorrect API key provided");
}
}

View File

@ -1,5 +1,6 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
@ -7,6 +8,9 @@ import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
@ -15,6 +19,7 @@ import org.junit.jupiter.api.Test;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
@ -31,6 +36,7 @@ import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.fail;
import static org.assertj.core.data.Percentage.withPercentage;
class OpenAiStreamingChatModelIT {
@ -639,4 +645,154 @@ class OpenAiStreamingChatModelIT {
// then
assertThat(tokenCount).isEqualTo(42);
}
@Test
void should_listen_request_and_response() {
// given
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
responseReference.set(response);
assertThat(request).isSameAs(requestReference.get());
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
fail("onError() must not be called");
}
};
OpenAiChatModelName modelName = GPT_3_5_TURBO;
double temperature = 0.7;
double topP = 1.0;
int maxTokens = 7;
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.build();
UserMessage userMessage = UserMessage.from("hello");
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("add")
.addParameter("a", INTEGER)
.addParameter("b", INTEGER)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), singletonList(toolSpecification), handler);
AiMessage aiMessage = handler.get().content();
// then
ChatLanguageModelRequest request = requestReference.get();
assertThat(request.model()).isEqualTo(modelName.toString());
assertThat(request.temperature()).isEqualTo(temperature);
assertThat(request.topP()).isEqualTo(topP);
assertThat(request.maxTokens()).isEqualTo(maxTokens);
assertThat(request.messages()).containsExactly(userMessage);
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
ChatLanguageModelResponse response = responseReference.get();
assertThat(response.id()).isNotBlank();
assertThat(response.model()).isNotBlank();
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
assertThat(response.finishReason()).isNotNull();
assertThat(response.aiMessage()).isEqualTo(aiMessage);
}
@Test
void should_listen_error() throws Exception {
// given
String wrongApiKey = "banana";
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<Throwable> errorReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
fail("onResponse() must not be called");
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
errorReference.set(error);
assertThat(response).isNull(); // can be non-null if it fails in the middle of streaming
assertThat(request).isSameAs(requestReference.get());
}
};
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
.apiKey(wrongApiKey)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.build();
String userMessage = "this message will fail";
CompletableFuture<Throwable> future = new CompletableFuture<>();
StreamingResponseHandler<AiMessage> handler = new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
fail("onNext() must not be called");
}
@Override
public void onError(Throwable error) {
future.complete(error);
}
@Override
public void onComplete(Response<AiMessage> response) {
fail("onComplete() must not be called");
}
};
// when
model.generate(userMessage, handler);
Throwable throwable = future.get(5, SECONDS);
// then
assertThat(throwable).isExactlyInstanceOf(OpenAiHttpException.class);
assertThat(throwable).hasMessageContaining("Incorrect API key provided");
assertThat(errorReference.get()).isSameAs(throwable);
}
}