Updated model listener API to be more integration friendly (#1229)

## Issue
[ModelListener](https://github.com/langchain4j/langchain4j/pull/1058)
does not allow passing data between methods of the same listener or
between multiple listeners.

## Change
- added attributes to allow passing data between listener methods or
multiple listeners
- changed generic model listener to model-specific listener


## General checklist
- [ ] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
This commit is contained in:
LangChain4j 2024-06-06 13:26:53 +02:00 committed by GitHub
parent caa125b657
commit e8ae23b51c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 368 additions and 202 deletions

View File

@ -0,0 +1,65 @@
package dev.langchain4j.model.chat.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import java.util.Map;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
/**
* The error context. It contains the error, corresponding {@link ChatModelRequest},
* partial {@link ChatModelResponse} (if available) and attributes.
* The attributes can be used to pass data between methods of a {@link ChatModelListener}
* or between multiple {@link ChatModelListener}s.
*/
@Experimental
public class ChatModelErrorContext {
private final Throwable error;
private final ChatModelRequest request;
private final ChatModelResponse partialResponse;
private final Map<Object, Object> attributes;
public ChatModelErrorContext(Throwable error,
ChatModelRequest request,
ChatModelResponse partialResponse,
Map<Object, Object> attributes) {
this.error = ensureNotNull(error, "error");
this.request = ensureNotNull(request, "request");
this.partialResponse = partialResponse;
this.attributes = ensureNotNull(attributes, "attributes");
}
/**
* @return The error that occurred.
*/
public Throwable error() {
return error;
}
/**
* @return The request to the {@link ChatLanguageModel} the error corresponds to.
*/
public ChatModelRequest request() {
return request;
}
/**
* @return The partial response from the {@link ChatLanguageModel}, if available.
* When used with {@link StreamingChatLanguageModel}, it might contain the tokens
* that were received before the error occurred.
*/
public ChatModelResponse partialResponse() {
return partialResponse;
}
/**
* @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener}
* or between multiple {@link ChatModelListener}s.
*/
public Map<Object, Object> attributes() {
return attributes;
}
}

View File

@ -0,0 +1,50 @@
package dev.langchain4j.model.chat.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.ChatLanguageModel;
/**
* A {@link ChatLanguageModel} listener that listens for requests, responses and errors.
*/
@Experimental
public interface ChatModelListener {
/**
* This method is called before the request is sent to the model.
*
* @param requestContext The request context. It contains the {@link ChatModelRequest} and attributes.
* The attributes can be used to pass data between methods of this listener
* or between multiple listeners.
*/
@Experimental
default void onRequest(ChatModelRequestContext requestContext) {
}
/**
* This method is called after the response is received from the model.
*
* @param responseContext The response context.
* It contains {@link ChatModelResponse}, corresponding {@link ChatModelRequest} and attributes.
* The attributes can be used to pass data between methods of this listener
* or between multiple listeners.
*/
@Experimental
default void onResponse(ChatModelResponseContext responseContext) {
}
/**
* This method is called when an error occurs during interaction with the model.
*
* @param errorContext The error context.
* It contains the error, corresponding {@link ChatModelRequest},
* partial {@link ChatModelResponse} (if available) and attributes.
* The attributes can be used to pass data between methods of this listener
* or between multiple listeners.
*/
@Experimental
default void onError(ChatModelErrorContext errorContext) {
}
}

View File

@ -5,7 +5,6 @@ import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.listener.ModelListener;
import lombok.Builder;
import java.util.List;
@ -14,10 +13,10 @@ import static dev.langchain4j.internal.Utils.copyIfNotNull;
/**
* A request to the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
* intended to be used with {@link ModelListener}.
* intended to be used with {@link ChatModelListener}.
*/
@Experimental
public class ChatLanguageModelRequest {
public class ChatModelRequest {
private final String model;
private final Double temperature;
@ -27,12 +26,12 @@ public class ChatLanguageModelRequest {
private final List<ToolSpecification> toolSpecifications;
@Builder
public ChatLanguageModelRequest(String model,
Double temperature,
Double topP,
Integer maxTokens,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
public ChatModelRequest(String model,
Double temperature,
Double topP,
Integer maxTokens,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
this.model = model;
this.temperature = temperature;
this.topP = topP;

View File

@ -0,0 +1,40 @@
package dev.langchain4j.model.chat.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Map;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
/**
* The request context. It contains the {@link ChatModelRequest} and attributes.
* The attributes can be used to pass data between methods of a {@link ChatModelListener}
* or between multiple {@link ChatModelListener}s.
*/
@Experimental
public class ChatModelRequestContext {
private final ChatModelRequest request;
private final Map<Object, Object> attributes;
public ChatModelRequestContext(ChatModelRequest request, Map<Object, Object> attributes) {
this.request = ensureNotNull(request, "request");
this.attributes = ensureNotNull(attributes, "attributes");
}
/**
* @return The request to the {@link ChatLanguageModel}.
*/
public ChatModelRequest request() {
return request;
}
/**
* @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener}
* or between multiple {@link ChatModelListener}s.
*/
public Map<Object, Object> attributes() {
return attributes;
}
}

View File

@ -4,17 +4,16 @@ import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
/**
* A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
* intended to be used with {@link ModelListener}.
* intended to be used with {@link ChatModelListener}.
*/
@Experimental
public class ChatLanguageModelResponse {
public class ChatModelResponse {
private final String id;
private final String model;
@ -23,11 +22,11 @@ public class ChatLanguageModelResponse {
private final AiMessage aiMessage;
@Builder
public ChatLanguageModelResponse(String id,
String model,
TokenUsage tokenUsage,
FinishReason finishReason,
AiMessage aiMessage) {
public ChatModelResponse(String id,
String model,
TokenUsage tokenUsage,
FinishReason finishReason,
AiMessage aiMessage) {
this.id = id;
this.model = model;
this.tokenUsage = tokenUsage;

View File

@ -0,0 +1,51 @@
package dev.langchain4j.model.chat.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Map;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
/**
* The response context. It contains {@link ChatModelResponse}, corresponding {@link ChatModelRequest} and attributes.
* The attributes can be used to pass data between methods of a {@link ChatModelListener}
* or between multiple {@link ChatModelListener}s.
*/
@Experimental
public class ChatModelResponseContext {
private final ChatModelResponse response;
private final ChatModelRequest request;
private final Map<Object, Object> attributes;
public ChatModelResponseContext(ChatModelResponse response,
ChatModelRequest request,
Map<Object, Object> attributes) {
this.response = ensureNotNull(response, "response");
this.request = ensureNotNull(request, "request");
this.attributes = ensureNotNull(attributes, "attributes");
}
/**
* @return The response from the {@link ChatLanguageModel}.
*/
public ChatModelResponse response() {
return response;
}
/**
* @return The request to the {@link ChatLanguageModel} the response corresponds to.
*/
public ChatModelRequest request() {
return request;
}
/**
* @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener}
* or between multiple {@link ChatModelListener}s.
*/
public Map<Object, Object> attributes() {
return attributes;
}
}

View File

@ -1,51 +0,0 @@
package dev.langchain4j.model.listener;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
/**
* A generic model listener.
* It can listen for requests to and responses from various model types,
* such as {@link ChatLanguageModel}, {@link StreamingChatLanguageModel}, {@link EmbeddingModel}, etc.
*/
@Experimental
public interface ModelListener<Request, Response> {
/**
* This method is called before the request is sent to the model.
*
* @param request The request to the model.
*/
@Experimental
default void onRequest(Request request) {
}
/**
* This method is called after the response is received from the model.
*
* @param response The response from the model.
* @param request The request this response corresponds to.
*/
@Experimental
default void onResponse(Response response, Request request) {
}
/**
* This method is called when an error occurs.
* <br>
* When streaming (e.g., using {@link StreamingChatLanguageModel}),
* the {@code response} might contain a partial response that was received before the error occurred.
*
* @param error The error that occurred.
* @param response The partial response, if available.
* @param request The request this error corresponds to.
*/
@Experimental
default void onError(Throwable error, Response response, Request request) {
}
}

View File

@ -10,8 +10,8 @@ import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
@ -283,10 +283,10 @@ public class InternalOpenAiHelper {
return Response.from(response.content(), null, response.finishReason());
}
static ChatLanguageModelRequest createModelListenerRequest(ChatCompletionRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return ChatLanguageModelRequest.builder()
static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return ChatModelRequest.builder()
.model(request.model())
.temperature(request.temperature())
.topP(request.topP())
@ -296,14 +296,14 @@ public class InternalOpenAiHelper {
.build();
}
static ChatLanguageModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
static ChatModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
if (response == null) {
return null;
}
return ChatLanguageModelResponse.builder()
return ChatModelResponse.builder()
.id(responseId)
.model(responseModel)
.tokenUsage(response.tokenUsage())

View File

@ -10,9 +10,7 @@ import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
@ -23,6 +21,7 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
@ -54,7 +53,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final String user;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(String baseUrl,
@ -78,7 +77,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@ -163,10 +162,12 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatCompletionRequest request = requestBuilder.build();
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(modelListenerRequest);
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
@ -181,14 +182,19 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
);
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
ChatModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(),
chatCompletionResponse.model(),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
@ -204,13 +210,21 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
error = e;
}
ChatModelErrorContext errorContext = new ChatModelErrorContext(
error,
modelListenerRequest,
null,
attributes
);
listeners.forEach(listener -> {
try {
listener.onError(error, null, modelListenerRequest);
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
}

View File

@ -12,9 +12,7 @@ import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.openai.spi.OpenAiStreamingChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
@ -25,6 +23,7 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import static dev.langchain4j.internal.Utils.*;
@ -57,7 +56,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
private final String user;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
private final List<ChatModelListener> listeners;
@Builder
public OpenAiStreamingChatModel(String baseUrl,
@ -80,7 +79,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
List<ChatModelListener> listeners) {
timeout = getOrDefault(timeout, ofSeconds(60));
@ -162,10 +161,12 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
ChatCompletionRequest request = requestBuilder.build();
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(modelListenerRequest);
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
@ -192,14 +193,19 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
.onComplete(() -> {
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
responseModel.get(),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
@ -210,14 +216,22 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
.onError(error -> {
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
responseId.get(),
responseModel.get(),
response
);
ChatModelErrorContext errorContext = new ChatModelErrorContext(
error,
modelListenerRequest,
modelListenerPartialResponse,
attributes
);
listeners.forEach(listener -> {
try {
listener.onError(error, modelListenerResponse, modelListenerRequest);
listener.onError(errorContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}

View File

@ -3,17 +3,10 @@ package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
@ -29,9 +22,7 @@ import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_1106;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
@ -484,30 +475,29 @@ class OpenAiChatModelIT {
void should_listen_request_and_response() {
// given
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatModelResponse> responseReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
ChatModelListener listener = new ChatModelListener() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onRequest(ChatModelRequestContext requestContext) {
requestReference.set(requestContext.request());
requestContext.attributes().put("id", "12345");
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
responseReference.set(response);
assertThat(request).isSameAs(requestReference.get());
}
@Override
public void onResponse(ChatModelResponseContext responseContext) {
responseReference.set(responseContext.response());
assertThat(responseContext.request()).isSameAs(requestReference.get());
assertThat(responseContext.attributes().get("id")).isEqualTo("12345");
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
fail("onError() must not be called");
}
};
@Override
public void onError(ChatModelErrorContext errorContext) {
fail("onError() must not be called");
}
};
OpenAiChatModelName modelName = GPT_3_5_TURBO;
double temperature = 0.7;
@ -524,7 +514,7 @@ class OpenAiChatModelIT {
.maxTokens(maxTokens)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.listeners(singletonList(listener))
.build();
UserMessage userMessage = UserMessage.from("hello");
@ -539,7 +529,7 @@ class OpenAiChatModelIT {
AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
// then
ChatLanguageModelRequest request = requestReference.get();
ChatModelRequest request = requestReference.get();
assertThat(request.model()).isEqualTo(modelName.toString());
assertThat(request.temperature()).isEqualTo(temperature);
assertThat(request.topP()).isEqualTo(topP);
@ -547,7 +537,7 @@ class OpenAiChatModelIT {
assertThat(request.messages()).containsExactly(userMessage);
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
ChatLanguageModelResponse response = responseReference.get();
ChatModelResponse response = responseReference.get();
assertThat(response.id()).isNotBlank();
assertThat(response.model()).isNotBlank();
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
@ -563,38 +553,37 @@ class OpenAiChatModelIT {
// given
String wrongApiKey = "banana";
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
AtomicReference<Throwable> errorReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
ChatModelListener listener = new ChatModelListener() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onRequest(ChatModelRequestContext requestContext) {
requestReference.set(requestContext.request());
requestContext.attributes().put("id", "12345");
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
fail("onResponse() must not be called");
}
@Override
public void onResponse(ChatModelResponseContext responseContext) {
fail("onResponse() must not be called");
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
errorReference.set(error);
assertThat(response).isNull();
assertThat(request).isSameAs(requestReference.get());
}
};
@Override
public void onError(ChatModelErrorContext errorContext) {
errorReference.set(errorContext.error());
assertThat(errorContext.request()).isSameAs(requestReference.get());
assertThat(errorContext.partialResponse()).isNull();
assertThat(errorContext.attributes().get("id")).isEqualTo("12345");
}
};
OpenAiChatModel model = OpenAiChatModel.builder()
.apiKey(wrongApiKey)
.maxRetries(0)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.listeners(singletonList(listener))
.build();
String userMessage = "this message will fail";

View File

@ -8,9 +8,7 @@ import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
@ -650,30 +648,29 @@ class OpenAiStreamingChatModelIT {
void should_listen_request_and_response() {
// given
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatModelResponse> responseReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
ChatModelListener listener = new ChatModelListener() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onRequest(ChatModelRequestContext requestContext) {
requestReference.set(requestContext.request());
requestContext.attributes().put("id", "12345");
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
responseReference.set(response);
assertThat(request).isSameAs(requestReference.get());
}
@Override
public void onResponse(ChatModelResponseContext responseContext) {
responseReference.set(responseContext.response());
assertThat(responseContext.request()).isSameAs(requestReference.get());
assertThat(responseContext.attributes().get("id")).isEqualTo("12345");
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
fail("onError() must not be called");
}
};
@Override
public void onError(ChatModelErrorContext errorContext) {
fail("onError() must not be called");
}
};
OpenAiChatModelName modelName = GPT_3_5_TURBO;
double temperature = 0.7;
@ -690,7 +687,7 @@ class OpenAiStreamingChatModelIT {
.maxTokens(maxTokens)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.listeners(singletonList(listener))
.build();
UserMessage userMessage = UserMessage.from("hello");
@ -707,7 +704,7 @@ class OpenAiStreamingChatModelIT {
AiMessage aiMessage = handler.get().content();
// then
ChatLanguageModelRequest request = requestReference.get();
ChatModelRequest request = requestReference.get();
assertThat(request.model()).isEqualTo(modelName.toString());
assertThat(request.temperature()).isEqualTo(temperature);
assertThat(request.topP()).isEqualTo(topP);
@ -715,7 +712,7 @@ class OpenAiStreamingChatModelIT {
assertThat(request.messages()).containsExactly(userMessage);
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
ChatLanguageModelResponse response = responseReference.get();
ChatModelResponse response = responseReference.get();
assertThat(response.id()).isNotBlank();
assertThat(response.model()).isNotBlank();
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
@ -731,37 +728,36 @@ class OpenAiStreamingChatModelIT {
// given
String wrongApiKey = "banana";
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
AtomicReference<Throwable> errorReference = new AtomicReference<>();
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
ChatModelListener listener = new ChatModelListener() {
@Override
public void onRequest(ChatLanguageModelRequest request) {
requestReference.set(request);
}
@Override
public void onRequest(ChatModelRequestContext requestContext) {
requestReference.set(requestContext.request());
requestContext.attributes().put("id", "12345");
}
@Override
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
fail("onResponse() must not be called");
}
@Override
public void onResponse(ChatModelResponseContext responseContext) {
fail("onResponse() must not be called");
}
@Override
public void onError(Throwable error,
ChatLanguageModelResponse response,
ChatLanguageModelRequest request) {
errorReference.set(error);
assertThat(response).isNull(); // can be non-null if it fails in the middle of streaming
assertThat(request).isSameAs(requestReference.get());
}
};
@Override
public void onError(ChatModelErrorContext errorContext) {
errorReference.set(errorContext.error());
assertThat(errorContext.request()).isSameAs(requestReference.get());
assertThat(errorContext.partialResponse()).isNull(); // can be non-null if it fails in the middle of streaming
assertThat(errorContext.attributes().get("id")).isEqualTo("12345");
}
};
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
.apiKey(wrongApiKey)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(modelListener))
.listeners(singletonList(listener))
.build();
String userMessage = "this message will fail";