From 16f410c7888f913e6f9c8d6b1a6f1ac838f89f2f Mon Sep 17 00:00:00 2001 From: ZYinNJU <1754350460@qq.com> Date: Tue, 17 Sep 2024 17:12:43 +0800 Subject: [PATCH] Ollama chat model listener (#1765) ## Issue Closes #1756 Closes #1750 ## Change 1. `OllamaChatModel` and `OllamaStreamingChatModel` support `ChatListener` 2. Fix `OllamaStreamingLanguageModel` throws `EOFException` when the response content is too long. ## General checklist - [x] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable) --- .../azure/AzureOpenAiChatModelListenerIT.java | 2 +- ...ureOpenAiStreamingChatModelListenerIT.java | 2 +- .../model/chat/ChatModelListenerIT.java | 38 ++++-- .../chat/StreamingChatModelListenerIT.java | 44 ++++--- .../GoogleAiGeminiChatModelListenerIT.java | 2 +- langchain4j-ollama/pom.xml | 6 + .../model/ollama/OllamaChatModel.java | 68 ++++++---- .../ollama/OllamaChatModelListenerUtils.java | 118 +++++++++++++++++ .../model/ollama/OllamaClient.java | 122 ++++++++++-------- .../model/ollama/OllamaMessagesUtils.java | 3 + .../ollama/OllamaStreamingChatModel.java | 19 ++- .../OllamaStreamingResponseBuilder.java | 44 +++++++ ...llamaToolsLanguageModelInfrastructure.java | 1 - .../ollama/OllamaChatModelListenerIT.java | 57 ++++++++ .../model/ollama/OllamaOpenAiChatModelIT.java | 6 + .../OllamaStreamingChatModelListenerIT.java | 61 +++++++++ .../openai/OpenAiChatModelListenerIT.java | 2 +- .../OpenAiStreamingChatModelListenerIT.java | 2 +- .../VertexAiGeminiChatModelListenerIT.java | 2 +- ...xAiGeminiStreamingChatModelListenerIT.java | 2 +- 20 files changed, 487 insertions(+), 114 deletions(-) create mode 100644 langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModelListenerUtils.java create mode 100644 langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingResponseBuilder.java create mode 100644 langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelListenerIT.java create mode 100644 langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelListenerIT.java diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelListenerIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelListenerIT.java index edf5b664a..1b6324748 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelListenerIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelListenerIT.java @@ -41,7 +41,7 @@ class AzureOpenAiChatModelListenerIT extends ChatModelListenerIT { } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return ClientAuthenticationException.class; } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelListenerIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelListenerIT.java index 54e517f3f..f63d6d7ca 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelListenerIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelListenerIT.java @@ -42,7 +42,7 @@ class AzureOpenAiStreamingChatModelListenerIT extends StreamingChatModelListener } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return ClientAuthenticationException.class; } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/chat/ChatModelListenerIT.java b/langchain4j-core/src/test/java/dev/langchain4j/model/chat/ChatModelListenerIT.java index 145f97a5e..6642c3f62 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/chat/ChatModelListenerIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/chat/ChatModelListenerIT.java @@ -71,7 +71,7 @@ public abstract class ChatModelListenerIT { protected abstract ChatLanguageModel createFailingModel(ChatModelListener listener); - protected abstract Class expectedExceptionClass(); + protected abstract Class expectedExceptionClass(); @Test void should_listen_request_and_response() { @@ -105,14 +105,22 @@ public abstract class ChatModelListenerIT { UserMessage userMessage = UserMessage.from("hello"); - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("add") - .addParameter("a", INTEGER) - .addParameter("b", INTEGER) - .build(); + ToolSpecification toolSpecification = null; + if (supportToolCalls()) { + toolSpecification = ToolSpecification.builder() + .name("add") + .addParameter("a", INTEGER) + .addParameter("b", INTEGER) + .build(); + } // when - AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content(); + AiMessage aiMessage; + if (supportToolCalls()) { + aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content(); + } else { + aiMessage = model.generate(singletonList(userMessage)).content(); + } // then ChatModelRequest request = requestReference.get(); @@ -121,7 +129,9 @@ public abstract class ChatModelListenerIT { assertThat(request.topP()).isEqualTo(topP()); assertThat(request.maxTokens()).isEqualTo(maxTokens()); assertThat(request.messages()).containsExactly(userMessage); - assertThat(request.toolSpecifications()).containsExactly(toolSpecification); + if (supportToolCalls()) { + assertThat(request.toolSpecifications()).containsExactly(toolSpecification); + } ChatModelResponse response = responseReference.get(); if (assertResponseId()) { @@ -131,14 +141,24 @@ public abstract class ChatModelListenerIT { assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0); assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0); - assertThat(response.finishReason()).isNotNull(); + if (assertFinishReason()) { + assertThat(response.finishReason()).isNotNull(); + } assertThat(response.aiMessage()).isEqualTo(aiMessage); } + protected boolean supportToolCalls() { + return true; + } + protected boolean assertResponseId() { return true; } + protected boolean assertFinishReason() { + return true; + } + @Test void should_listen_error() { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/chat/StreamingChatModelListenerIT.java b/langchain4j-core/src/test/java/dev/langchain4j/model/chat/StreamingChatModelListenerIT.java index 16d953b44..87b5ba2d9 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/chat/StreamingChatModelListenerIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/chat/StreamingChatModelListenerIT.java @@ -4,12 +4,7 @@ import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.StreamingResponseHandler; -import dev.langchain4j.model.chat.listener.ChatModelErrorContext; -import dev.langchain4j.model.chat.listener.ChatModelListener; -import dev.langchain4j.model.chat.listener.ChatModelRequest; -import dev.langchain4j.model.chat.listener.ChatModelRequestContext; -import dev.langchain4j.model.chat.listener.ChatModelResponse; -import dev.langchain4j.model.chat.listener.ChatModelResponseContext; +import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.output.Response; import org.assertj.core.data.Percentage; import org.junit.jupiter.api.Test; @@ -75,7 +70,7 @@ public abstract class StreamingChatModelListenerIT { protected abstract StreamingChatLanguageModel createFailingModel(ChatModelListener listener); - protected abstract Class expectedExceptionClass(); + protected abstract Class expectedExceptionClass(); @Test void should_listen_request_and_response() { @@ -109,15 +104,22 @@ public abstract class StreamingChatModelListenerIT { UserMessage userMessage = UserMessage.from("hello"); - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("add") - .addParameter("a", INTEGER) - .addParameter("b", INTEGER) - .build(); + ToolSpecification toolSpecification = null; + if (supportToolCalls()) { + toolSpecification = ToolSpecification.builder() + .name("add") + .addParameter("a", INTEGER) + .addParameter("b", INTEGER) + .build(); + } // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - model.generate(singletonList(userMessage), singletonList(toolSpecification), handler); + if (supportToolCalls()) { + model.generate(singletonList(userMessage), singletonList(toolSpecification), handler); + } else { + model.generate(singletonList(userMessage), handler); + } AiMessage aiMessage = handler.get().content(); // then @@ -127,7 +129,9 @@ public abstract class StreamingChatModelListenerIT { assertThat(request.topP()).isEqualTo(topP()); assertThat(request.maxTokens()).isEqualTo(maxTokens()); assertThat(request.messages()).containsExactly(userMessage); - assertThat(request.toolSpecifications()).containsExactly(toolSpecification); + if (supportToolCalls()) { + assertThat(request.toolSpecifications()).containsExactly(toolSpecification); + } ChatModelResponse response = responseReference.get(); if (assertResponseId()) { @@ -137,14 +141,24 @@ public abstract class StreamingChatModelListenerIT { assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0); assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0); - assertThat(response.finishReason()).isNotNull(); + if (assertFinishReason()) { + assertThat(response.finishReason()).isNotNull(); + } assertThat(response.aiMessage()).isEqualTo(aiMessage); } + protected boolean supportToolCalls() { + return true; + } + protected boolean assertResponseId() { return true; } + protected boolean assertFinishReason() { + return true; + } + @Test protected void should_listen_error() throws Exception { diff --git a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelListenerIT.java b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelListenerIT.java index ca80de158..fa44fbcd6 100644 --- a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelListenerIT.java +++ b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelListenerIT.java @@ -43,7 +43,7 @@ class GoogleAiGeminiChatModelListenerIT extends ChatModelListenerIT { } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return RuntimeException.class; } } diff --git a/langchain4j-ollama/pom.xml b/langchain4j-ollama/pom.xml index f04d7913f..84585883d 100644 --- a/langchain4j-ollama/pom.xml +++ b/langchain4j-ollama/pom.xml @@ -68,6 +68,12 @@ test + + org.mockito + mockito-core + test + + org.mockito mockito-junit-jupiter diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java index 7716b5445..d40a321ac 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java @@ -4,21 +4,29 @@ import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequest; import dev.langchain4j.model.ollama.spi.OllamaChatModelBuilderFactory; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.*; import static dev.langchain4j.model.ollama.OllamaMessagesUtils.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.time.Duration.ofSeconds; +import static java.util.Collections.emptyList; /** * Ollama API reference @@ -27,11 +35,14 @@ import static java.time.Duration.ofSeconds; */ public class OllamaChatModel implements ChatLanguageModel { + private static final Logger log = LoggerFactory.getLogger(OllamaChatModel.class); + private final OllamaClient client; private final String modelName; private final Options options; private final String format; private final Integer maxRetries; + private final List listeners; public OllamaChatModel(String baseUrl, String modelName, @@ -48,7 +59,8 @@ public class OllamaChatModel implements ChatLanguageModel { Integer maxRetries, Map customHeaders, Boolean logRequests, - Boolean logResponses) { + Boolean logResponses, + List listeners) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) @@ -69,6 +81,7 @@ public class OllamaChatModel implements ChatLanguageModel { .build(); this.format = format; this.maxRetries = getOrDefault(maxRetries, 3); + this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList())); } public static OllamaChatModelBuilder builder() { @@ -82,26 +95,17 @@ public class OllamaChatModel implements ChatLanguageModel { public Response generate(List messages) { ensureNotEmpty(messages, "messages"); - ChatRequest request = ChatRequest.builder() - .model(modelName) - .messages(toOllamaMessages(messages)) - .options(options) - .format(format) - .stream(false) - .build(); - - ChatResponse response = withRetry(() -> client.chat(request), maxRetries); - - return Response.from( - AiMessage.from(response.getMessage().getContent()), - new TokenUsage(response.getPromptEvalCount(), response.getEvalCount()) - ); + return doGenerate(messages, null); } @Override public Response generate(List messages, List toolSpecifications) { ensureNotEmpty(messages, "messages"); + return doGenerate(messages, toolSpecifications); + } + + private Response doGenerate(List messages, List toolSpecifications) { ChatRequest request = ChatRequest.builder() .model(modelName) .messages(toOllamaMessages(messages)) @@ -111,14 +115,25 @@ public class OllamaChatModel implements ChatLanguageModel { .tools(toOllamaTools(toolSpecifications)) .build(); - ChatResponse response = withRetry(() -> client.chat(request), maxRetries); + ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications); + Map attributes = new ConcurrentHashMap<>(); + onListenRequest(listeners, modelListenerRequest, attributes); - return Response.from( - response.getMessage().getToolCalls() != null ? - AiMessage.from(toToolExecutionRequest(response.getMessage().getToolCalls())) : - AiMessage.from(response.getMessage().getContent()), - new TokenUsage(response.getPromptEvalCount(), response.getEvalCount()) - ); + try { + ChatResponse chatResponse = withRetry(() -> client.chat(request), maxRetries); + Response response = Response.from( + chatResponse.getMessage().getToolCalls() != null ? + AiMessage.from(toToolExecutionRequest(chatResponse.getMessage().getToolCalls())) : + AiMessage.from(chatResponse.getMessage().getContent()), + new TokenUsage(chatResponse.getPromptEvalCount(), chatResponse.getEvalCount()) + ); + onListenResponse(listeners, response, modelListenerRequest, attributes); + + return response; + } catch (Exception e) { + onListenError(listeners, e, modelListenerRequest, null, attributes); + throw e; + } } public static class OllamaChatModelBuilder { @@ -139,6 +154,7 @@ public class OllamaChatModel implements ChatLanguageModel { private Map customHeaders; private Boolean logRequests; private Boolean logResponses; + private List listeners; public OllamaChatModelBuilder() { // This is public so it can be extended @@ -225,6 +241,11 @@ public class OllamaChatModel implements ChatLanguageModel { return this; } + public OllamaChatModelBuilder listeners(List listeners) { + this.listeners = listeners; + return this; + } + public OllamaChatModel build() { return new OllamaChatModel( baseUrl, @@ -242,7 +263,8 @@ public class OllamaChatModel implements ChatLanguageModel { maxRetries, customHeaders, logRequests, - logResponses + logResponses, + listeners ); } } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModelListenerUtils.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModelListenerUtils.java new file mode 100644 index 000000000..2a9586997 --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModelListenerUtils.java @@ -0,0 +1,118 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.listener.*; +import dev.langchain4j.model.output.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; + +class OllamaChatModelListenerUtils { + + private static final Logger log = LoggerFactory.getLogger(OllamaChatModelListenerUtils.class); + + private OllamaChatModelListenerUtils() throws InstantiationException { + throw new InstantiationException("Can't instantiate this utility class."); + } + + /** + * Processes a listen request by notifying all registered chat model listeners. + * + * @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null. + * @param modelListenerRequest The {@link ChatModelRequest} containing the request details. + * @param attributes A map of additional attributes to be passed to the context. + */ + static void onListenRequest(List listeners, ChatModelRequest modelListenerRequest, Map attributes) { + ChatModelRequestContext context = new ChatModelRequestContext(modelListenerRequest, attributes); + listeners.forEach(listener -> { + try { + listener.onRequest(context); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + } + + /** + * Processes a listen response by notifying all registered chat model listeners. + * + * @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null. + * @param response The {@link Response} containing the response details. + * @param modelListenerRequest The original {@link ChatModelRequest} associated with the response. + * @param attributes A map of additional attributes to be passed to the context. + */ + static void onListenResponse(List listeners, Response response, ChatModelRequest modelListenerRequest, Map attributes) { + ChatModelResponse modelListenerResponse = createModelListenerResponse(modelListenerRequest.model(), response); + ChatModelResponseContext context = new ChatModelResponseContext( + modelListenerResponse, + modelListenerRequest, + attributes + ); + listeners.forEach(listener -> { + try { + listener.onResponse(context); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + } + + /** + * Processes a listen error by notifying all registered chat model listeners. + * + * @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null. + * @param error Error between chat + * @param modelListenerRequest The original {@link ChatModelRequest} associated with the response. + * @param partialResponse The partial {@link Response} containing cur response details. + * @param attributes A map of additional attributes to be passed to the context. + */ + static void onListenError(List listeners, Throwable error, ChatModelRequest modelListenerRequest, Response partialResponse, Map attributes) { + ChatModelResponse partialModelListenerResponse = createModelListenerResponse(modelListenerRequest.model(), partialResponse); + ChatModelErrorContext context = new ChatModelErrorContext( + error, + modelListenerRequest, + partialModelListenerResponse, + attributes + ); + listeners.forEach(listener -> { + try { + listener.onError(context); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + } + + static ChatModelRequest createModelListenerRequest(ChatRequest request, + List messages, + List toolSpecifications) { + Options options = request.getOptions(); + + return ChatModelRequest.builder() + .model(request.getModel()) + .temperature(options.getTemperature()) + .topP(options.getTopP()) + .maxTokens(options.getNumPredict()) + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); + } + + static ChatModelResponse createModelListenerResponse(String responseModel, + Response response) { + if (response == null) { + return null; + } + + return ChatModelResponse.builder() + .model(responseModel) + .tokenUsage(response.tokenUsage()) + .finishReason(response.finishReason()) + .aiMessage(response.content()) + .build(); + } +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java index a4dc7a1e9..6b4cf2265 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java @@ -1,8 +1,11 @@ package dev.langchain4j.model.ollama; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.internal.Utils; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequest; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import okhttp3.Interceptor; @@ -22,10 +25,10 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.time.Duration; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.*; import static dev.langchain4j.model.ollama.OllamaJsonUtils.getObjectMapper; import static dev.langchain4j.model.ollama.OllamaJsonUtils.toObject; import static java.lang.Boolean.TRUE; @@ -106,50 +109,6 @@ class OllamaClient { public void streamingCompletion(CompletionRequest request, StreamingResponseHandler handler) { ollamaApi.streamingCompletion(request).enqueue(new Callback() { - @Override - public void onResponse(Call call, retrofit2.Response retrofitResponse) { - try (InputStream inputStream = retrofitResponse.body().byteStream()) { - StringBuilder contentBuilder = new StringBuilder(); - while (true) { - byte[] bytes = new byte[1024]; - int len = inputStream.read(bytes); - String partialResponse = new String(bytes, 0, len); - - if (logStreamingResponses) { - log.debug("Streaming partial response: {}", partialResponse); - } - - CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class); - contentBuilder.append(completionResponse.getResponse()); - handler.onNext(completionResponse.getResponse()); - - if (TRUE.equals(completionResponse.getDone())) { - Response response = Response.from( - contentBuilder.toString(), - new TokenUsage( - completionResponse.getPromptEvalCount(), - completionResponse.getEvalCount() - ) - ); - handler.onComplete(response); - return; - } - } - } catch (Exception e) { - handler.onError(e); - } - } - - @Override - public void onFailure(Call call, Throwable throwable) { - handler.onError(throwable); - } - }); - } - - public void streamingChat(ChatRequest request, StreamingResponseHandler handler) { - ollamaApi.streamingChat(request).enqueue(new Callback() { - @Override public void onResponse(Call call, retrofit2.Response retrofitResponse) { try (InputStream inputStream = retrofitResponse.body().byteStream()) { @@ -162,17 +121,16 @@ class OllamaClient { log.debug("Streaming partial response: {}", partialResponse); } - ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class); - String content = chatResponse.getMessage().getContent(); - contentBuilder.append(content); - handler.onNext(content); + CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class); + contentBuilder.append(completionResponse.getResponse()); + handler.onNext(completionResponse.getResponse()); - if (TRUE.equals(chatResponse.getDone())) { - Response response = Response.from( - AiMessage.from(contentBuilder.toString()), + if (TRUE.equals(completionResponse.getDone())) { + Response response = Response.from( + contentBuilder.toString(), new TokenUsage( - chatResponse.getPromptEvalCount(), - chatResponse.getEvalCount() + completionResponse.getPromptEvalCount(), + completionResponse.getEvalCount() ) ); handler.onComplete(response); @@ -192,6 +150,57 @@ class OllamaClient { }); } + public void streamingChat(ChatRequest request, StreamingResponseHandler handler, + List listeners, List messages) { + ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, new ArrayList<>()); + Map attributes = new ConcurrentHashMap<>(); + onListenRequest(listeners, modelListenerRequest, attributes); + + OllamaStreamingResponseBuilder responseBuilder = new OllamaStreamingResponseBuilder(); + ollamaApi.streamingChat(request).enqueue(new Callback() { + + @Override + public void onResponse(Call call, retrofit2.Response retrofitResponse) { + try (InputStream inputStream = retrofitResponse.body().byteStream()) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + while (true) { + String partialResponse = reader.readLine(); + + if (logStreamingResponses) { + log.debug("Streaming partial response: {}", partialResponse); + } + + ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class); + String content = chatResponse.getMessage().getContent(); + responseBuilder.append(chatResponse); + handler.onNext(content); + + if (TRUE.equals(chatResponse.getDone())) { + Response response = responseBuilder.build(); + handler.onComplete(response); + + onListenResponse(listeners, response, modelListenerRequest, attributes); + + return; + } + } + } + } catch (Exception e) { + onListenError(listeners, e, modelListenerRequest, responseBuilder.build(), attributes); + + handler.onError(e); + } + } + + @Override + public void onFailure(Call call, Throwable throwable) { + onListenError(listeners, throwable, modelListenerRequest, responseBuilder.build(), attributes); + + handler.onError(throwable); + } + }); + } + public EmbeddingResponse embed(EmbeddingRequest request) { try { retrofit2.Response retrofitResponse = ollamaApi.embed(request).execute(); @@ -257,6 +266,7 @@ class OllamaClient { } } + private RuntimeException toException(retrofit2.Response response) throws IOException { int code = response.code(); String body = response.errorBody().string(); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java index 21a6d8ef8..3ba5e1ea2 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java @@ -34,6 +34,9 @@ class OllamaMessagesUtils { } static List toOllamaTools(List toolSpecifications) { + if (toolSpecifications == null) { + return null; + } return toolSpecifications.stream().map(toolSpecification -> Tool.builder() .function(Function.builder() diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java index 08388fe24..9d411c1eb 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java @@ -4,9 +4,11 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.listener.ChatModelListener; import dev.langchain4j.model.ollama.spi.OllamaStreamingChatModelBuilderFactory; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -16,6 +18,7 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaMessages; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.time.Duration.ofSeconds; +import static java.util.Collections.emptyList; /** * Ollama API reference @@ -28,6 +31,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { private final String modelName; private final Options options; private final String format; + private final List listeners; public OllamaStreamingChatModel(String baseUrl, String modelName, @@ -43,7 +47,8 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { Duration timeout, Boolean logRequests, Boolean logResponses, - Map customHeaders + Map customHeaders, + List listeners ) { this.client = OllamaClient.builder() .baseUrl(baseUrl) @@ -64,6 +69,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { .stop(stop) .build(); this.format = format; + this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList())); } public static OllamaStreamingChatModelBuilder builder() { @@ -85,7 +91,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { .stream(true) .build(); - client.streamingChat(request, handler); + client.streamingChat(request, handler, listeners, messages); } public static class OllamaStreamingChatModelBuilder { @@ -105,6 +111,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { private Map customHeaders; private Boolean logRequests; private Boolean logResponses; + private List listeners; public OllamaStreamingChatModelBuilder() { // This is public so it can be extended @@ -186,6 +193,11 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { return this; } + public OllamaStreamingChatModelBuilder listeners(List listeners) { + this.listeners = listeners; + return this; + } + public OllamaStreamingChatModel build() { return new OllamaStreamingChatModel( baseUrl, @@ -202,7 +214,8 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel { timeout, logRequests, logResponses, - customHeaders + customHeaders, + listeners ); } } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingResponseBuilder.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingResponseBuilder.java new file mode 100644 index 000000000..f1fb3bdb5 --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingResponseBuilder.java @@ -0,0 +1,44 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; + +/** + * This class needs to be thread safe because it is called when a streaming result comes back + * and there is no guarantee that this thread will be the same as the one that initiated the request, + * in fact it almost certainly won't be. + */ +class OllamaStreamingResponseBuilder { + + private StringBuffer contentBuilder = new StringBuffer(); + private volatile TokenUsage tokenUsage; + + void append(ChatResponse partialResponse) { + if (partialResponse == null) { + return; + } + + if (partialResponse.getEvalCount() != null && partialResponse.getPromptEvalCount() != null) { + this.tokenUsage = new TokenUsage( + partialResponse.getPromptEvalCount(), + partialResponse.getEvalCount() + ); + } + + String content = partialResponse.getMessage().getContent(); + if (content != null) { + contentBuilder.append(content); + } + } + + Response build() { + if (contentBuilder.toString().isEmpty()) { + return null; + } + return Response.from( + AiMessage.from(contentBuilder.toString()), + tokenUsage + ); + } +} diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java index e87d36fcc..807d28cd1 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java @@ -14,5 +14,4 @@ class AbstractOllamaToolsLanguageModelInfrastructure { } - } diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelListenerIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelListenerIT.java new file mode 100644 index 000000000..163096c03 --- /dev/null +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelListenerIT.java @@ -0,0 +1,57 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.ChatModelListenerIT; +import dev.langchain4j.model.chat.listener.ChatModelListener; + +import static dev.langchain4j.model.ollama.OllamaImage.TOOL_MODEL; +import static java.util.Collections.singletonList; + +class OllamaChatModelListenerIT extends ChatModelListenerIT { + + @Override + protected ChatLanguageModel createModel(ChatModelListener listener) { + return OllamaChatModel.builder() + .baseUrl(AbstractOllamaToolsLanguageModelInfrastructure.ollama.getEndpoint()) + .modelName(TOOL_MODEL) + .temperature(temperature()) + .topP(topP()) + .numPredict(maxTokens()) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected String modelName() { + return TOOL_MODEL; + } + + @Override + protected ChatLanguageModel createFailingModel(ChatModelListener listener) { + return OllamaChatModel.builder() + .baseUrl(AbstractOllamaToolsLanguageModelInfrastructure.ollama.getEndpoint()) + .modelName("banana") + .maxRetries(0) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected Class expectedExceptionClass() { + return NullPointerException.class; + } + + @Override + protected boolean assertResponseId() { + return false; + } + + @Override + protected boolean assertFinishReason() { + return false; + } +} diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java index c285c1bf9..745c696e5 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java @@ -3,14 +3,20 @@ package dev.langchain4j.model.ollama; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; +import java.util.concurrent.atomic.AtomicReference; + import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL; import static dev.langchain4j.model.output.FinishReason.STOP; +import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * Tests if Ollama can be used via OpenAI API (langchain4j-open-ai module) diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelListenerIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelListenerIT.java new file mode 100644 index 000000000..56cc9bc32 --- /dev/null +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelListenerIT.java @@ -0,0 +1,61 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatModelListenerIT; +import dev.langchain4j.model.chat.listener.ChatModelListener; + +import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL; +import static java.util.Collections.singletonList; + +public class OllamaStreamingChatModelListenerIT extends StreamingChatModelListenerIT { + + @Override + protected StreamingChatLanguageModel createModel(ChatModelListener listener) { + return OllamaStreamingChatModel.builder() + .baseUrl(AbstractOllamaLanguageModelInfrastructure.ollama.getEndpoint()) + .modelName(TINY_DOLPHIN_MODEL) + .temperature(temperature()) + .topP(topP()) + .numPredict(maxTokens()) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected String modelName() { + return TINY_DOLPHIN_MODEL; + } + + @Override + protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) { + return OllamaStreamingChatModel.builder() + .baseUrl(AbstractOllamaLanguageModelInfrastructure.ollama.getEndpoint()) + .modelName("banana") + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected Class expectedExceptionClass() { + return NullPointerException.class; + } + + @Override + protected boolean supportToolCalls() { + return false; + } + + @Override + protected boolean assertResponseId() { + return false; + } + + @Override + protected boolean assertFinishReason() { + return false; + } +} diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelListenerIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelListenerIT.java index 17b825af5..a716797fb 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelListenerIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelListenerIT.java @@ -41,7 +41,7 @@ class OpenAiChatModelListenerIT extends ChatModelListenerIT { } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return OpenAiHttpException.class; } } diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelListenerIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelListenerIT.java index da2eb6caa..5e7f669f7 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelListenerIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelListenerIT.java @@ -42,7 +42,7 @@ class OpenAiStreamingChatModelListenerIT extends StreamingChatModelListenerIT { } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return OpenAiHttpException.class; } } \ No newline at end of file diff --git a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelListenerIT.java b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelListenerIT.java index 8d720a5e7..2bfd74a57 100644 --- a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelListenerIT.java +++ b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelListenerIT.java @@ -45,7 +45,7 @@ public class VertexAiGeminiChatModelListenerIT extends ChatModelListenerIT { } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return RuntimeException.class; } } diff --git a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelListenerIT.java b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelListenerIT.java index 43c7279f9..8668db609 100644 --- a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelListenerIT.java +++ b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelListenerIT.java @@ -46,7 +46,7 @@ public class VertexAiGeminiStreamingChatModelListenerIT extends StreamingChatMod } @Override - protected Class expectedExceptionClass() { + protected Class expectedExceptionClass() { return NotFoundException.class; } }