Updated model listener API to be more integration friendly (#1229)
## Issue [ModelListener](https://github.com/langchain4j/langchain4j/pull/1058) does not allow passing data between methods of the same listener or between multiple listeners. ## Change - added attributes to allow passing data between listener methods or multiple listeners - changed generic model listener to model-specific listener ## General checklist - [ ] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features)
This commit is contained in:
parent
caa125b657
commit
e8ae23b51c
|
@ -0,0 +1,65 @@
|
|||
package dev.langchain4j.model.chat.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
/**
|
||||
* The error context. It contains the error, corresponding {@link ChatModelRequest},
|
||||
* partial {@link ChatModelResponse} (if available) and attributes.
|
||||
* The attributes can be used to pass data between methods of a {@link ChatModelListener}
|
||||
* or between multiple {@link ChatModelListener}s.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatModelErrorContext {
|
||||
|
||||
private final Throwable error;
|
||||
private final ChatModelRequest request;
|
||||
private final ChatModelResponse partialResponse;
|
||||
private final Map<Object, Object> attributes;
|
||||
|
||||
public ChatModelErrorContext(Throwable error,
|
||||
ChatModelRequest request,
|
||||
ChatModelResponse partialResponse,
|
||||
Map<Object, Object> attributes) {
|
||||
this.error = ensureNotNull(error, "error");
|
||||
this.request = ensureNotNull(request, "request");
|
||||
this.partialResponse = partialResponse;
|
||||
this.attributes = ensureNotNull(attributes, "attributes");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The error that occurred.
|
||||
*/
|
||||
public Throwable error() {
|
||||
return error;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The request to the {@link ChatLanguageModel} the error corresponds to.
|
||||
*/
|
||||
public ChatModelRequest request() {
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The partial response from the {@link ChatLanguageModel}, if available.
|
||||
* When used with {@link StreamingChatLanguageModel}, it might contain the tokens
|
||||
* that were received before the error occurred.
|
||||
*/
|
||||
public ChatModelResponse partialResponse() {
|
||||
return partialResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener}
|
||||
* or between multiple {@link ChatModelListener}s.
|
||||
*/
|
||||
public Map<Object, Object> attributes() {
|
||||
return attributes;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package dev.langchain4j.model.chat.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
/**
|
||||
* A {@link ChatLanguageModel} listener that listens for requests, responses and errors.
|
||||
*/
|
||||
@Experimental
|
||||
public interface ChatModelListener {
|
||||
|
||||
/**
|
||||
* This method is called before the request is sent to the model.
|
||||
*
|
||||
* @param requestContext The request context. It contains the {@link ChatModelRequest} and attributes.
|
||||
* The attributes can be used to pass data between methods of this listener
|
||||
* or between multiple listeners.
|
||||
*/
|
||||
@Experimental
|
||||
default void onRequest(ChatModelRequestContext requestContext) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called after the response is received from the model.
|
||||
*
|
||||
* @param responseContext The response context.
|
||||
* It contains {@link ChatModelResponse}, corresponding {@link ChatModelRequest} and attributes.
|
||||
* The attributes can be used to pass data between methods of this listener
|
||||
* or between multiple listeners.
|
||||
*/
|
||||
@Experimental
|
||||
default void onResponse(ChatModelResponseContext responseContext) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called when an error occurs during interaction with the model.
|
||||
*
|
||||
* @param errorContext The error context.
|
||||
* It contains the error, corresponding {@link ChatModelRequest},
|
||||
* partial {@link ChatModelResponse} (if available) and attributes.
|
||||
* The attributes can be used to pass data between methods of this listener
|
||||
* or between multiple listeners.
|
||||
*/
|
||||
@Experimental
|
||||
default void onError(ChatModelErrorContext errorContext) {
|
||||
|
||||
}
|
||||
}
|
|
@ -5,7 +5,6 @@ import dev.langchain4j.agent.tool.ToolSpecification;
|
|||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -14,10 +13,10 @@ import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
|||
|
||||
/**
|
||||
* A request to the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
|
||||
* intended to be used with {@link ModelListener}.
|
||||
* intended to be used with {@link ChatModelListener}.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatLanguageModelRequest {
|
||||
public class ChatModelRequest {
|
||||
|
||||
private final String model;
|
||||
private final Double temperature;
|
||||
|
@ -27,12 +26,12 @@ public class ChatLanguageModelRequest {
|
|||
private final List<ToolSpecification> toolSpecifications;
|
||||
|
||||
@Builder
|
||||
public ChatLanguageModelRequest(String model,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
Integer maxTokens,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
public ChatModelRequest(String model,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
Integer maxTokens,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
this.model = model;
|
||||
this.temperature = temperature;
|
||||
this.topP = topP;
|
|
@ -0,0 +1,40 @@
|
|||
package dev.langchain4j.model.chat.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
/**
|
||||
* The request context. It contains the {@link ChatModelRequest} and attributes.
|
||||
* The attributes can be used to pass data between methods of a {@link ChatModelListener}
|
||||
* or between multiple {@link ChatModelListener}s.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatModelRequestContext {
|
||||
|
||||
private final ChatModelRequest request;
|
||||
private final Map<Object, Object> attributes;
|
||||
|
||||
public ChatModelRequestContext(ChatModelRequest request, Map<Object, Object> attributes) {
|
||||
this.request = ensureNotNull(request, "request");
|
||||
this.attributes = ensureNotNull(attributes, "attributes");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The request to the {@link ChatLanguageModel}.
|
||||
*/
|
||||
public ChatModelRequest request() {
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener}
|
||||
* or between multiple {@link ChatModelListener}s.
|
||||
*/
|
||||
public Map<Object, Object> attributes() {
|
||||
return attributes;
|
||||
}
|
||||
}
|
|
@ -4,17 +4,16 @@ import dev.langchain4j.Experimental;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import lombok.Builder;
|
||||
|
||||
/**
|
||||
* A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
|
||||
* intended to be used with {@link ModelListener}.
|
||||
* intended to be used with {@link ChatModelListener}.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatLanguageModelResponse {
|
||||
public class ChatModelResponse {
|
||||
|
||||
private final String id;
|
||||
private final String model;
|
||||
|
@ -23,11 +22,11 @@ public class ChatLanguageModelResponse {
|
|||
private final AiMessage aiMessage;
|
||||
|
||||
@Builder
|
||||
public ChatLanguageModelResponse(String id,
|
||||
String model,
|
||||
TokenUsage tokenUsage,
|
||||
FinishReason finishReason,
|
||||
AiMessage aiMessage) {
|
||||
public ChatModelResponse(String id,
|
||||
String model,
|
||||
TokenUsage tokenUsage,
|
||||
FinishReason finishReason,
|
||||
AiMessage aiMessage) {
|
||||
this.id = id;
|
||||
this.model = model;
|
||||
this.tokenUsage = tokenUsage;
|
|
@ -0,0 +1,51 @@
|
|||
package dev.langchain4j.model.chat.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
/**
|
||||
* The response context. It contains {@link ChatModelResponse}, corresponding {@link ChatModelRequest} and attributes.
|
||||
* The attributes can be used to pass data between methods of a {@link ChatModelListener}
|
||||
* or between multiple {@link ChatModelListener}s.
|
||||
*/
|
||||
@Experimental
|
||||
public class ChatModelResponseContext {
|
||||
|
||||
private final ChatModelResponse response;
|
||||
private final ChatModelRequest request;
|
||||
private final Map<Object, Object> attributes;
|
||||
|
||||
public ChatModelResponseContext(ChatModelResponse response,
|
||||
ChatModelRequest request,
|
||||
Map<Object, Object> attributes) {
|
||||
this.response = ensureNotNull(response, "response");
|
||||
this.request = ensureNotNull(request, "request");
|
||||
this.attributes = ensureNotNull(attributes, "attributes");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The response from the {@link ChatLanguageModel}.
|
||||
*/
|
||||
public ChatModelResponse response() {
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The request to the {@link ChatLanguageModel} the response corresponds to.
|
||||
*/
|
||||
public ChatModelRequest request() {
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener}
|
||||
* or between multiple {@link ChatModelListener}s.
|
||||
*/
|
||||
public Map<Object, Object> attributes() {
|
||||
return attributes;
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
package dev.langchain4j.model.listener;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
/**
|
||||
* A generic model listener.
|
||||
* It can listen for requests to and responses from various model types,
|
||||
* such as {@link ChatLanguageModel}, {@link StreamingChatLanguageModel}, {@link EmbeddingModel}, etc.
|
||||
*/
|
||||
@Experimental
|
||||
public interface ModelListener<Request, Response> {
|
||||
|
||||
/**
|
||||
* This method is called before the request is sent to the model.
|
||||
*
|
||||
* @param request The request to the model.
|
||||
*/
|
||||
@Experimental
|
||||
default void onRequest(Request request) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called after the response is received from the model.
|
||||
*
|
||||
* @param response The response from the model.
|
||||
* @param request The request this response corresponds to.
|
||||
*/
|
||||
@Experimental
|
||||
default void onResponse(Response response, Request request) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called when an error occurs.
|
||||
* <br>
|
||||
* When streaming (e.g., using {@link StreamingChatLanguageModel}),
|
||||
* the {@code response} might contain a partial response that was received before the error occurred.
|
||||
*
|
||||
* @param error The error that occurred.
|
||||
* @param response The partial response, if available.
|
||||
* @param request The request this error corresponds to.
|
||||
*/
|
||||
@Experimental
|
||||
default void onError(Throwable error, Response response, Request request) {
|
||||
|
||||
}
|
||||
}
|
|
@ -10,8 +10,8 @@ import dev.langchain4j.data.message.Content;
|
|||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
@ -283,10 +283,10 @@ public class InternalOpenAiHelper {
|
|||
return Response.from(response.content(), null, response.finishReason());
|
||||
}
|
||||
|
||||
static ChatLanguageModelRequest createModelListenerRequest(ChatCompletionRequest request,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
return ChatLanguageModelRequest.builder()
|
||||
static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
return ChatModelRequest.builder()
|
||||
.model(request.model())
|
||||
.temperature(request.temperature())
|
||||
.topP(request.topP())
|
||||
|
@ -296,14 +296,14 @@ public class InternalOpenAiHelper {
|
|||
.build();
|
||||
}
|
||||
|
||||
static ChatLanguageModelResponse createModelListenerResponse(String responseId,
|
||||
String responseModel,
|
||||
Response<AiMessage> response) {
|
||||
static ChatModelResponse createModelListenerResponse(String responseId,
|
||||
String responseModel,
|
||||
Response<AiMessage> response) {
|
||||
if (response == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return ChatLanguageModelResponse.builder()
|
||||
return ChatModelResponse.builder()
|
||||
.id(responseId)
|
||||
.model(responseModel)
|
||||
.tokenUsage(response.tokenUsage())
|
||||
|
|
|
@ -10,9 +10,7 @@ import dev.langchain4j.data.message.ChatMessage;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
|
@ -23,6 +21,7 @@ import java.time.Duration;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
|
@ -54,7 +53,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
private final String user;
|
||||
private final Integer maxRetries;
|
||||
private final Tokenizer tokenizer;
|
||||
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiChatModel(String baseUrl,
|
||||
|
@ -78,7 +77,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders,
|
||||
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
|
||||
List<ChatModelListener> listeners) {
|
||||
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
|
@ -163,10 +162,12 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(modelListenerRequest);
|
||||
listener.onRequest(requestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
|
@ -181,14 +182,19 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
|
||||
);
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
ChatModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
chatCompletionResponse.id(),
|
||||
chatCompletionResponse.model(),
|
||||
response
|
||||
);
|
||||
ChatModelResponseContext responseContext = new ChatModelResponseContext(
|
||||
modelListenerResponse,
|
||||
modelListenerRequest,
|
||||
attributes
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(modelListenerResponse, modelListenerRequest);
|
||||
listener.onResponse(responseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
|
@ -204,13 +210,21 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
error = e;
|
||||
}
|
||||
|
||||
ChatModelErrorContext errorContext = new ChatModelErrorContext(
|
||||
error,
|
||||
modelListenerRequest,
|
||||
null,
|
||||
attributes
|
||||
);
|
||||
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(error, null, modelListenerRequest);
|
||||
listener.onError(errorContext);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,9 +12,7 @@ import dev.langchain4j.model.StreamingResponseHandler;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.openai.spi.OpenAiStreamingChatModelBuilderFactory;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
|
@ -25,6 +23,7 @@ import java.time.Duration;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
|
@ -57,7 +56,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
private final String user;
|
||||
private final Tokenizer tokenizer;
|
||||
private final boolean isOpenAiModel;
|
||||
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiStreamingChatModel(String baseUrl,
|
||||
|
@ -80,7 +79,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders,
|
||||
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
|
||||
List<ChatModelListener> listeners) {
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
|
@ -162,10 +161,12 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
ChatLanguageModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(modelListenerRequest);
|
||||
listener.onRequest(requestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
|
@ -192,14 +193,19 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
.onComplete(() -> {
|
||||
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
ChatModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
responseId.get(),
|
||||
responseModel.get(),
|
||||
response
|
||||
);
|
||||
ChatModelResponseContext responseContext = new ChatModelResponseContext(
|
||||
modelListenerResponse,
|
||||
modelListenerRequest,
|
||||
attributes
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(modelListenerResponse, modelListenerRequest);
|
||||
listener.onResponse(responseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
|
@ -210,14 +216,22 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
.onError(error -> {
|
||||
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
|
||||
responseId.get(),
|
||||
responseModel.get(),
|
||||
response
|
||||
);
|
||||
|
||||
ChatModelErrorContext errorContext = new ChatModelErrorContext(
|
||||
error,
|
||||
modelListenerRequest,
|
||||
modelListenerPartialResponse,
|
||||
attributes
|
||||
);
|
||||
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(error, modelListenerResponse, modelListenerRequest);
|
||||
listener.onError(errorContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
|
|
|
@ -3,17 +3,10 @@ package dev.langchain4j.model.openai;
|
|||
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.TextContent;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -29,9 +22,7 @@ import static dev.langchain4j.internal.Utils.readBytes;
|
|||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_1106;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -484,30 +475,29 @@ class OpenAiChatModelIT {
|
|||
void should_listen_request_and_response() {
|
||||
|
||||
// given
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
|
||||
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatModelResponse> responseReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
ChatModelListener listener = new ChatModelListener() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
@Override
|
||||
public void onRequest(ChatModelRequestContext requestContext) {
|
||||
requestReference.set(requestContext.request());
|
||||
requestContext.attributes().put("id", "12345");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
responseReference.set(response);
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
@Override
|
||||
public void onResponse(ChatModelResponseContext responseContext) {
|
||||
responseReference.set(responseContext.response());
|
||||
assertThat(responseContext.request()).isSameAs(requestReference.get());
|
||||
assertThat(responseContext.attributes().get("id")).isEqualTo("12345");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
fail("onError() must not be called");
|
||||
}
|
||||
};
|
||||
@Override
|
||||
public void onError(ChatModelErrorContext errorContext) {
|
||||
fail("onError() must not be called");
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiChatModelName modelName = GPT_3_5_TURBO;
|
||||
double temperature = 0.7;
|
||||
|
@ -524,7 +514,7 @@ class OpenAiChatModelIT {
|
|||
.maxTokens(maxTokens)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("hello");
|
||||
|
@ -539,7 +529,7 @@ class OpenAiChatModelIT {
|
|||
AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
|
||||
|
||||
// then
|
||||
ChatLanguageModelRequest request = requestReference.get();
|
||||
ChatModelRequest request = requestReference.get();
|
||||
assertThat(request.model()).isEqualTo(modelName.toString());
|
||||
assertThat(request.temperature()).isEqualTo(temperature);
|
||||
assertThat(request.topP()).isEqualTo(topP);
|
||||
|
@ -547,7 +537,7 @@ class OpenAiChatModelIT {
|
|||
assertThat(request.messages()).containsExactly(userMessage);
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
|
||||
ChatLanguageModelResponse response = responseReference.get();
|
||||
ChatModelResponse response = responseReference.get();
|
||||
assertThat(response.id()).isNotBlank();
|
||||
assertThat(response.model()).isNotBlank();
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
|
||||
|
@ -563,38 +553,37 @@ class OpenAiChatModelIT {
|
|||
// given
|
||||
String wrongApiKey = "banana";
|
||||
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<Throwable> errorReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
ChatModelListener listener = new ChatModelListener() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
@Override
|
||||
public void onRequest(ChatModelRequestContext requestContext) {
|
||||
requestReference.set(requestContext.request());
|
||||
requestContext.attributes().put("id", "12345");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
fail("onResponse() must not be called");
|
||||
}
|
||||
@Override
|
||||
public void onResponse(ChatModelResponseContext responseContext) {
|
||||
fail("onResponse() must not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
errorReference.set(error);
|
||||
assertThat(response).isNull();
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
};
|
||||
@Override
|
||||
public void onError(ChatModelErrorContext errorContext) {
|
||||
errorReference.set(errorContext.error());
|
||||
assertThat(errorContext.request()).isSameAs(requestReference.get());
|
||||
assertThat(errorContext.partialResponse()).isNull();
|
||||
assertThat(errorContext.attributes().get("id")).isEqualTo("12345");
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiChatModel model = OpenAiChatModel.builder()
|
||||
.apiKey(wrongApiKey)
|
||||
.maxRetries(0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
|
||||
String userMessage = "this message will fail";
|
||||
|
|
|
@ -8,9 +8,7 @@ import dev.langchain4j.model.StreamingResponseHandler;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
|
||||
import dev.langchain4j.model.listener.ModelListener;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.assertj.core.data.Percentage;
|
||||
|
@ -650,30 +648,29 @@ class OpenAiStreamingChatModelIT {
|
|||
void should_listen_request_and_response() {
|
||||
|
||||
// given
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();
|
||||
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatModelResponse> responseReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
ChatModelListener listener = new ChatModelListener() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
@Override
|
||||
public void onRequest(ChatModelRequestContext requestContext) {
|
||||
requestReference.set(requestContext.request());
|
||||
requestContext.attributes().put("id", "12345");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
responseReference.set(response);
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
@Override
|
||||
public void onResponse(ChatModelResponseContext responseContext) {
|
||||
responseReference.set(responseContext.response());
|
||||
assertThat(responseContext.request()).isSameAs(requestReference.get());
|
||||
assertThat(responseContext.attributes().get("id")).isEqualTo("12345");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
fail("onError() must not be called");
|
||||
}
|
||||
};
|
||||
@Override
|
||||
public void onError(ChatModelErrorContext errorContext) {
|
||||
fail("onError() must not be called");
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiChatModelName modelName = GPT_3_5_TURBO;
|
||||
double temperature = 0.7;
|
||||
|
@ -690,7 +687,7 @@ class OpenAiStreamingChatModelIT {
|
|||
.maxTokens(maxTokens)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("hello");
|
||||
|
@ -707,7 +704,7 @@ class OpenAiStreamingChatModelIT {
|
|||
AiMessage aiMessage = handler.get().content();
|
||||
|
||||
// then
|
||||
ChatLanguageModelRequest request = requestReference.get();
|
||||
ChatModelRequest request = requestReference.get();
|
||||
assertThat(request.model()).isEqualTo(modelName.toString());
|
||||
assertThat(request.temperature()).isEqualTo(temperature);
|
||||
assertThat(request.topP()).isEqualTo(topP);
|
||||
|
@ -715,7 +712,7 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(request.messages()).containsExactly(userMessage);
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
|
||||
ChatLanguageModelResponse response = responseReference.get();
|
||||
ChatModelResponse response = responseReference.get();
|
||||
assertThat(response.id()).isNotBlank();
|
||||
assertThat(response.model()).isNotBlank();
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
|
||||
|
@ -731,37 +728,36 @@ class OpenAiStreamingChatModelIT {
|
|||
// given
|
||||
String wrongApiKey = "banana";
|
||||
|
||||
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<ChatModelRequest> requestReference = new AtomicReference<>();
|
||||
AtomicReference<Throwable> errorReference = new AtomicReference<>();
|
||||
|
||||
ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse> modelListener =
|
||||
new ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>() {
|
||||
ChatModelListener listener = new ChatModelListener() {
|
||||
|
||||
@Override
|
||||
public void onRequest(ChatLanguageModelRequest request) {
|
||||
requestReference.set(request);
|
||||
}
|
||||
@Override
|
||||
public void onRequest(ChatModelRequestContext requestContext) {
|
||||
requestReference.set(requestContext.request());
|
||||
requestContext.attributes().put("id", "12345");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(ChatLanguageModelResponse response, ChatLanguageModelRequest request) {
|
||||
fail("onResponse() must not be called");
|
||||
}
|
||||
@Override
|
||||
public void onResponse(ChatModelResponseContext responseContext) {
|
||||
fail("onResponse() must not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error,
|
||||
ChatLanguageModelResponse response,
|
||||
ChatLanguageModelRequest request) {
|
||||
errorReference.set(error);
|
||||
assertThat(response).isNull(); // can be non-null if it fails in the middle of streaming
|
||||
assertThat(request).isSameAs(requestReference.get());
|
||||
}
|
||||
};
|
||||
@Override
|
||||
public void onError(ChatModelErrorContext errorContext) {
|
||||
errorReference.set(errorContext.error());
|
||||
assertThat(errorContext.request()).isSameAs(requestReference.get());
|
||||
assertThat(errorContext.partialResponse()).isNull(); // can be non-null if it fails in the middle of streaming
|
||||
assertThat(errorContext.attributes().get("id")).isEqualTo("12345");
|
||||
}
|
||||
};
|
||||
|
||||
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
|
||||
.apiKey(wrongApiKey)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(modelListener))
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
|
||||
String userMessage = "this message will fail";
|
||||
|
|
Loading…
Reference in New Issue