Fix #756: Allow blank content in AiMessage, propagate failures into streaming handler (Ollama) (#782)

<!-- Thank you so much for your contribution! -->

## Context
See https://github.com/langchain4j/langchain4j/issues/756

## Change
- Allow creating `AiMessage` with blank ("", " ") content. `null` is
still prohibited.
- In `OllamaStreamingChat/LanguageModel`: propagate failures from
`onResponse()` method into `StreamingResponseHandler.onError()` method

## Checklist
Before submitting this PR, please check the following points:
- [X] I have added unit and integration tests for my change
- [X] All unit and integration tests in the module I have added/changed
are green
- [X] All unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules are green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
(only when a new module is added)
This commit is contained in:
LangChain4j 2024-03-22 09:39:48 +01:00 committed by GitHub
parent 04055c896b
commit da816fd491
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 233 additions and 335 deletions

View File

@ -8,8 +8,8 @@ import java.util.Objects;
import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.quoted;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Arrays.asList;
/**
@ -24,15 +24,17 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given text.
*
* @param text the text of the message.
*/
public AiMessage(String text) {
this.text = ensureNotBlank(text, "text");
this.text = ensureNotNull(text, "text");
this.toolExecutionRequests = null;
}
/**
* Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message.
*/
public AiMessage(List<ToolExecutionRequest> toolExecutionRequests) {
@ -42,6 +44,7 @@ public class AiMessage implements ChatMessage {
/**
* Get the text of the message.
*
* @return the text of the message.
*/
public String text() {
@ -50,6 +53,7 @@ public class AiMessage implements ChatMessage {
/**
* Get the tool execution requests of the message.
*
* @return the tool execution requests of the message.
*/
public List<ToolExecutionRequest> toolExecutionRequests() {
@ -58,6 +62,7 @@ public class AiMessage implements ChatMessage {
/**
* Check if the message has ToolExecutionRequests.
*
* @return true if the message has ToolExecutionRequests, false otherwise.
*/
public boolean hasToolExecutionRequests() {
@ -93,6 +98,7 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given text.
*
* @param text the text of the message.
* @return the new {@link AiMessage}.
*/
@ -102,6 +108,7 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}.
*/
@ -111,6 +118,7 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}.
*/
@ -120,6 +128,7 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given text.
*
* @param text the text of the message.
* @return the new {@link AiMessage}.
*/
@ -129,6 +138,7 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}.
*/
@ -138,6 +148,7 @@ public class AiMessage implements ChatMessage {
/**
* Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}.
*/

View File

@ -116,4 +116,16 @@ class AiMessageTest implements WithAssertions {
}
}
@Test
void should_allow_blank_content() {
assertThat(AiMessage.from("").text()).isEqualTo("");
assertThat(AiMessage.from(" ").text()).isEqualTo(" ");
}
@Test
void should_fail_when_text_is_null() {
assertThatThrownBy(() -> AiMessage.from((String) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("text cannot be null");
}
}

View File

@ -68,6 +68,14 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>

View File

@ -106,8 +106,8 @@ class OllamaClient {
return;
}
}
} catch (IOException e) {
throw new RuntimeException(e);
} catch (Exception e) {
handler.onError(e);
}
}
@ -147,8 +147,8 @@ class OllamaClient {
return;
}
}
} catch (IOException e) {
throw new RuntimeException(e);
} catch (Exception e) {
handler.onError(e);
}
}

View File

@ -0,0 +1,18 @@
package dev.langchain4j.model.ollama;
import static dev.langchain4j.model.ollama.OllamaImage.ALL_MINILM_MODEL;
import static dev.langchain4j.model.ollama.OllamaImage.OLLAMA_IMAGE;
class AbstractOllamaEmbeddingModelInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, ALL_MINILM_MODEL);
static LangChain4jOllamaContainer ollama;
static {
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(ALL_MINILM_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
}

View File

@ -1,33 +0,0 @@
package dev.langchain4j.model.ollama;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.model.Image;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.utility.DockerImageName;
import java.util.List;
public class AbstractOllamaInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.PHI_MODEL);
static LangChain4jOllamaContainer ollama;
static {
ollama = new LangChain4jOllamaContainer(resolveImage(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.PHI_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
static DockerImageName resolveImage(String baseImage, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
}
}

View File

@ -0,0 +1,15 @@
package dev.langchain4j.model.ollama;
class AbstractOllamaLanguageModelInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.TINY_DOLPHIN_MODEL);
static LangChain4jOllamaContainer ollama;
static {
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.TINY_DOLPHIN_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
}

View File

@ -1,6 +1,6 @@
package dev.langchain4j.model.ollama;
public class AbstractOllamaInfrastructureVisionModel {
class AbstractOllamaVisionModelInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.BAKLLAVA_MODEL);
@ -12,5 +12,4 @@ public class AbstractOllamaInfrastructureVisionModel {
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
}

View File

@ -10,9 +10,11 @@ import org.junit.jupiter.api.Test;
import java.time.Duration;
import static dev.langchain4j.model.ollama.OllamaImage.BAKLLAVA_MODEL;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionModel {
class OllamaChatModeVisionModellITInfrastructure extends AbstractOllamaVisionModelInfrastructure {
static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png";
@ -23,7 +25,7 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(3))
.modelName(OllamaImage.BAKLLAVA_MODEL)
.modelName(BAKLLAVA_MODEL)
.temperature(0.0)
.build();
@ -45,7 +47,7 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(3))
.modelName(OllamaImage.BAKLLAVA_MODEL)
.modelName(BAKLLAVA_MODEL)
.temperature(0.0)
.build();
@ -59,5 +61,4 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
assertThat(response.content().text())
.containsIgnoringCase("dice");
}
}

View File

@ -11,14 +11,15 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaChatModelIT extends AbstractOllamaInfrastructure {
class OllamaChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0)
.build();
@ -38,7 +39,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
assertThat(aiMessage.toolExecutionRequests()).isNull();
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -54,7 +55,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
OllamaChatModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -113,7 +114,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
// given
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.format("json")
.temperature(0.0)
.build();

View File

@ -6,7 +6,7 @@ import java.time.Duration;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaClientIT extends AbstractOllamaInfrastructure {
class OllamaClientIT extends AbstractOllamaLanguageModelInfrastructure {
@Test
void should_respond_with_models_list() {
@ -22,12 +22,11 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
// then
assertThat(modelListResponse.getModels().size()).isGreaterThan(0);
assertThat(modelListResponse.getModels().get(0).getName()).isEqualTo("phi:latest");
assertThat(modelListResponse.getModels().get(0).getName()).isEqualTo("tinydolphin:latest");
assertThat(modelListResponse.getModels().get(0).getDigest()).isNotNull();
assertThat(modelListResponse.getModels().get(0).getSize()).isPositive();
}
@Test
void should_respond_with_model_information() {
// given AbstractOllamaInfrastructure
@ -39,16 +38,14 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
.build();
OllamaModelCard modelDetailsResponse = ollamaClient.showInformation(ShowModelInformationRequest.builder()
.name("phi:latest")
.name("tinydolphin:latest")
.build());
// then
assertThat(modelDetailsResponse.getModelfile()).contains("# Modelfile generated by \"ollama show\"");
assertThat(modelDetailsResponse.getParameters()).contains("stop");
assertThat(modelDetailsResponse.getTemplate()).contains("System:");
assertThat(modelDetailsResponse.getTemplate()).contains("<|im_start|>");
assertThat(modelDetailsResponse.getDetails().getFormat()).isEqualTo("gguf");
assertThat(modelDetailsResponse.getDetails().getFamily()).isEqualTo("phi2");
assertThat(modelDetailsResponse.getDetails().getFamily()).isEqualTo("llama");
}
}

View File

@ -5,13 +5,14 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;
import static dev.langchain4j.model.ollama.OllamaImage.ALL_MINILM_MODEL;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaEmbeddingModelIT extends AbstractOllamaInfrastructure {
class OllamaEmbeddingModelIT extends AbstractOllamaEmbeddingModelInfrastructure {
EmbeddingModel model = OllamaEmbeddingModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(ALL_MINILM_MODEL)
.build();
@Test

View File

@ -13,7 +13,9 @@ public class OllamaImage {
static final String BAKLLAVA_MODEL = "bakllava";
static final String PHI_MODEL = "phi";
static final String TINY_DOLPHIN_MODEL = "tinydolphin";
static final String ALL_MINILM_MODEL = "all-minilm";
static DockerImageName resolve(String baseImage, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
@ -24,5 +26,4 @@ public class OllamaImage {
}
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
}
}

View File

@ -2,16 +2,16 @@ package dev.langchain4j.model.ollama;
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
class OllamaLanguageModelIT extends AbstractOllamaLanguageModelInfrastructure {
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0)
.build();
@ -27,14 +27,6 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
// then
assertThat(response.content()).contains("Berlin");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isNull();
}
@Test
@ -45,7 +37,7 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -67,7 +59,7 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
// given
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.format("json")
.temperature(0.0)
.build();

View File

@ -7,7 +7,7 @@ import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaModelsIT extends AbstractOllamaInfrastructure {
class OllamaModelsIT extends AbstractOllamaLanguageModelInfrastructure {
OllamaModels ollamaModels = OllamaModels.builder()
.baseUrl(ollama.getEndpoint())
@ -22,7 +22,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
// then
assertThat(response.content().size()).isGreaterThan(0);
assertThat(response.content().get(0).getName()).isEqualTo("phi:latest");
assertThat(response.content().get(0).getName()).isEqualTo("tinydolphin:latest");
}
@Test
@ -31,7 +31,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
// when
OllamaModel ollamaModel = OllamaModel.builder()
.name("phi:latest")
.name("tinydolphin:latest")
.build();
Response<OllamaModelCard> response = ollamaModels.modelCard(ollamaModel);
@ -40,7 +40,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
assertThat(response.content().getModelfile()).isNotBlank();
assertThat(response.content().getTemplate()).isNotBlank();
assertThat(response.content().getParameters()).isNotBlank();
assertThat(response.content().getDetails().getFamily()).isEqualTo("phi2");
assertThat(response.content().getDetails().getFamily()).isEqualTo("llama");
}
@Test
@ -48,13 +48,12 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
// given AbstractOllamaInfrastructure
// when
Response<OllamaModelCard> response = ollamaModels.modelCard("phi:latest");
Response<OllamaModelCard> response = ollamaModels.modelCard("tinydolphin:latest");
// then
assertThat(response.content().getModelfile()).isNotBlank();
assertThat(response.content().getTemplate()).isNotBlank();
assertThat(response.content().getParameters()).isNotBlank();
assertThat(response.content().getDetails().getFamily()).isEqualTo("phi2");
assertThat(response.content().getDetails().getFamily()).isEqualTo("llama");
}
}

View File

@ -6,6 +6,7 @@ import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
@ -13,55 +14,30 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
class OllamaStreamingChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0)
.build();
@Test
void should_stream_answer() throws Exception {
void should_stream_answer() {
// given
String userMessage = "What is the capital of Germany?";
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
String answer = response.content().text();
// then
assertThat(answer).contains("Berlin");
@ -71,7 +47,7 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
assertThat(aiMessage.toolExecutionRequests()).isNull();
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
assertThat(tokenUsage.inputTokenCount()).isEqualTo(35);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -80,14 +56,14 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
}
@Test
void should_respect_numPredict() throws Exception {
void should_respect_numPredict() {
// given
int numPredict = 1; // max output tokens
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -95,35 +71,10 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(singletonList(userMessage), new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), handler);
Response<AiMessage> response = handler.get();
String answer = response.content().text();
// then
assertThat(answer).doesNotContain("Berlin");
@ -132,44 +83,18 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
assertThat(response.tokenUsage().outputTokenCount()).isBetween(numPredict, numPredict + 2); // bug in Ollama
}
@Test
void should_respect_system_message() throws Exception {
void should_respect_system_message() {
// given
SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German");
UserMessage userMessage = UserMessage.from("I love you");
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(asList(systemMessage, userMessage), new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(asList(systemMessage, userMessage), handler);
Response<AiMessage> response = handler.get();
String answer = response.content().text();
// then
assertThat(answer).containsIgnoringCase("liebe");
@ -177,7 +102,7 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
}
@Test
void should_respond_to_few_shot() throws Exception {
void should_respond_to_few_shot() {
// given
List<ChatMessage> messages = asList(
@ -191,35 +116,10 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
);
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(messages, handler);
Response<AiMessage> response = handler.get();
String answer = response.content().text();
// then
assertThat(answer).startsWith(">>> 8");
@ -227,12 +127,12 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
}
@Test
void should_generate_valid_json() throws Exception {
void should_generate_valid_json() {
// given
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(TINY_DOLPHIN_MODEL)
.format("json")
.temperature(0.0)
.build();
@ -240,38 +140,51 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
String answer = response.content().text();
// then
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
assertThat(response.content().text()).isEqualTo(answer);
}
@Test
void should_propagate_failure_to_handler_onError() throws Exception {
// given
String wrongModelName = "banana";
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(wrongModelName)
.build();
CompletableFuture<Throwable> future = new CompletableFuture<>();
// when
model.generate("does not matter", new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
future.completeExceptionally(new Exception("onNext should never be called"));
}
@Override
public void onComplete(Response<AiMessage> response) {
future.completeExceptionally(new Exception("onComplete should never be called"));
}
@Override
public void onError(Throwable error) {
future.complete(error);
}
});
// then
assertThat(future.get())
.isExactlyInstanceOf(NullPointerException.class)
.hasMessageContaining("is null");
}
}

View File

@ -1,6 +1,7 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
@ -8,60 +9,34 @@ import org.junit.jupiter.api.Test;
import java.util.concurrent.CompletableFuture;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
class OllamaStreamingLanguageModelIT extends AbstractOllamaLanguageModelInfrastructure {
@Test
void should_stream_answer() throws Exception {
void should_stream_answer() {
// given
String prompt = "What is the capital of Germany?";
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(OllamaImage.TINY_DOLPHIN_MODEL)
.temperature(0.0)
.build();
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
model.generate(prompt, new StreamingResponseHandler<String>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
model.generate(prompt, handler);
Response<String> response = handler.get();
String answer = response.content();
// then
assertThat(answer).contains("Berlin");
assertThat(response.content()).isEqualTo(answer);
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -70,14 +45,14 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
}
@Test
void should_respect_numPredict() throws Exception {
void should_respect_numPredict() {
// given
int numPredict = 1; // max output tokens
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(OllamaImage.TINY_DOLPHIN_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -85,35 +60,10 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
String prompt = "What is the capital of Germany?";
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
model.generate(prompt, new StreamingResponseHandler<String>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
model.generate(prompt, handler);
Response<String> response = handler.get();
String answer = response.content();
// then
assertThat(answer).doesNotContain("Berlin");
@ -123,12 +73,12 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
}
@Test
void should_stream_valid_json() throws Exception {
void should_stream_valid_json() {
// given
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.modelName(OllamaImage.TINY_DOLPHIN_MODEL)
.format("json")
.temperature(0.0)
.build();
@ -136,38 +86,51 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old.";
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
model.generate(prompt, new StreamingResponseHandler<String>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
model.generate(prompt, handler);
Response<String> response = handler.get();
String answer = response.content();
// then
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
assertThat(response.content()).isEqualTo(answer);
}
@Test
void should_propagate_failure_to_handler_onError() throws Exception {
// given
String wrongModelName = "banana";
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(wrongModelName)
.build();
CompletableFuture<Throwable> future = new CompletableFuture<>();
// when
model.generate("does not matter", new StreamingResponseHandler<String>() {
@Override
public void onNext(String token) {
future.completeExceptionally(new Exception("onNext should never be called"));
}
@Override
public void onComplete(Response<String> response) {
future.completeExceptionally(new Exception("onComplete should never be called"));
}
@Override
public void onError(Throwable error) {
future.complete(error);
}
});
// then
assertThat(future.get())
.isExactlyInstanceOf(NullPointerException.class)
.hasMessageContaining("is null");
}
}