Fix #756: Allow blank content in AiMessage, propagate failures into streaming handler (Ollama) (#782)

<!-- Thank you so much for your contribution! -->

## Context
See https://github.com/langchain4j/langchain4j/issues/756

## Change
- Allow creating `AiMessage` with blank ("", " ") content. `null` is
still prohibited.
- In `OllamaStreamingChat/LanguageModel`: propagate failures from
`onResponse()` method into `StreamingResponseHandler.onError()` method

## Checklist
Before submitting this PR, please check the following points:
- [X] I have added unit and integration tests for my change
- [X] All unit and integration tests in the module I have added/changed
are green
- [X] All unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules are green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
(only when a new module is added)
This commit is contained in:
LangChain4j 2024-03-22 09:39:48 +01:00 committed by GitHub
parent 04055c896b
commit da816fd491
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 233 additions and 335 deletions

View File

@ -8,8 +8,8 @@ import java.util.Objects;
import static dev.langchain4j.data.message.ChatMessageType.AI; import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.quoted; import static dev.langchain4j.internal.Utils.quoted;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
/** /**
@ -24,15 +24,17 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given text. * Create a new {@link AiMessage} with the given text.
*
* @param text the text of the message. * @param text the text of the message.
*/ */
public AiMessage(String text) { public AiMessage(String text) {
this.text = ensureNotBlank(text, "text"); this.text = ensureNotNull(text, "text");
this.toolExecutionRequests = null; this.toolExecutionRequests = null;
} }
/** /**
* Create a new {@link AiMessage} with the given tool execution requests. * Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message. * @param toolExecutionRequests the tool execution requests of the message.
*/ */
public AiMessage(List<ToolExecutionRequest> toolExecutionRequests) { public AiMessage(List<ToolExecutionRequest> toolExecutionRequests) {
@ -42,6 +44,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Get the text of the message. * Get the text of the message.
*
* @return the text of the message. * @return the text of the message.
*/ */
public String text() { public String text() {
@ -50,6 +53,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Get the tool execution requests of the message. * Get the tool execution requests of the message.
*
* @return the tool execution requests of the message. * @return the tool execution requests of the message.
*/ */
public List<ToolExecutionRequest> toolExecutionRequests() { public List<ToolExecutionRequest> toolExecutionRequests() {
@ -58,6 +62,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Check if the message has ToolExecutionRequests. * Check if the message has ToolExecutionRequests.
*
* @return true if the message has ToolExecutionRequests, false otherwise. * @return true if the message has ToolExecutionRequests, false otherwise.
*/ */
public boolean hasToolExecutionRequests() { public boolean hasToolExecutionRequests() {
@ -93,6 +98,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given text. * Create a new {@link AiMessage} with the given text.
*
* @param text the text of the message. * @param text the text of the message.
* @return the new {@link AiMessage}. * @return the new {@link AiMessage}.
*/ */
@ -102,6 +108,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given tool execution requests. * Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message. * @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}. * @return the new {@link AiMessage}.
*/ */
@ -111,6 +118,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given tool execution requests. * Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message. * @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}. * @return the new {@link AiMessage}.
*/ */
@ -120,6 +128,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given text. * Create a new {@link AiMessage} with the given text.
*
* @param text the text of the message. * @param text the text of the message.
* @return the new {@link AiMessage}. * @return the new {@link AiMessage}.
*/ */
@ -129,6 +138,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given tool execution requests. * Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message. * @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}. * @return the new {@link AiMessage}.
*/ */
@ -138,6 +148,7 @@ public class AiMessage implements ChatMessage {
/** /**
* Create a new {@link AiMessage} with the given tool execution requests. * Create a new {@link AiMessage} with the given tool execution requests.
*
* @param toolExecutionRequests the tool execution requests of the message. * @param toolExecutionRequests the tool execution requests of the message.
* @return the new {@link AiMessage}. * @return the new {@link AiMessage}.
*/ */

View File

@ -116,4 +116,16 @@ class AiMessageTest implements WithAssertions {
} }
} }
@Test
void should_allow_blank_content() {
assertThat(AiMessage.from("").text()).isEqualTo("");
assertThat(AiMessage.from(" ").text()).isEqualTo(" ");
}
@Test
void should_fail_when_text_is_null() {
assertThatThrownBy(() -> AiMessage.from((String) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("text cannot be null");
}
} }

View File

@ -68,6 +68,14 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.tinylog</groupId> <groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId> <artifactId>tinylog-impl</artifactId>

View File

@ -106,8 +106,8 @@ class OllamaClient {
return; return;
} }
} }
} catch (IOException e) { } catch (Exception e) {
throw new RuntimeException(e); handler.onError(e);
} }
} }
@ -147,8 +147,8 @@ class OllamaClient {
return; return;
} }
} }
} catch (IOException e) { } catch (Exception e) {
throw new RuntimeException(e); handler.onError(e);
} }
} }

View File

@ -0,0 +1,18 @@
package dev.langchain4j.model.ollama;
import static dev.langchain4j.model.ollama.OllamaImage.ALL_MINILM_MODEL;
import static dev.langchain4j.model.ollama.OllamaImage.OLLAMA_IMAGE;
class AbstractOllamaEmbeddingModelInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, ALL_MINILM_MODEL);
static LangChain4jOllamaContainer ollama;
static {
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(ALL_MINILM_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
}

View File

@ -1,33 +0,0 @@
package dev.langchain4j.model.ollama;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.model.Image;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.utility.DockerImageName;
import java.util.List;
public class AbstractOllamaInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.PHI_MODEL);
static LangChain4jOllamaContainer ollama;
static {
ollama = new LangChain4jOllamaContainer(resolveImage(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.PHI_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
static DockerImageName resolveImage(String baseImage, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
}
}

View File

@ -0,0 +1,15 @@
package dev.langchain4j.model.ollama;
class AbstractOllamaLanguageModelInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.TINY_DOLPHIN_MODEL);
static LangChain4jOllamaContainer ollama;
static {
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.TINY_DOLPHIN_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
}

View File

@ -1,6 +1,6 @@
package dev.langchain4j.model.ollama; package dev.langchain4j.model.ollama;
public class AbstractOllamaInfrastructureVisionModel { class AbstractOllamaVisionModelInfrastructure {
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.BAKLLAVA_MODEL); private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.BAKLLAVA_MODEL);
@ -12,5 +12,4 @@ public class AbstractOllamaInfrastructureVisionModel {
ollama.start(); ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE); ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
} }
} }

View File

@ -10,9 +10,11 @@ import org.junit.jupiter.api.Test;
import java.time.Duration; import java.time.Duration;
import static dev.langchain4j.model.ollama.OllamaImage.BAKLLAVA_MODEL;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionModel { class OllamaChatModeVisionModellITInfrastructure extends AbstractOllamaVisionModelInfrastructure {
static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png"; static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"; static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png";
@ -23,7 +25,7 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
ChatLanguageModel model = OllamaChatModel.builder() ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(3)) .timeout(Duration.ofMinutes(3))
.modelName(OllamaImage.BAKLLAVA_MODEL) .modelName(BAKLLAVA_MODEL)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -45,7 +47,7 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
ChatLanguageModel model = OllamaChatModel.builder() ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(3)) .timeout(Duration.ofMinutes(3))
.modelName(OllamaImage.BAKLLAVA_MODEL) .modelName(BAKLLAVA_MODEL)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -59,5 +61,4 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
assertThat(response.content().text()) assertThat(response.content().text())
.containsIgnoringCase("dice"); .containsIgnoringCase("dice");
} }
} }

View File

@ -11,14 +11,15 @@ import org.junit.jupiter.api.Test;
import java.util.List; import java.util.List;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaChatModelIT extends AbstractOllamaInfrastructure { class OllamaChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
ChatLanguageModel model = OllamaChatModel.builder() ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -38,7 +39,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
assertThat(aiMessage.toolExecutionRequests()).isNull(); assertThat(aiMessage.toolExecutionRequests()).isNull();
TokenUsage tokenUsage = response.tokenUsage(); TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38); assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount()) assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -54,7 +55,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
OllamaChatModel model = OllamaChatModel.builder() OllamaChatModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.numPredict(numPredict) .numPredict(numPredict)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -113,7 +114,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
// given // given
ChatLanguageModel model = OllamaChatModel.builder() ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.format("json") .format("json")
.temperature(0.0) .temperature(0.0)
.build(); .build();

View File

@ -6,7 +6,7 @@ import java.time.Duration;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaClientIT extends AbstractOllamaInfrastructure { class OllamaClientIT extends AbstractOllamaLanguageModelInfrastructure {
@Test @Test
void should_respond_with_models_list() { void should_respond_with_models_list() {
@ -22,12 +22,11 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
// then // then
assertThat(modelListResponse.getModels().size()).isGreaterThan(0); assertThat(modelListResponse.getModels().size()).isGreaterThan(0);
assertThat(modelListResponse.getModels().get(0).getName()).isEqualTo("phi:latest"); assertThat(modelListResponse.getModels().get(0).getName()).isEqualTo("tinydolphin:latest");
assertThat(modelListResponse.getModels().get(0).getDigest()).isNotNull(); assertThat(modelListResponse.getModels().get(0).getDigest()).isNotNull();
assertThat(modelListResponse.getModels().get(0).getSize()).isPositive(); assertThat(modelListResponse.getModels().get(0).getSize()).isPositive();
} }
@Test @Test
void should_respond_with_model_information() { void should_respond_with_model_information() {
// given AbstractOllamaInfrastructure // given AbstractOllamaInfrastructure
@ -39,16 +38,14 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
.build(); .build();
OllamaModelCard modelDetailsResponse = ollamaClient.showInformation(ShowModelInformationRequest.builder() OllamaModelCard modelDetailsResponse = ollamaClient.showInformation(ShowModelInformationRequest.builder()
.name("phi:latest") .name("tinydolphin:latest")
.build()); .build());
// then // then
assertThat(modelDetailsResponse.getModelfile()).contains("# Modelfile generated by \"ollama show\""); assertThat(modelDetailsResponse.getModelfile()).contains("# Modelfile generated by \"ollama show\"");
assertThat(modelDetailsResponse.getParameters()).contains("stop"); assertThat(modelDetailsResponse.getParameters()).contains("stop");
assertThat(modelDetailsResponse.getTemplate()).contains("System:"); assertThat(modelDetailsResponse.getTemplate()).contains("<|im_start|>");
assertThat(modelDetailsResponse.getDetails().getFormat()).isEqualTo("gguf"); assertThat(modelDetailsResponse.getDetails().getFormat()).isEqualTo("gguf");
assertThat(modelDetailsResponse.getDetails().getFamily()).isEqualTo("phi2"); assertThat(modelDetailsResponse.getDetails().getFamily()).isEqualTo("llama");
} }
} }

View File

@ -5,13 +5,14 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static dev.langchain4j.model.ollama.OllamaImage.ALL_MINILM_MODEL;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaEmbeddingModelIT extends AbstractOllamaInfrastructure { class OllamaEmbeddingModelIT extends AbstractOllamaEmbeddingModelInfrastructure {
EmbeddingModel model = OllamaEmbeddingModel.builder() EmbeddingModel model = OllamaEmbeddingModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(ALL_MINILM_MODEL)
.build(); .build();
@Test @Test

View File

@ -13,7 +13,9 @@ public class OllamaImage {
static final String BAKLLAVA_MODEL = "bakllava"; static final String BAKLLAVA_MODEL = "bakllava";
static final String PHI_MODEL = "phi"; static final String TINY_DOLPHIN_MODEL = "tinydolphin";
static final String ALL_MINILM_MODEL = "all-minilm";
static DockerImageName resolve(String baseImage, String localImageName) { static DockerImageName resolve(String baseImage, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(baseImage); DockerImageName dockerImageName = DockerImageName.parse(baseImage);
@ -24,5 +26,4 @@ public class OllamaImage {
} }
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage); return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
} }
} }

View File

@ -2,16 +2,16 @@ package dev.langchain4j.model.ollama;
import dev.langchain4j.model.language.LanguageModel; import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaLanguageModelIT extends AbstractOllamaInfrastructure { class OllamaLanguageModelIT extends AbstractOllamaLanguageModelInfrastructure {
LanguageModel model = OllamaLanguageModel.builder() LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -27,14 +27,6 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
// then // then
assertThat(response.content()).contains("Berlin"); assertThat(response.content()).contains("Berlin");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isNull();
} }
@Test @Test
@ -45,7 +37,7 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
LanguageModel model = OllamaLanguageModel.builder() LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.numPredict(numPredict) .numPredict(numPredict)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -67,7 +59,7 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
// given // given
LanguageModel model = OllamaLanguageModel.builder() LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.format("json") .format("json")
.temperature(0.0) .temperature(0.0)
.build(); .build();

View File

@ -7,7 +7,7 @@ import java.util.List;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaModelsIT extends AbstractOllamaInfrastructure { class OllamaModelsIT extends AbstractOllamaLanguageModelInfrastructure {
OllamaModels ollamaModels = OllamaModels.builder() OllamaModels ollamaModels = OllamaModels.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
@ -22,7 +22,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
// then // then
assertThat(response.content().size()).isGreaterThan(0); assertThat(response.content().size()).isGreaterThan(0);
assertThat(response.content().get(0).getName()).isEqualTo("phi:latest"); assertThat(response.content().get(0).getName()).isEqualTo("tinydolphin:latest");
} }
@Test @Test
@ -31,7 +31,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
// when // when
OllamaModel ollamaModel = OllamaModel.builder() OllamaModel ollamaModel = OllamaModel.builder()
.name("phi:latest") .name("tinydolphin:latest")
.build(); .build();
Response<OllamaModelCard> response = ollamaModels.modelCard(ollamaModel); Response<OllamaModelCard> response = ollamaModels.modelCard(ollamaModel);
@ -40,7 +40,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
assertThat(response.content().getModelfile()).isNotBlank(); assertThat(response.content().getModelfile()).isNotBlank();
assertThat(response.content().getTemplate()).isNotBlank(); assertThat(response.content().getTemplate()).isNotBlank();
assertThat(response.content().getParameters()).isNotBlank(); assertThat(response.content().getParameters()).isNotBlank();
assertThat(response.content().getDetails().getFamily()).isEqualTo("phi2"); assertThat(response.content().getDetails().getFamily()).isEqualTo("llama");
} }
@Test @Test
@ -48,13 +48,12 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
// given AbstractOllamaInfrastructure // given AbstractOllamaInfrastructure
// when // when
Response<OllamaModelCard> response = ollamaModels.modelCard("phi:latest"); Response<OllamaModelCard> response = ollamaModels.modelCard("tinydolphin:latest");
// then // then
assertThat(response.content().getModelfile()).isNotBlank(); assertThat(response.content().getModelfile()).isNotBlank();
assertThat(response.content().getTemplate()).isNotBlank(); assertThat(response.content().getTemplate()).isNotBlank();
assertThat(response.content().getParameters()).isNotBlank(); assertThat(response.content().getParameters()).isNotBlank();
assertThat(response.content().getDetails().getFamily()).isEqualTo("phi2"); assertThat(response.content().getDetails().getFamily()).isEqualTo("llama");
} }
} }

View File

@ -6,6 +6,7 @@ import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -13,55 +14,30 @@ import org.junit.jupiter.api.Test;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure { class OllamaStreamingChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@Test @Test
void should_stream_answer() throws Exception { void should_stream_answer() {
// given // given
String userMessage = "What is the capital of Germany?"; String userMessage = "What is the capital of Germany?";
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() { String answer = response.content().text();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).contains("Berlin"); assertThat(answer).contains("Berlin");
@ -71,7 +47,7 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
assertThat(aiMessage.toolExecutionRequests()).isNull(); assertThat(aiMessage.toolExecutionRequests()).isNull();
TokenUsage tokenUsage = response.tokenUsage(); TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38); assertThat(tokenUsage.inputTokenCount()).isEqualTo(35);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount()) assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -80,14 +56,14 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
} }
@Test @Test
void should_respect_numPredict() throws Exception { void should_respect_numPredict() {
// given // given
int numPredict = 1; // max output tokens int numPredict = 1; // max output tokens
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.numPredict(numPredict) .numPredict(numPredict)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -95,35 +71,10 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
UserMessage userMessage = UserMessage.from("What is the capital of Germany?"); UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); model.generate(singletonList(userMessage), handler);
Response<AiMessage> response = handler.get();
model.generate(singletonList(userMessage), new StreamingResponseHandler<AiMessage>() { String answer = response.content().text();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).doesNotContain("Berlin"); assertThat(answer).doesNotContain("Berlin");
@ -132,44 +83,18 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
assertThat(response.tokenUsage().outputTokenCount()).isBetween(numPredict, numPredict + 2); // bug in Ollama assertThat(response.tokenUsage().outputTokenCount()).isBetween(numPredict, numPredict + 2); // bug in Ollama
} }
@Test @Test
void should_respect_system_message() throws Exception { void should_respect_system_message() {
// given // given
SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German"); SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German");
UserMessage userMessage = UserMessage.from("I love you"); UserMessage userMessage = UserMessage.from("I love you");
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); model.generate(asList(systemMessage, userMessage), handler);
Response<AiMessage> response = handler.get();
model.generate(asList(systemMessage, userMessage), new StreamingResponseHandler<AiMessage>() { String answer = response.content().text();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).containsIgnoringCase("liebe"); assertThat(answer).containsIgnoringCase("liebe");
@ -177,7 +102,7 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
} }
@Test @Test
void should_respond_to_few_shot() throws Exception { void should_respond_to_few_shot() {
// given // given
List<ChatMessage> messages = asList( List<ChatMessage> messages = asList(
@ -191,35 +116,10 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
); );
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); model.generate(messages, handler);
Response<AiMessage> response = handler.get();
model.generate(messages, new StreamingResponseHandler<AiMessage>() { String answer = response.content().text();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).startsWith(">>> 8"); assertThat(answer).startsWith(">>> 8");
@ -227,12 +127,12 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
} }
@Test @Test
void should_generate_valid_json() throws Exception { void should_generate_valid_json() {
// given // given
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(TINY_DOLPHIN_MODEL)
.format("json") .format("json")
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -240,38 +140,51 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old."; String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() { String answer = response.content().text();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}"); assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
assertThat(response.content().text()).isEqualTo(answer); assertThat(response.content().text()).isEqualTo(answer);
} }
@Test
void should_propagate_failure_to_handler_onError() throws Exception {
// given
String wrongModelName = "banana";
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(wrongModelName)
.build();
CompletableFuture<Throwable> future = new CompletableFuture<>();
// when
model.generate("does not matter", new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
future.completeExceptionally(new Exception("onNext should never be called"));
}
@Override
public void onComplete(Response<AiMessage> response) {
future.completeExceptionally(new Exception("onComplete should never be called"));
}
@Override
public void onError(Throwable error) {
future.complete(error);
}
});
// then
assertThat(future.get())
.isExactlyInstanceOf(NullPointerException.class)
.hasMessageContaining("is null");
}
} }

View File

@ -1,6 +1,7 @@
package dev.langchain4j.model.ollama; package dev.langchain4j.model.ollama;
import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel; import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
@ -8,60 +9,34 @@ import org.junit.jupiter.api.Test;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure { class OllamaStreamingLanguageModelIT extends AbstractOllamaLanguageModelInfrastructure {
@Test @Test
void should_stream_answer() throws Exception { void should_stream_answer() {
// given // given
String prompt = "What is the capital of Germany?"; String prompt = "What is the capital of Germany?";
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(OllamaImage.TINY_DOLPHIN_MODEL)
.temperature(0.0) .temperature(0.0)
.build(); .build();
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>(); model.generate(prompt, handler);
Response<String> response = handler.get();
model.generate(prompt, new StreamingResponseHandler<String>() { String answer = response.content();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).contains("Berlin"); assertThat(answer).contains("Berlin");
assertThat(response.content()).isEqualTo(answer); assertThat(response.content()).isEqualTo(answer);
TokenUsage tokenUsage = response.tokenUsage(); TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38); assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount()) assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -70,14 +45,14 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
} }
@Test @Test
void should_respect_numPredict() throws Exception { void should_respect_numPredict() {
// given // given
int numPredict = 1; // max output tokens int numPredict = 1; // max output tokens
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(OllamaImage.TINY_DOLPHIN_MODEL)
.numPredict(numPredict) .numPredict(numPredict)
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -85,35 +60,10 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
String prompt = "What is the capital of Germany?"; String prompt = "What is the capital of Germany?";
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>(); model.generate(prompt, handler);
Response<String> response = handler.get();
model.generate(prompt, new StreamingResponseHandler<String>() { String answer = response.content();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).doesNotContain("Berlin"); assertThat(answer).doesNotContain("Berlin");
@ -123,12 +73,12 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
} }
@Test @Test
void should_stream_valid_json() throws Exception { void should_stream_valid_json() {
// given // given
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder() StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint()) .baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL) .modelName(OllamaImage.TINY_DOLPHIN_MODEL)
.format("json") .format("json")
.temperature(0.0) .temperature(0.0)
.build(); .build();
@ -136,38 +86,51 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old."; String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old.";
// when // when
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>(); model.generate(prompt, handler);
Response<String> response = handler.get();
model.generate(prompt, new StreamingResponseHandler<String>() { String answer = response.content();
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);
// then // then
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}"); assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
assertThat(response.content()).isEqualTo(answer); assertThat(response.content()).isEqualTo(answer);
} }
@Test
void should_propagate_failure_to_handler_onError() throws Exception {
// given
String wrongModelName = "banana";
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(wrongModelName)
.build();
CompletableFuture<Throwable> future = new CompletableFuture<>();
// when
model.generate("does not matter", new StreamingResponseHandler<String>() {
@Override
public void onNext(String token) {
future.completeExceptionally(new Exception("onNext should never be called"));
}
@Override
public void onComplete(Response<String> response) {
future.completeExceptionally(new Exception("onComplete should never be called"));
}
@Override
public void onError(Throwable error) {
future.complete(error);
}
});
// then
assertThat(future.get())
.isExactlyInstanceOf(NullPointerException.class)
.hasMessageContaining("is null");
}
} }