Ollama chat model listener (#1765)

## Issue
Closes #1756
Closes #1750

## Change
1. `OllamaChatModel` and `OllamaStreamingChatModel` support
`ChatListener`
2. Fix `OllamaStreamingLanguageModel` throws `EOFException` when the
response content is too long.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
ZYinNJU 2024-09-17 17:12:43 +08:00 committed by GitHub
parent 07ce173fb3
commit 16f410c788
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 487 additions and 114 deletions

View File

@ -41,7 +41,7 @@ class AzureOpenAiChatModelListenerIT extends ChatModelListenerIT {
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return ClientAuthenticationException.class;
}
}

View File

@ -42,7 +42,7 @@ class AzureOpenAiStreamingChatModelListenerIT extends StreamingChatModelListener
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return ClientAuthenticationException.class;
}

View File

@ -71,7 +71,7 @@ public abstract class ChatModelListenerIT {
protected abstract ChatLanguageModel createFailingModel(ChatModelListener listener);
protected abstract Class<?> expectedExceptionClass();
protected abstract Class<? extends Exception> expectedExceptionClass();
@Test
void should_listen_request_and_response() {
@ -105,14 +105,22 @@ public abstract class ChatModelListenerIT {
UserMessage userMessage = UserMessage.from("hello");
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("add")
.addParameter("a", INTEGER)
.addParameter("b", INTEGER)
.build();
ToolSpecification toolSpecification = null;
if (supportToolCalls()) {
toolSpecification = ToolSpecification.builder()
.name("add")
.addParameter("a", INTEGER)
.addParameter("b", INTEGER)
.build();
}
// when
AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
AiMessage aiMessage;
if (supportToolCalls()) {
aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content();
} else {
aiMessage = model.generate(singletonList(userMessage)).content();
}
// then
ChatModelRequest request = requestReference.get();
@ -121,7 +129,9 @@ public abstract class ChatModelListenerIT {
assertThat(request.topP()).isEqualTo(topP());
assertThat(request.maxTokens()).isEqualTo(maxTokens());
assertThat(request.messages()).containsExactly(userMessage);
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
if (supportToolCalls()) {
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
}
ChatModelResponse response = responseReference.get();
if (assertResponseId()) {
@ -131,14 +141,24 @@ public abstract class ChatModelListenerIT {
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
assertThat(response.finishReason()).isNotNull();
if (assertFinishReason()) {
assertThat(response.finishReason()).isNotNull();
}
assertThat(response.aiMessage()).isEqualTo(aiMessage);
}
protected boolean supportToolCalls() {
return true;
}
protected boolean assertResponseId() {
return true;
}
protected boolean assertFinishReason() {
return true;
}
@Test
void should_listen_error() {

View File

@ -4,12 +4,7 @@ import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.Test;
@ -75,7 +70,7 @@ public abstract class StreamingChatModelListenerIT {
protected abstract StreamingChatLanguageModel createFailingModel(ChatModelListener listener);
protected abstract Class<?> expectedExceptionClass();
protected abstract Class<? extends Exception> expectedExceptionClass();
@Test
void should_listen_request_and_response() {
@ -109,15 +104,22 @@ public abstract class StreamingChatModelListenerIT {
UserMessage userMessage = UserMessage.from("hello");
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("add")
.addParameter("a", INTEGER)
.addParameter("b", INTEGER)
.build();
ToolSpecification toolSpecification = null;
if (supportToolCalls()) {
toolSpecification = ToolSpecification.builder()
.name("add")
.addParameter("a", INTEGER)
.addParameter("b", INTEGER)
.build();
}
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), singletonList(toolSpecification), handler);
if (supportToolCalls()) {
model.generate(singletonList(userMessage), singletonList(toolSpecification), handler);
} else {
model.generate(singletonList(userMessage), handler);
}
AiMessage aiMessage = handler.get().content();
// then
@ -127,7 +129,9 @@ public abstract class StreamingChatModelListenerIT {
assertThat(request.topP()).isEqualTo(topP());
assertThat(request.maxTokens()).isEqualTo(maxTokens());
assertThat(request.messages()).containsExactly(userMessage);
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
if (supportToolCalls()) {
assertThat(request.toolSpecifications()).containsExactly(toolSpecification);
}
ChatModelResponse response = responseReference.get();
if (assertResponseId()) {
@ -137,14 +141,24 @@ public abstract class StreamingChatModelListenerIT {
assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
assertThat(response.finishReason()).isNotNull();
if (assertFinishReason()) {
assertThat(response.finishReason()).isNotNull();
}
assertThat(response.aiMessage()).isEqualTo(aiMessage);
}
protected boolean supportToolCalls() {
return true;
}
protected boolean assertResponseId() {
return true;
}
protected boolean assertFinishReason() {
return true;
}
@Test
protected void should_listen_error() throws Exception {

View File

@ -43,7 +43,7 @@ class GoogleAiGeminiChatModelListenerIT extends ChatModelListenerIT {
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return RuntimeException.class;
}
}

View File

@ -68,6 +68,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>

View File

@ -4,21 +4,29 @@ import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.ollama.spi.OllamaChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.*;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.*;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
/**
* <a href="https://github.com/jmorganca/ollama/blob/main/docs/api.md">Ollama API reference</a>
@ -27,11 +35,14 @@ import static java.time.Duration.ofSeconds;
*/
public class OllamaChatModel implements ChatLanguageModel {
private static final Logger log = LoggerFactory.getLogger(OllamaChatModel.class);
private final OllamaClient client;
private final String modelName;
private final Options options;
private final String format;
private final Integer maxRetries;
private final List<ChatModelListener> listeners;
public OllamaChatModel(String baseUrl,
String modelName,
@ -48,7 +59,8 @@ public class OllamaChatModel implements ChatLanguageModel {
Integer maxRetries,
Map<String, String> customHeaders,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
List<ChatModelListener> listeners) {
this.client = OllamaClient.builder()
.baseUrl(baseUrl)
.timeout(getOrDefault(timeout, ofSeconds(60)))
@ -69,6 +81,7 @@ public class OllamaChatModel implements ChatLanguageModel {
.build();
this.format = format;
this.maxRetries = getOrDefault(maxRetries, 3);
this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList()));
}
public static OllamaChatModelBuilder builder() {
@ -82,26 +95,17 @@ public class OllamaChatModel implements ChatLanguageModel {
public Response<AiMessage> generate(List<ChatMessage> messages) {
ensureNotEmpty(messages, "messages");
ChatRequest request = ChatRequest.builder()
.model(modelName)
.messages(toOllamaMessages(messages))
.options(options)
.format(format)
.stream(false)
.build();
ChatResponse response = withRetry(() -> client.chat(request), maxRetries);
return Response.from(
AiMessage.from(response.getMessage().getContent()),
new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
);
return doGenerate(messages, null);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
return doGenerate(messages, toolSpecifications);
}
private Response<AiMessage> doGenerate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ChatRequest request = ChatRequest.builder()
.model(modelName)
.messages(toOllamaMessages(messages))
@ -111,14 +115,25 @@ public class OllamaChatModel implements ChatLanguageModel {
.tools(toOllamaTools(toolSpecifications))
.build();
ChatResponse response = withRetry(() -> client.chat(request), maxRetries);
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
onListenRequest(listeners, modelListenerRequest, attributes);
return Response.from(
response.getMessage().getToolCalls() != null ?
AiMessage.from(toToolExecutionRequest(response.getMessage().getToolCalls())) :
AiMessage.from(response.getMessage().getContent()),
new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
);
try {
ChatResponse chatResponse = withRetry(() -> client.chat(request), maxRetries);
Response<AiMessage> response = Response.from(
chatResponse.getMessage().getToolCalls() != null ?
AiMessage.from(toToolExecutionRequest(chatResponse.getMessage().getToolCalls())) :
AiMessage.from(chatResponse.getMessage().getContent()),
new TokenUsage(chatResponse.getPromptEvalCount(), chatResponse.getEvalCount())
);
onListenResponse(listeners, response, modelListenerRequest, attributes);
return response;
} catch (Exception e) {
onListenError(listeners, e, modelListenerRequest, null, attributes);
throw e;
}
}
public static class OllamaChatModelBuilder {
@ -139,6 +154,7 @@ public class OllamaChatModel implements ChatLanguageModel {
private Map<String, String> customHeaders;
private Boolean logRequests;
private Boolean logResponses;
private List<ChatModelListener> listeners;
public OllamaChatModelBuilder() {
// This is public so it can be extended
@ -225,6 +241,11 @@ public class OllamaChatModel implements ChatLanguageModel {
return this;
}
public OllamaChatModelBuilder listeners(List<ChatModelListener> listeners) {
this.listeners = listeners;
return this;
}
public OllamaChatModel build() {
return new OllamaChatModel(
baseUrl,
@ -242,7 +263,8 @@ public class OllamaChatModel implements ChatLanguageModel {
maxRetries,
customHeaders,
logRequests,
logResponses
logResponses,
listeners
);
}
}

View File

@ -0,0 +1,118 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Map;
class OllamaChatModelListenerUtils {
private static final Logger log = LoggerFactory.getLogger(OllamaChatModelListenerUtils.class);
private OllamaChatModelListenerUtils() throws InstantiationException {
throw new InstantiationException("Can't instantiate this utility class.");
}
/**
* Processes a listen request by notifying all registered chat model listeners.
*
* @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null.
* @param modelListenerRequest The {@link ChatModelRequest} containing the request details.
* @param attributes A map of additional attributes to be passed to the context.
*/
static void onListenRequest(List<ChatModelListener> listeners, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
ChatModelRequestContext context = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(context);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
}
/**
* Processes a listen response by notifying all registered chat model listeners.
*
* @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null.
* @param response The {@link Response} containing the response details.
* @param modelListenerRequest The original {@link ChatModelRequest} associated with the response.
* @param attributes A map of additional attributes to be passed to the context.
*/
static void onListenResponse(List<ChatModelListener> listeners, Response<AiMessage> response, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
ChatModelResponse modelListenerResponse = createModelListenerResponse(modelListenerRequest.model(), response);
ChatModelResponseContext context = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(context);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
}
/**
* Processes a listen error by notifying all registered chat model listeners.
*
* @param listeners A list of {@link ChatModelListener} instances to be notified. Should not be null.
* @param error Error between chat
* @param modelListenerRequest The original {@link ChatModelRequest} associated with the response.
* @param partialResponse The partial {@link Response} containing cur response details.
* @param attributes A map of additional attributes to be passed to the context.
*/
static void onListenError(List<ChatModelListener> listeners, Throwable error, ChatModelRequest modelListenerRequest, Response<AiMessage> partialResponse, Map<Object, Object> attributes) {
ChatModelResponse partialModelListenerResponse = createModelListenerResponse(modelListenerRequest.model(), partialResponse);
ChatModelErrorContext context = new ChatModelErrorContext(
error,
modelListenerRequest,
partialModelListenerResponse,
attributes
);
listeners.forEach(listener -> {
try {
listener.onError(context);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
}
static ChatModelRequest createModelListenerRequest(ChatRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
Options options = request.getOptions();
return ChatModelRequest.builder()
.model(request.getModel())
.temperature(options.getTemperature())
.topP(options.getTopP())
.maxTokens(options.getNumPredict())
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
}
static ChatModelResponse createModelListenerResponse(String responseModel,
Response<AiMessage> response) {
if (response == null) {
return null;
}
return ChatModelResponse.builder()
.model(responseModel)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build();
}
}

View File

@ -1,8 +1,11 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import okhttp3.Interceptor;
@ -22,10 +25,10 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.model.ollama.OllamaChatModelListenerUtils.*;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.getObjectMapper;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toObject;
import static java.lang.Boolean.TRUE;
@ -106,50 +109,6 @@ class OllamaClient {
public void streamingCompletion(CompletionRequest request, StreamingResponseHandler<String> handler) {
ollamaApi.streamingCompletion(request).enqueue(new Callback<ResponseBody>() {
@Override
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
StringBuilder contentBuilder = new StringBuilder();
while (true) {
byte[] bytes = new byte[1024];
int len = inputStream.read(bytes);
String partialResponse = new String(bytes, 0, len);
if (logStreamingResponses) {
log.debug("Streaming partial response: {}", partialResponse);
}
CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class);
contentBuilder.append(completionResponse.getResponse());
handler.onNext(completionResponse.getResponse());
if (TRUE.equals(completionResponse.getDone())) {
Response<String> response = Response.from(
contentBuilder.toString(),
new TokenUsage(
completionResponse.getPromptEvalCount(),
completionResponse.getEvalCount()
)
);
handler.onComplete(response);
return;
}
}
} catch (Exception e) {
handler.onError(e);
}
}
@Override
public void onFailure(Call<ResponseBody> call, Throwable throwable) {
handler.onError(throwable);
}
});
}
public void streamingChat(ChatRequest request, StreamingResponseHandler<AiMessage> handler) {
ollamaApi.streamingChat(request).enqueue(new Callback<ResponseBody>() {
@Override
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
@ -162,17 +121,16 @@ class OllamaClient {
log.debug("Streaming partial response: {}", partialResponse);
}
ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class);
String content = chatResponse.getMessage().getContent();
contentBuilder.append(content);
handler.onNext(content);
CompletionResponse completionResponse = toObject(partialResponse, CompletionResponse.class);
contentBuilder.append(completionResponse.getResponse());
handler.onNext(completionResponse.getResponse());
if (TRUE.equals(chatResponse.getDone())) {
Response<AiMessage> response = Response.from(
AiMessage.from(contentBuilder.toString()),
if (TRUE.equals(completionResponse.getDone())) {
Response<String> response = Response.from(
contentBuilder.toString(),
new TokenUsage(
chatResponse.getPromptEvalCount(),
chatResponse.getEvalCount()
completionResponse.getPromptEvalCount(),
completionResponse.getEvalCount()
)
);
handler.onComplete(response);
@ -192,6 +150,57 @@ class OllamaClient {
});
}
public void streamingChat(ChatRequest request, StreamingResponseHandler<AiMessage> handler,
List<ChatModelListener> listeners, List<ChatMessage> messages) {
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, new ArrayList<>());
Map<Object, Object> attributes = new ConcurrentHashMap<>();
onListenRequest(listeners, modelListenerRequest, attributes);
OllamaStreamingResponseBuilder responseBuilder = new OllamaStreamingResponseBuilder();
ollamaApi.streamingChat(request).enqueue(new Callback<ResponseBody>() {
@Override
public void onResponse(Call<ResponseBody> call, retrofit2.Response<ResponseBody> retrofitResponse) {
try (InputStream inputStream = retrofitResponse.body().byteStream()) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
while (true) {
String partialResponse = reader.readLine();
if (logStreamingResponses) {
log.debug("Streaming partial response: {}", partialResponse);
}
ChatResponse chatResponse = toObject(partialResponse, ChatResponse.class);
String content = chatResponse.getMessage().getContent();
responseBuilder.append(chatResponse);
handler.onNext(content);
if (TRUE.equals(chatResponse.getDone())) {
Response<AiMessage> response = responseBuilder.build();
handler.onComplete(response);
onListenResponse(listeners, response, modelListenerRequest, attributes);
return;
}
}
}
} catch (Exception e) {
onListenError(listeners, e, modelListenerRequest, responseBuilder.build(), attributes);
handler.onError(e);
}
}
@Override
public void onFailure(Call<ResponseBody> call, Throwable throwable) {
onListenError(listeners, throwable, modelListenerRequest, responseBuilder.build(), attributes);
handler.onError(throwable);
}
});
}
public EmbeddingResponse embed(EmbeddingRequest request) {
try {
retrofit2.Response<EmbeddingResponse> retrofitResponse = ollamaApi.embed(request).execute();
@ -257,6 +266,7 @@ class OllamaClient {
}
}
private RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();

View File

@ -34,6 +34,9 @@ class OllamaMessagesUtils {
}
static List<Tool> toOllamaTools(List<ToolSpecification> toolSpecifications) {
if (toolSpecifications == null) {
return null;
}
return toolSpecifications.stream().map(toolSpecification ->
Tool.builder()
.function(Function.builder()

View File

@ -4,9 +4,11 @@ import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.ollama.spi.OllamaStreamingChatModelBuilderFactory;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -16,6 +18,7 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaMessages;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
/**
* <a href="https://github.com/jmorganca/ollama/blob/main/docs/api.md">Ollama API reference</a>
@ -28,6 +31,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
private final String modelName;
private final Options options;
private final String format;
private final List<ChatModelListener> listeners;
public OllamaStreamingChatModel(String baseUrl,
String modelName,
@ -43,7 +47,8 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
Duration timeout,
Boolean logRequests,
Boolean logResponses,
Map<String, String> customHeaders
Map<String, String> customHeaders,
List<ChatModelListener> listeners
) {
this.client = OllamaClient.builder()
.baseUrl(baseUrl)
@ -64,6 +69,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
.stop(stop)
.build();
this.format = format;
this.listeners = new ArrayList<>(getOrDefault(listeners, emptyList()));
}
public static OllamaStreamingChatModelBuilder builder() {
@ -85,7 +91,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
.stream(true)
.build();
client.streamingChat(request, handler);
client.streamingChat(request, handler, listeners, messages);
}
public static class OllamaStreamingChatModelBuilder {
@ -105,6 +111,7 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
private Map<String, String> customHeaders;
private Boolean logRequests;
private Boolean logResponses;
private List<ChatModelListener> listeners;
public OllamaStreamingChatModelBuilder() {
// This is public so it can be extended
@ -186,6 +193,11 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
return this;
}
public OllamaStreamingChatModelBuilder listeners(List<ChatModelListener> listeners) {
this.listeners = listeners;
return this;
}
public OllamaStreamingChatModel build() {
return new OllamaStreamingChatModel(
baseUrl,
@ -202,7 +214,8 @@ public class OllamaStreamingChatModel implements StreamingChatLanguageModel {
timeout,
logRequests,
logResponses,
customHeaders
customHeaders,
listeners
);
}
}

View File

@ -0,0 +1,44 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
/**
* This class needs to be thread safe because it is called when a streaming result comes back
* and there is no guarantee that this thread will be the same as the one that initiated the request,
* in fact it almost certainly won't be.
*/
class OllamaStreamingResponseBuilder {
private StringBuffer contentBuilder = new StringBuffer();
private volatile TokenUsage tokenUsage;
void append(ChatResponse partialResponse) {
if (partialResponse == null) {
return;
}
if (partialResponse.getEvalCount() != null && partialResponse.getPromptEvalCount() != null) {
this.tokenUsage = new TokenUsage(
partialResponse.getPromptEvalCount(),
partialResponse.getEvalCount()
);
}
String content = partialResponse.getMessage().getContent();
if (content != null) {
contentBuilder.append(content);
}
}
Response<AiMessage> build() {
if (contentBuilder.toString().isEmpty()) {
return null;
}
return Response.from(
AiMessage.from(contentBuilder.toString()),
tokenUsage
);
}
}

View File

@ -14,5 +14,4 @@ class AbstractOllamaToolsLanguageModelInfrastructure {
}
}

View File

@ -0,0 +1,57 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.ChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import static dev.langchain4j.model.ollama.OllamaImage.TOOL_MODEL;
import static java.util.Collections.singletonList;
class OllamaChatModelListenerIT extends ChatModelListenerIT {
@Override
protected ChatLanguageModel createModel(ChatModelListener listener) {
return OllamaChatModel.builder()
.baseUrl(AbstractOllamaToolsLanguageModelInfrastructure.ollama.getEndpoint())
.modelName(TOOL_MODEL)
.temperature(temperature())
.topP(topP())
.numPredict(maxTokens())
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
.build();
}
@Override
protected String modelName() {
return TOOL_MODEL;
}
@Override
protected ChatLanguageModel createFailingModel(ChatModelListener listener) {
return OllamaChatModel.builder()
.baseUrl(AbstractOllamaToolsLanguageModelInfrastructure.ollama.getEndpoint())
.modelName("banana")
.maxRetries(0)
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
.build();
}
@Override
protected Class<? extends Exception> expectedExceptionClass() {
return NullPointerException.class;
}
@Override
protected boolean assertResponseId() {
return false;
}
@Override
protected boolean assertFinishReason() {
return false;
}
}

View File

@ -3,14 +3,20 @@ package dev.langchain4j.model.ollama;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import java.util.concurrent.atomic.AtomicReference;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.fail;
import static org.junit.jupiter.api.Assertions.assertThrows;
/**
* Tests if Ollama can be used via OpenAI API (langchain4j-open-ai module)

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static java.util.Collections.singletonList;
public class OllamaStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
@Override
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
return OllamaStreamingChatModel.builder()
.baseUrl(AbstractOllamaLanguageModelInfrastructure.ollama.getEndpoint())
.modelName(TINY_DOLPHIN_MODEL)
.temperature(temperature())
.topP(topP())
.numPredict(maxTokens())
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
.build();
}
@Override
protected String modelName() {
return TINY_DOLPHIN_MODEL;
}
@Override
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
return OllamaStreamingChatModel.builder()
.baseUrl(AbstractOllamaLanguageModelInfrastructure.ollama.getEndpoint())
.modelName("banana")
.logRequests(true)
.logResponses(true)
.listeners(singletonList(listener))
.build();
}
@Override
protected Class<? extends Exception> expectedExceptionClass() {
return NullPointerException.class;
}
@Override
protected boolean supportToolCalls() {
return false;
}
@Override
protected boolean assertResponseId() {
return false;
}
@Override
protected boolean assertFinishReason() {
return false;
}
}

View File

@ -41,7 +41,7 @@ class OpenAiChatModelListenerIT extends ChatModelListenerIT {
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return OpenAiHttpException.class;
}
}

View File

@ -42,7 +42,7 @@ class OpenAiStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return OpenAiHttpException.class;
}
}

View File

@ -45,7 +45,7 @@ public class VertexAiGeminiChatModelListenerIT extends ChatModelListenerIT {
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return RuntimeException.class;
}
}

View File

@ -46,7 +46,7 @@ public class VertexAiGeminiStreamingChatModelListenerIT extends StreamingChatMod
}
@Override
protected Class<?> expectedExceptionClass() {
protected Class<? extends Exception> expectedExceptionClass() {
return NotFoundException.class;
}
}