Ollama chat model listener (#1765)
## Issue Closes #1756 Closes #1750 ## Change 1. `OllamaChatModel` and `OllamaStreamingChatModel` support `ChatListener` 2. Fix `OllamaStreamingLanguageModel` throws `EOFException` when the response content is too long. ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [x] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
07ce173fb3
commit
16f410c788
|
@ -41,7 +41,7 @@ class AzureOpenAiChatModelListenerIT extends ChatModelListenerIT {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return ClientAuthenticationException.class;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ class AzureOpenAiStreamingChatModelListenerIT extends StreamingChatModelListener
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return ClientAuthenticationException.class;
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ public abstract class ChatModelListenerIT {
|
|||
|
||||
protected abstract ChatLanguageModel createFailingModel(ChatModelListener listener);
|
||||
|
||||
protected abstract Class<?> expectedExceptionClass();
|
||||
protected abstract Class<? extends Exception> expectedExceptionClass();
|
||||
|
||||
@Test
|
||||
void should_listen_request_and_response() {
|
||||
|
@ -105,14 +105,22 @@ public abstract class ChatModelListenerIT {
|
|||
|
||||
UserMessage userMessage = UserMessage.from("hello");
|
||||
|
||||
ToolSpecification toolSpecification = ToolSpecification.builder()
|
||||
.name("add")
|
||||
.addParameter("a", INTEGER)
|
||||
.addParameter("b", INTEGER)
|
||||
.build();
|
||||
ToolSpecification toolSpecification = null;
|
||||
if (supportToolCalls()) {
|
||||
toolSpecification = ToolSpecification.builder()
|
||||
.name("add")
|
||||
.addParameter("a", INTEGER)
|
||||
.addParameter("b", INTEGER)
|
||||
.build();
|
||||
}
|
||||
|
||||
// when
|
||||
AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
|
||||
AiMessage aiMessage;
|
||||
if (supportToolCalls()) {
|
||||
aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
|
||||
} else {
|
||||
aiMessage = model.generate(singletonList(userMessage)).content();
|
||||
}
|
||||
|
||||
// then
|
||||
ChatModelRequest request = requestReference.get();
|
||||
|
@ -121,7 +129,9 @@ public abstract class ChatModelListenerIT {
|
|||
assertThat(request.topP()).isEqualTo(topP());
|
||||
assertThat(request.maxTokens()).isEqualTo(maxTokens());
|
||||
assertThat(request.messages()).containsExactly(userMessage);
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
if (supportToolCalls()) {
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
}
|
||||
|
||||
ChatModelResponse response = responseReference.get();
|
||||
if (assertResponseId()) {
|
||||
|
@ -131,14 +141,24 @@ public abstract class ChatModelListenerIT {
|
|||
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.finishReason()).isNotNull();
|
||||
if (assertFinishReason()) {
|
||||
assertThat(response.finishReason()).isNotNull();
|
||||
}
|
||||
assertThat(response.aiMessage()).isEqualTo(aiMessage);
|
||||
}
|
||||
|
||||
protected boolean supportToolCalls() {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected boolean assertResponseId() {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected boolean assertFinishReason() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_listen_error() {
|
||||
|
||||
|
|
|
@ -4,12 +4,7 @@ import dev.langchain4j.agent.tool.ToolSpecification;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.assertj.core.data.Percentage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -75,7 +70,7 @@ public abstract class StreamingChatModelListenerIT {
|
|||
|
||||
protected abstract StreamingChatLanguageModel createFailingModel(ChatModelListener listener);
|
||||
|
||||
protected abstract Class<?> expectedExceptionClass();
|
||||
protected abstract Class<? extends Exception> expectedExceptionClass();
|
||||
|
||||
@Test
|
||||
void should_listen_request_and_response() {
|
||||
|
@ -109,15 +104,22 @@ public abstract class StreamingChatModelListenerIT {
|
|||
|
||||
UserMessage userMessage = UserMessage.from("hello");
|
||||
|
||||
ToolSpecification toolSpecification = ToolSpecification.builder()
|
||||
.name("add")
|
||||
.addParameter("a", INTEGER)
|
||||
.addParameter("b", INTEGER)
|
||||
.build();
|
||||
ToolSpecification toolSpecification = null;
|
||||
if (supportToolCalls()) {
|
||||
toolSpecification = ToolSpecification.builder()
|
||||
.name("add")
|
||||
.addParameter("a", INTEGER)
|
||||
.addParameter("b", INTEGER)
|
||||
.build();
|
||||
}
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(singletonList(userMessage), singletonList(toolSpecification), handler);
|
||||
if (supportToolCalls()) {
|
||||
model.generate(singletonList(userMessage), singletonList(toolSpecification), handler);
|
||||
} else {
|
||||
model.generate(singletonList(userMessage), handler);
|
||||
}
|
||||
AiMessage aiMessage = handler.get().content();
|
||||
|
||||
// then
|
||||
|
@ -127,7 +129,9 @@ public abstract class StreamingChatModelListenerIT {
|
|||
assertThat(request.topP()).isEqualTo(topP());
|
||||
assertThat(request.maxTokens()).isEqualTo(maxTokens());
|
||||
assertThat(request.messages()).containsExactly(userMessage);
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
if (supportToolCalls()) {
|
||||
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
|
||||
}
|
||||
|
||||
ChatModelResponse response = responseReference.get();
|
||||
if (assertResponseId()) {
|
||||
|
@ -137,14 +141,24 @@ public abstract class StreamingChatModelListenerIT {
|
|||
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.finishReason()).isNotNull();
|
||||
if (assertFinishReason()) {
|
||||
assertThat(response.finishReason()).isNotNull();
|
||||
}
|
||||
assertThat(response.aiMessage()).isEqualTo(aiMessage);
|
||||
}
|
||||
|
||||
protected boolean supportToolCalls() {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected boolean assertResponseId() {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected boolean assertFinishReason() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Test
|
||||
protected void should_listen_error() throws Exception {
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class GoogleAiGeminiChatModelListenerIT extends ChatModelListenerIT {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return RuntimeException.class;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,6 +68,12 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
|
|
|
@ -4,21 +4,29 @@ import dev.langchain4j.agent.tool.ToolSpecification;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.ollama.spi.OllamaChatModelBuilderFactory;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.*;
|
||||
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.*;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
/**
|
||||
* <a href="https://github.com/jmorganca/ollama/blob/main/docs/api.md">Ollama API reference</a>
|
||||
|
@ -27,11 +35,14 @@ import static java.time.Duration.ofSeconds;
|
|||
*/
|
||||
public class OllamaChatModel implements ChatLanguageModel {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(OllamaChatModel.class);
|
||||
|
||||
private final OllamaClient client;
|
||||
private final String modelName;
|
||||
private final Options options;
|
||||
private final String format;
|
||||
private final Integer maxRetries;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
public OllamaChatModel(String baseUrl,
|
||||
String modelName,
|
||||
|
@ -48,7 +59,8 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
Integer maxRetries,
|
||||
Map<String, String> customHeaders,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
Boolean logResponses,
|
||||
List<ChatModelListener> listeners) {
|
||||
this.client = OllamaClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||
|
@ -69,6 +81,7 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
.build();
|
||||
this.format = format;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList()));
|
||||
}
|
||||
|
||||
public static OllamaChatModelBuilder builder() {
|
||||
|
@ -82,26 +95,17 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOllamaMessages(messages))
|
||||
.options(options)
|
||||
.format(format)
|
||||
.stream(false)
|
||||
.build();
|
||||
|
||||
ChatResponse response = withRetry(() -> client.chat(request), maxRetries);
|
||||
|
||||
return Response.from(
|
||||
AiMessage.from(response.getMessage().getContent()),
|
||||
new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
|
||||
);
|
||||
return doGenerate(messages, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
|
||||
return doGenerate(messages, toolSpecifications);
|
||||
}
|
||||
|
||||
private Response<AiMessage> doGenerate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOllamaMessages(messages))
|
||||
|
@ -111,14 +115,25 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
.tools(toOllamaTools(toolSpecifications))
|
||||
.build();
|
||||
|
||||
ChatResponse response = withRetry(() -> client.chat(request), maxRetries);
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
onListenRequest(listeners, modelListenerRequest, attributes);
|
||||
|
||||
return Response.from(
|
||||
response.getMessage().getToolCalls() != null ?
|
||||
AiMessage.from(toToolExecutionRequest(response.getMessage().getToolCalls())) :
|
||||
AiMessage.from(response.getMessage().getContent()),
|
||||
new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
|
||||
);
|
||||
try {
|
||||
ChatResponse chatResponse = withRetry(() -> client.chat(request), maxRetries);
|
||||
Response<AiMessage> response = Response.from(
|
||||
chatResponse.getMessage().getToolCalls() != null ?
|
||||
AiMessage.from(toToolExecutionRequest(chatResponse.getMessage().getToolCalls())) :
|
||||
AiMessage.from(chatResponse.getMessage().getContent()),
|
||||
new TokenUsage(chatResponse.getPromptEvalCount(), chatResponse.getEvalCount())
|
||||
);
|
||||
onListenResponse(listeners, response, modelListenerRequest, attributes);
|
||||
|
||||
return response;
|
||||
} catch (Exception e) {
|
||||
onListenError(listeners, e, modelListenerRequest, null, attributes);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public static class OllamaChatModelBuilder {
|
||||
|
@ -139,6 +154,7 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
private Map<String, String> customHeaders;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
private List<ChatModelListener> listeners;
|
||||
|
||||
public OllamaChatModelBuilder() {
|
||||
// This is public so it can be extended
|
||||
|
@ -225,6 +241,11 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
return this;
|
||||
}
|
||||
|
||||
public OllamaChatModelBuilder listeners(List<ChatModelListener> listeners) {
|
||||
this.listeners = listeners;
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaChatModel build() {
|
||||
return new OllamaChatModel(
|
||||
baseUrl,
|
||||
|
@ -242,7 +263,8 @@ public class OllamaChatModel implements ChatLanguageModel {
|
|||
maxRetries,
|
||||
customHeaders,
|
||||
logRequests,
|
||||
logResponses
|
||||
logResponses,
|
||||
listeners
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
package dev.langchain4j.model.ollama;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
class OllamaChatModelListenerUtils {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(OllamaChatModelListenerUtils.class);
|
||||
|
||||
private OllamaChatModelListenerUtils() throws InstantiationException {
|
||||
throw new InstantiationException("Can't instantiate this utility class.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a listen request by notifying all registered chat model listeners.
|
||||
*
|
||||
* @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null.
|
||||
* @param modelListenerRequest The {@link ChatModelRequest} containing the request details.
|
||||
* @param attributes A map of additional attributes to be passed to the context.
|
||||
*/
|
||||
static void onListenRequest(List<ChatModelListener> listeners, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
|
||||
ChatModelRequestContext context = new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(context);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a listen response by notifying all registered chat model listeners.
|
||||
*
|
||||
* @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null.
|
||||
* @param response The {@link Response} containing the response details.
|
||||
* @param modelListenerRequest The original {@link ChatModelRequest} associated with the response.
|
||||
* @param attributes A map of additional attributes to be passed to the context.
|
||||
*/
|
||||
static void onListenResponse(List<ChatModelListener> listeners, Response<AiMessage> response, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
|
||||
ChatModelResponse modelListenerResponse = createModelListenerResponse(modelListenerRequest.model(), response);
|
||||
ChatModelResponseContext context = new ChatModelResponseContext(
|
||||
modelListenerResponse,
|
||||
modelListenerRequest,
|
||||
attributes
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(context);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a listen error by notifying all registered chat model listeners.
|
||||
*
|
||||
* @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null.
|
||||
* @param error Error between chat
|
||||
* @param modelListenerRequest The original {@link ChatModelRequest} associated with the response.
|
||||
* @param partialResponse The partial {@link Response} containing cur response details.
|
||||
* @param attributes A map of additional attributes to be passed to the context.
|
||||
*/
|
||||
static void onListenError(List<ChatModelListener> listeners, Throwable error, ChatModelRequest modelListenerRequest, Response<AiMessage> partialResponse, Map<Object, Object> attributes) {
|
||||
ChatModelResponse partialModelListenerResponse = createModelListenerResponse(modelListenerRequest.model(), partialResponse);
|
||||
ChatModelErrorContext context = new ChatModelErrorContext(
|
||||
error,
|
||||
modelListenerRequest,
|
||||
partialModelListenerResponse,
|
||||
attributes
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(context);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static ChatModelRequest createModelListenerRequest(ChatRequest request,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
Options options = request.getOptions();
|
||||
|
||||
return ChatModelRequest.builder()
|
||||
.model(request.getModel())
|
||||
.temperature(options.getTemperature())
|
||||
.topP(options.getTopP())
|
||||
.maxTokens(options.getNumPredict())
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
}
|
||||
|
||||
static ChatModelResponse createModelListenerResponse(String responseModel,
|
||||
Response<AiMessage> response) {
|
||||
if (response == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return ChatModelResponse.builder()
|
||||
.model(responseModel)
|
||||
.tokenUsage(response.tokenUsage())
|
||||
.finishReason(response.finishReason())
|
||||
.aiMessage(response.content())
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -1,8 +1,11 @@
|
|||
package dev.langchain4j.model.ollama;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import okhttp3.Interceptor;
|
||||
|
@ -22,10 +25,10 @@ import java.io.IOException;
|
|||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.time.Duration;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.*;
|
||||
import static dev.langchain4j.model.ollama.OllamaJsonUtils.getObjectMapper;
|
||||
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toObject;
|
||||
import static java.lang.Boolean.TRUE;
|
||||
|
@ -106,50 +109,6 @@ class OllamaClient {
|
|||
public void streamingCompletion(CompletionRequest request, StreamingResponseHandler<String> handler) {
|
||||
ollamaApi.streamingCompletion(request).enqueue(new Callback<ResponseBody>() {
|
||||
|
||||
@Override
|
||||
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
|
||||
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
|
||||
StringBuilder contentBuilder = new StringBuilder();
|
||||
while (true) {
|
||||
byte[] bytes = new byte[1024];
|
||||
int len = inputStream.read(bytes);
|
||||
String partialResponse = new String(bytes, 0, len);
|
||||
|
||||
if (logStreamingResponses) {
|
||||
log.debug("Streaming partial response: {}", partialResponse);
|
||||
}
|
||||
|
||||
CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class);
|
||||
contentBuilder.append(completionResponse.getResponse());
|
||||
handler.onNext(completionResponse.getResponse());
|
||||
|
||||
if (TRUE.equals(completionResponse.getDone())) {
|
||||
Response<String> response = Response.from(
|
||||
contentBuilder.toString(),
|
||||
new TokenUsage(
|
||||
completionResponse.getPromptEvalCount(),
|
||||
completionResponse.getEvalCount()
|
||||
)
|
||||
);
|
||||
handler.onComplete(response);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
handler.onError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Call<ResponseBody> call, Throwable throwable) {
|
||||
handler.onError(throwable);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public void streamingChat(ChatRequest request, StreamingResponseHandler<AiMessage> handler) {
|
||||
ollamaApi.streamingChat(request).enqueue(new Callback<ResponseBody>() {
|
||||
|
||||
@Override
|
||||
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
|
||||
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
|
||||
|
@ -162,17 +121,16 @@ class OllamaClient {
|
|||
log.debug("Streaming partial response: {}", partialResponse);
|
||||
}
|
||||
|
||||
ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class);
|
||||
String content = chatResponse.getMessage().getContent();
|
||||
contentBuilder.append(content);
|
||||
handler.onNext(content);
|
||||
CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class);
|
||||
contentBuilder.append(completionResponse.getResponse());
|
||||
handler.onNext(completionResponse.getResponse());
|
||||
|
||||
if (TRUE.equals(chatResponse.getDone())) {
|
||||
Response<AiMessage> response = Response.from(
|
||||
AiMessage.from(contentBuilder.toString()),
|
||||
if (TRUE.equals(completionResponse.getDone())) {
|
||||
Response<String> response = Response.from(
|
||||
contentBuilder.toString(),
|
||||
new TokenUsage(
|
||||
chatResponse.getPromptEvalCount(),
|
||||
chatResponse.getEvalCount()
|
||||
completionResponse.getPromptEvalCount(),
|
||||
completionResponse.getEvalCount()
|
||||
)
|
||||
);
|
||||
handler.onComplete(response);
|
||||
|
@ -192,6 +150,57 @@ class OllamaClient {
|
|||
});
|
||||
}
|
||||
|
||||
public void streamingChat(ChatRequest request, StreamingResponseHandler<AiMessage> handler,
|
||||
List<ChatModelListener> listeners, List<ChatMessage> messages) {
|
||||
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, new ArrayList<>());
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
onListenRequest(listeners, modelListenerRequest, attributes);
|
||||
|
||||
OllamaStreamingResponseBuilder responseBuilder = new OllamaStreamingResponseBuilder();
|
||||
ollamaApi.streamingChat(request).enqueue(new Callback<ResponseBody>() {
|
||||
|
||||
@Override
|
||||
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
|
||||
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
|
||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
|
||||
while (true) {
|
||||
String partialResponse = reader.readLine();
|
||||
|
||||
if (logStreamingResponses) {
|
||||
log.debug("Streaming partial response: {}", partialResponse);
|
||||
}
|
||||
|
||||
ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class);
|
||||
String content = chatResponse.getMessage().getContent();
|
||||
responseBuilder.append(chatResponse);
|
||||
handler.onNext(content);
|
||||
|
||||
if (TRUE.equals(chatResponse.getDone())) {
|
||||
Response<AiMessage> response = responseBuilder.build();
|
||||
handler.onComplete(response);
|
||||
|
||||
onListenResponse(listeners, response, modelListenerRequest, attributes);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
onListenError(listeners, e, modelListenerRequest, responseBuilder.build(), attributes);
|
||||
|
||||
handler.onError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Call<ResponseBody> call, Throwable throwable) {
|
||||
onListenError(listeners, throwable, modelListenerRequest, responseBuilder.build(), attributes);
|
||||
|
||||
handler.onError(throwable);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public EmbeddingResponse embed(EmbeddingRequest request) {
|
||||
try {
|
||||
retrofit2.Response<EmbeddingResponse> retrofitResponse = ollamaApi.embed(request).execute();
|
||||
|
@ -257,6 +266,7 @@ class OllamaClient {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
private RuntimeException toException(retrofit2.Response<?> response) throws IOException {
|
||||
int code = response.code();
|
||||
String body = response.errorBody().string();
|
||||
|
|
|
@ -34,6 +34,9 @@ class OllamaMessagesUtils {
|
|||
}
|
||||
|
||||
static List<Tool> toOllamaTools(List<ToolSpecification> toolSpecifications) {
|
||||
if (toolSpecifications == null) {
|
||||
return null;
|
||||
}
|
||||
return toolSpecifications.stream().map(toolSpecification ->
|
||||
Tool.builder()
|
||||
.function(Function.builder()
|
||||
|
|
|
@ -4,9 +4,11 @@ import dev.langchain4j.data.message.AiMessage;
|
|||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.ollama.spi.OllamaStreamingChatModelBuilderFactory;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -16,6 +18,7 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
|||
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaMessages;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
/**
|
||||
* <a href="https://github.com/jmorganca/ollama/blob/main/docs/api.md">Ollama API reference</a>
|
||||
|
@ -28,6 +31,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
private final String modelName;
|
||||
private final Options options;
|
||||
private final String format;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
public OllamaStreamingChatModel(String baseUrl,
|
||||
String modelName,
|
||||
|
@ -43,7 +47,8 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
Duration timeout,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Map<String, String> customHeaders
|
||||
Map<String, String> customHeaders,
|
||||
List<ChatModelListener> listeners
|
||||
) {
|
||||
this.client = OllamaClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
|
@ -64,6 +69,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
.stop(stop)
|
||||
.build();
|
||||
this.format = format;
|
||||
this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList()));
|
||||
}
|
||||
|
||||
public static OllamaStreamingChatModelBuilder builder() {
|
||||
|
@ -85,7 +91,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
.stream(true)
|
||||
.build();
|
||||
|
||||
client.streamingChat(request, handler);
|
||||
client.streamingChat(request, handler, listeners, messages);
|
||||
}
|
||||
|
||||
public static class OllamaStreamingChatModelBuilder {
|
||||
|
@ -105,6 +111,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
private Map<String, String> customHeaders;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
private List<ChatModelListener> listeners;
|
||||
|
||||
public OllamaStreamingChatModelBuilder() {
|
||||
// This is public so it can be extended
|
||||
|
@ -186,6 +193,11 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
return this;
|
||||
}
|
||||
|
||||
public OllamaStreamingChatModelBuilder listeners(List<ChatModelListener> listeners) {
|
||||
this.listeners = listeners;
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaStreamingChatModel build() {
|
||||
return new OllamaStreamingChatModel(
|
||||
baseUrl,
|
||||
|
@ -202,7 +214,8 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
timeout,
|
||||
logRequests,
|
||||
logResponses,
|
||||
customHeaders
|
||||
customHeaders,
|
||||
listeners
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package dev.langchain4j.model.ollama;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
/**
|
||||
* This class needs to be thread safe because it is called when a streaming result comes back
|
||||
* and there is no guarantee that this thread will be the same as the one that initiated the request,
|
||||
* in fact it almost certainly won't be.
|
||||
*/
|
||||
class OllamaStreamingResponseBuilder {
|
||||
|
||||
private StringBuffer contentBuilder = new StringBuffer();
|
||||
private volatile TokenUsage tokenUsage;
|
||||
|
||||
void append(ChatResponse partialResponse) {
|
||||
if (partialResponse == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (partialResponse.getEvalCount() != null && partialResponse.getPromptEvalCount() != null) {
|
||||
this.tokenUsage = new TokenUsage(
|
||||
partialResponse.getPromptEvalCount(),
|
||||
partialResponse.getEvalCount()
|
||||
);
|
||||
}
|
||||
|
||||
String content = partialResponse.getMessage().getContent();
|
||||
if (content != null) {
|
||||
contentBuilder.append(content);
|
||||
}
|
||||
}
|
||||
|
||||
Response<AiMessage> build() {
|
||||
if (contentBuilder.toString().isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
return Response.from(
|
||||
AiMessage.from(contentBuilder.toString()),
|
||||
tokenUsage
|
||||
);
|
||||
}
|
||||
}
|
|
@ -14,5 +14,4 @@ class AbstractOllamaToolsLanguageModelInfrastructure {
|
|||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package dev.langchain4j.model.ollama;
|
||||
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.ChatModelListenerIT;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
|
||||
import static dev.langchain4j.model.ollama.OllamaImage.TOOL_MODEL;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
class OllamaChatModelListenerIT extends ChatModelListenerIT {
|
||||
|
||||
@Override
|
||||
protected ChatLanguageModel createModel(ChatModelListener listener) {
|
||||
return OllamaChatModel.builder()
|
||||
.baseUrl(AbstractOllamaToolsLanguageModelInfrastructure.ollama.getEndpoint())
|
||||
.modelName(TOOL_MODEL)
|
||||
.temperature(temperature())
|
||||
.topP(topP())
|
||||
.numPredict(maxTokens())
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String modelName() {
|
||||
return TOOL_MODEL;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ChatLanguageModel createFailingModel(ChatModelListener listener) {
|
||||
return OllamaChatModel.builder()
|
||||
.baseUrl(AbstractOllamaToolsLanguageModelInfrastructure.ollama.getEndpoint())
|
||||
.modelName("banana")
|
||||
.maxRetries(0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return NullPointerException.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertResponseId() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertFinishReason() {
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -3,14 +3,20 @@ package dev.langchain4j.model.ollama;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Fail.fail;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
/**
|
||||
* Tests if Ollama can be used via OpenAI API (langchain4j-open-ai module)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.model.ollama;
|
||||
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
|
||||
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
public class OllamaStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
||||
|
||||
@Override
|
||||
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
|
||||
return OllamaStreamingChatModel.builder()
|
||||
.baseUrl(AbstractOllamaLanguageModelInfrastructure.ollama.getEndpoint())
|
||||
.modelName(TINY_DOLPHIN_MODEL)
|
||||
.temperature(temperature())
|
||||
.topP(topP())
|
||||
.numPredict(maxTokens())
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String modelName() {
|
||||
return TINY_DOLPHIN_MODEL;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
|
||||
return OllamaStreamingChatModel.builder()
|
||||
.baseUrl(AbstractOllamaLanguageModelInfrastructure.ollama.getEndpoint())
|
||||
.modelName("banana")
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.listeners(singletonList(listener))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return NullPointerException.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportToolCalls() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertResponseId() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertFinishReason() {
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -41,7 +41,7 @@ class OpenAiChatModelListenerIT extends ChatModelListenerIT {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return OpenAiHttpException.class;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ class OpenAiStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return OpenAiHttpException.class;
|
||||
}
|
||||
}
|
|
@ -45,7 +45,7 @@ public class VertexAiGeminiChatModelListenerIT extends ChatModelListenerIT {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return RuntimeException.class;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ public class VertexAiGeminiStreamingChatModelListenerIT extends StreamingChatMod
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> expectedExceptionClass() {
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return NotFoundException.class;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue