Fix #756: Allow blank content in AiMessage, propagate failures into streaming handler (Ollama) (#782)
<!-- Thank you so much for your contribution! --> ## Context See https://github.com/langchain4j/langchain4j/issues/756 ## Change - Allow creating `AiMessage` with blank ("", " ") content. `null` is still prohibited. - In `OllamaStreamingChat/LanguageModel`: propagate failures from `onResponse()` method into `StreamingResponseHandler.onError()` method ## Checklist Before submitting this PR, please check the following points: - [X] I have added unit and integration tests for my change - [X] All unit and integration tests in the module I have added/changed are green - [X] All unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules are green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) (only when a new module is added)
This commit is contained in:
parent
04055c896b
commit
da816fd491
|
@ -8,8 +8,8 @@ import java.util.Objects;
|
||||||
import static dev.langchain4j.data.message.ChatMessageType.AI;
|
import static dev.langchain4j.data.message.ChatMessageType.AI;
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
import static dev.langchain4j.internal.Utils.quoted;
|
import static dev.langchain4j.internal.Utils.quoted;
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||||
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -24,15 +24,17 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given text.
|
* Create a new {@link AiMessage} with the given text.
|
||||||
|
*
|
||||||
* @param text the text of the message.
|
* @param text the text of the message.
|
||||||
*/
|
*/
|
||||||
public AiMessage(String text) {
|
public AiMessage(String text) {
|
||||||
this.text = ensureNotBlank(text, "text");
|
this.text = ensureNotNull(text, "text");
|
||||||
this.toolExecutionRequests = null;
|
this.toolExecutionRequests = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given tool execution requests.
|
* Create a new {@link AiMessage} with the given tool execution requests.
|
||||||
|
*
|
||||||
* @param toolExecutionRequests the tool execution requests of the message.
|
* @param toolExecutionRequests the tool execution requests of the message.
|
||||||
*/
|
*/
|
||||||
public AiMessage(List<ToolExecutionRequest> toolExecutionRequests) {
|
public AiMessage(List<ToolExecutionRequest> toolExecutionRequests) {
|
||||||
|
@ -42,6 +44,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the text of the message.
|
* Get the text of the message.
|
||||||
|
*
|
||||||
* @return the text of the message.
|
* @return the text of the message.
|
||||||
*/
|
*/
|
||||||
public String text() {
|
public String text() {
|
||||||
|
@ -50,6 +53,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the tool execution requests of the message.
|
* Get the tool execution requests of the message.
|
||||||
|
*
|
||||||
* @return the tool execution requests of the message.
|
* @return the tool execution requests of the message.
|
||||||
*/
|
*/
|
||||||
public List<ToolExecutionRequest> toolExecutionRequests() {
|
public List<ToolExecutionRequest> toolExecutionRequests() {
|
||||||
|
@ -58,6 +62,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if the message has ToolExecutionRequests.
|
* Check if the message has ToolExecutionRequests.
|
||||||
|
*
|
||||||
* @return true if the message has ToolExecutionRequests, false otherwise.
|
* @return true if the message has ToolExecutionRequests, false otherwise.
|
||||||
*/
|
*/
|
||||||
public boolean hasToolExecutionRequests() {
|
public boolean hasToolExecutionRequests() {
|
||||||
|
@ -93,6 +98,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given text.
|
* Create a new {@link AiMessage} with the given text.
|
||||||
|
*
|
||||||
* @param text the text of the message.
|
* @param text the text of the message.
|
||||||
* @return the new {@link AiMessage}.
|
* @return the new {@link AiMessage}.
|
||||||
*/
|
*/
|
||||||
|
@ -102,6 +108,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given tool execution requests.
|
* Create a new {@link AiMessage} with the given tool execution requests.
|
||||||
|
*
|
||||||
* @param toolExecutionRequests the tool execution requests of the message.
|
* @param toolExecutionRequests the tool execution requests of the message.
|
||||||
* @return the new {@link AiMessage}.
|
* @return the new {@link AiMessage}.
|
||||||
*/
|
*/
|
||||||
|
@ -111,6 +118,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given tool execution requests.
|
* Create a new {@link AiMessage} with the given tool execution requests.
|
||||||
|
*
|
||||||
* @param toolExecutionRequests the tool execution requests of the message.
|
* @param toolExecutionRequests the tool execution requests of the message.
|
||||||
* @return the new {@link AiMessage}.
|
* @return the new {@link AiMessage}.
|
||||||
*/
|
*/
|
||||||
|
@ -120,6 +128,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given text.
|
* Create a new {@link AiMessage} with the given text.
|
||||||
|
*
|
||||||
* @param text the text of the message.
|
* @param text the text of the message.
|
||||||
* @return the new {@link AiMessage}.
|
* @return the new {@link AiMessage}.
|
||||||
*/
|
*/
|
||||||
|
@ -129,6 +138,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given tool execution requests.
|
* Create a new {@link AiMessage} with the given tool execution requests.
|
||||||
|
*
|
||||||
* @param toolExecutionRequests the tool execution requests of the message.
|
* @param toolExecutionRequests the tool execution requests of the message.
|
||||||
* @return the new {@link AiMessage}.
|
* @return the new {@link AiMessage}.
|
||||||
*/
|
*/
|
||||||
|
@ -138,6 +148,7 @@ public class AiMessage implements ChatMessage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new {@link AiMessage} with the given tool execution requests.
|
* Create a new {@link AiMessage} with the given tool execution requests.
|
||||||
|
*
|
||||||
* @param toolExecutionRequests the tool execution requests of the message.
|
* @param toolExecutionRequests the tool execution requests of the message.
|
||||||
* @return the new {@link AiMessage}.
|
* @return the new {@link AiMessage}.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -116,4 +116,16 @@ class AiMessageTest implements WithAssertions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_allow_blank_content() {
|
||||||
|
assertThat(AiMessage.from("").text()).isEqualTo("");
|
||||||
|
assertThat(AiMessage.from(" ").text()).isEqualTo(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_fail_when_text_is_null() {
|
||||||
|
assertThatThrownBy(() -> AiMessage.from((String) null))
|
||||||
|
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||||
|
.hasMessage("text cannot be null");
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -68,6 +68,14 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-core</artifactId>
|
||||||
|
<classifier>tests</classifier>
|
||||||
|
<type>test-jar</type>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.tinylog</groupId>
|
<groupId>org.tinylog</groupId>
|
||||||
<artifactId>tinylog-impl</artifactId>
|
<artifactId>tinylog-impl</artifactId>
|
||||||
|
|
|
@ -106,8 +106,8 @@ class OllamaClient {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
handler.onError(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,8 +147,8 @@ class OllamaClient {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
handler.onError(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
package dev.langchain4j.model.ollama;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.ALL_MINILM_MODEL;
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.OLLAMA_IMAGE;
|
||||||
|
|
||||||
|
class AbstractOllamaEmbeddingModelInfrastructure {
|
||||||
|
|
||||||
|
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, ALL_MINILM_MODEL);
|
||||||
|
|
||||||
|
static LangChain4jOllamaContainer ollama;
|
||||||
|
|
||||||
|
static {
|
||||||
|
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
|
||||||
|
.withModel(ALL_MINILM_MODEL);
|
||||||
|
ollama.start();
|
||||||
|
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,33 +0,0 @@
|
||||||
package dev.langchain4j.model.ollama;
|
|
||||||
|
|
||||||
import com.github.dockerjava.api.DockerClient;
|
|
||||||
import com.github.dockerjava.api.model.Image;
|
|
||||||
import org.testcontainers.DockerClientFactory;
|
|
||||||
import org.testcontainers.utility.DockerImageName;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class AbstractOllamaInfrastructure {
|
|
||||||
|
|
||||||
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.PHI_MODEL);
|
|
||||||
|
|
||||||
static LangChain4jOllamaContainer ollama;
|
|
||||||
|
|
||||||
static {
|
|
||||||
ollama = new LangChain4jOllamaContainer(resolveImage(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
|
|
||||||
.withModel(OllamaImage.PHI_MODEL);
|
|
||||||
ollama.start();
|
|
||||||
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
|
|
||||||
}
|
|
||||||
|
|
||||||
static DockerImageName resolveImage(String baseImage, String localImageName) {
|
|
||||||
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
|
|
||||||
DockerClient dockerClient = DockerClientFactory.instance().client();
|
|
||||||
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
|
|
||||||
if (images.isEmpty()) {
|
|
||||||
return dockerImageName;
|
|
||||||
}
|
|
||||||
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package dev.langchain4j.model.ollama;
|
||||||
|
|
||||||
|
class AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
|
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.TINY_DOLPHIN_MODEL);
|
||||||
|
|
||||||
|
static LangChain4jOllamaContainer ollama;
|
||||||
|
|
||||||
|
static {
|
||||||
|
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
|
||||||
|
.withModel(OllamaImage.TINY_DOLPHIN_MODEL);
|
||||||
|
ollama.start();
|
||||||
|
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
package dev.langchain4j.model.ollama;
|
package dev.langchain4j.model.ollama;
|
||||||
|
|
||||||
public class AbstractOllamaInfrastructureVisionModel {
|
class AbstractOllamaVisionModelInfrastructure {
|
||||||
|
|
||||||
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.BAKLLAVA_MODEL);
|
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.BAKLLAVA_MODEL);
|
||||||
|
|
||||||
|
@ -12,5 +12,4 @@ public class AbstractOllamaInfrastructureVisionModel {
|
||||||
ollama.start();
|
ollama.start();
|
||||||
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
|
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -10,9 +10,11 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.BAKLLAVA_MODEL;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionModel {
|
class OllamaChatModeVisionModellITInfrastructure extends AbstractOllamaVisionModelInfrastructure {
|
||||||
|
|
||||||
static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
|
static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
|
||||||
static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png";
|
static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png";
|
||||||
|
|
||||||
|
@ -23,7 +25,7 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
|
||||||
ChatLanguageModel model = OllamaChatModel.builder()
|
ChatLanguageModel model = OllamaChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.timeout(Duration.ofMinutes(3))
|
.timeout(Duration.ofMinutes(3))
|
||||||
.modelName(OllamaImage.BAKLLAVA_MODEL)
|
.modelName(BAKLLAVA_MODEL)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -45,7 +47,7 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
|
||||||
ChatLanguageModel model = OllamaChatModel.builder()
|
ChatLanguageModel model = OllamaChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.timeout(Duration.ofMinutes(3))
|
.timeout(Duration.ofMinutes(3))
|
||||||
.modelName(OllamaImage.BAKLLAVA_MODEL)
|
.modelName(BAKLLAVA_MODEL)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -59,5 +61,4 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
|
||||||
assertThat(response.content().text())
|
assertThat(response.content().text())
|
||||||
.containsIgnoringCase("dice");
|
.containsIgnoringCase("dice");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -11,14 +11,15 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaChatModelIT extends AbstractOllamaInfrastructure {
|
class OllamaChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
ChatLanguageModel model = OllamaChatModel.builder()
|
ChatLanguageModel model = OllamaChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
assertThat(aiMessage.toolExecutionRequests()).isNull();
|
assertThat(aiMessage.toolExecutionRequests()).isNull();
|
||||||
|
|
||||||
TokenUsage tokenUsage = response.tokenUsage();
|
TokenUsage tokenUsage = response.tokenUsage();
|
||||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
|
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
|
||||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||||
assertThat(tokenUsage.totalTokenCount())
|
assertThat(tokenUsage.totalTokenCount())
|
||||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||||
|
@ -54,7 +55,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
|
|
||||||
OllamaChatModel model = OllamaChatModel.builder()
|
OllamaChatModel model = OllamaChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.numPredict(numPredict)
|
.numPredict(numPredict)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
@ -113,7 +114,7 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
// given
|
// given
|
||||||
ChatLanguageModel model = OllamaChatModel.builder()
|
ChatLanguageModel model = OllamaChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.format("json")
|
.format("json")
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -6,7 +6,7 @@ import java.time.Duration;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaClientIT extends AbstractOllamaInfrastructure {
|
class OllamaClientIT extends AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_respond_with_models_list() {
|
void should_respond_with_models_list() {
|
||||||
|
@ -22,12 +22,11 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(modelListResponse.getModels().size()).isGreaterThan(0);
|
assertThat(modelListResponse.getModels().size()).isGreaterThan(0);
|
||||||
assertThat(modelListResponse.getModels().get(0).getName()).isEqualTo("phi:latest");
|
assertThat(modelListResponse.getModels().get(0).getName()).isEqualTo("tinydolphin:latest");
|
||||||
assertThat(modelListResponse.getModels().get(0).getDigest()).isNotNull();
|
assertThat(modelListResponse.getModels().get(0).getDigest()).isNotNull();
|
||||||
assertThat(modelListResponse.getModels().get(0).getSize()).isPositive();
|
assertThat(modelListResponse.getModels().get(0).getSize()).isPositive();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_respond_with_model_information() {
|
void should_respond_with_model_information() {
|
||||||
// given AbstractOllamaInfrastructure
|
// given AbstractOllamaInfrastructure
|
||||||
|
@ -39,16 +38,14 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
OllamaModelCard modelDetailsResponse = ollamaClient.showInformation(ShowModelInformationRequest.builder()
|
OllamaModelCard modelDetailsResponse = ollamaClient.showInformation(ShowModelInformationRequest.builder()
|
||||||
.name("phi:latest")
|
.name("tinydolphin:latest")
|
||||||
.build());
|
.build());
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(modelDetailsResponse.getModelfile()).contains("# Modelfile generated by \"ollama show\"");
|
assertThat(modelDetailsResponse.getModelfile()).contains("# Modelfile generated by \"ollama show\"");
|
||||||
assertThat(modelDetailsResponse.getParameters()).contains("stop");
|
assertThat(modelDetailsResponse.getParameters()).contains("stop");
|
||||||
assertThat(modelDetailsResponse.getTemplate()).contains("System:");
|
assertThat(modelDetailsResponse.getTemplate()).contains("<|im_start|>");
|
||||||
assertThat(modelDetailsResponse.getDetails().getFormat()).isEqualTo("gguf");
|
assertThat(modelDetailsResponse.getDetails().getFormat()).isEqualTo("gguf");
|
||||||
assertThat(modelDetailsResponse.getDetails().getFamily()).isEqualTo("phi2");
|
assertThat(modelDetailsResponse.getDetails().getFamily()).isEqualTo("llama");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -5,13 +5,14 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.ALL_MINILM_MODEL;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaEmbeddingModelIT extends AbstractOllamaInfrastructure {
|
class OllamaEmbeddingModelIT extends AbstractOllamaEmbeddingModelInfrastructure {
|
||||||
|
|
||||||
EmbeddingModel model = OllamaEmbeddingModel.builder()
|
EmbeddingModel model = OllamaEmbeddingModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(ALL_MINILM_MODEL)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -13,7 +13,9 @@ public class OllamaImage {
|
||||||
|
|
||||||
static final String BAKLLAVA_MODEL = "bakllava";
|
static final String BAKLLAVA_MODEL = "bakllava";
|
||||||
|
|
||||||
static final String PHI_MODEL = "phi";
|
static final String TINY_DOLPHIN_MODEL = "tinydolphin";
|
||||||
|
|
||||||
|
static final String ALL_MINILM_MODEL = "all-minilm";
|
||||||
|
|
||||||
static DockerImageName resolve(String baseImage, String localImageName) {
|
static DockerImageName resolve(String baseImage, String localImageName) {
|
||||||
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
|
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
|
||||||
|
@ -24,5 +26,4 @@ public class OllamaImage {
|
||||||
}
|
}
|
||||||
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
|
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,16 +2,16 @@ package dev.langchain4j.model.ollama;
|
||||||
|
|
||||||
import dev.langchain4j.model.language.LanguageModel;
|
import dev.langchain4j.model.language.LanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
|
class OllamaLanguageModelIT extends AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
LanguageModel model = OllamaLanguageModel.builder()
|
LanguageModel model = OllamaLanguageModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -27,14 +27,6 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(response.content()).contains("Berlin");
|
assertThat(response.content()).contains("Berlin");
|
||||||
|
|
||||||
TokenUsage tokenUsage = response.tokenUsage();
|
|
||||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
|
|
||||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
|
||||||
assertThat(tokenUsage.totalTokenCount())
|
|
||||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
|
||||||
|
|
||||||
assertThat(response.finishReason()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -45,7 +37,7 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
|
|
||||||
LanguageModel model = OllamaLanguageModel.builder()
|
LanguageModel model = OllamaLanguageModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.numPredict(numPredict)
|
.numPredict(numPredict)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
@ -67,7 +59,7 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
// given
|
// given
|
||||||
LanguageModel model = OllamaLanguageModel.builder()
|
LanguageModel model = OllamaLanguageModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.format("json")
|
.format("json")
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -7,7 +7,7 @@ import java.util.List;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaModelsIT extends AbstractOllamaInfrastructure {
|
class OllamaModelsIT extends AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
OllamaModels ollamaModels = OllamaModels.builder()
|
OllamaModels ollamaModels = OllamaModels.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
|
@ -22,7 +22,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(response.content().size()).isGreaterThan(0);
|
assertThat(response.content().size()).isGreaterThan(0);
|
||||||
assertThat(response.content().get(0).getName()).isEqualTo("phi:latest");
|
assertThat(response.content().get(0).getName()).isEqualTo("tinydolphin:latest");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -31,7 +31,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
|
||||||
|
|
||||||
// when
|
// when
|
||||||
OllamaModel ollamaModel = OllamaModel.builder()
|
OllamaModel ollamaModel = OllamaModel.builder()
|
||||||
.name("phi:latest")
|
.name("tinydolphin:latest")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Response<OllamaModelCard> response = ollamaModels.modelCard(ollamaModel);
|
Response<OllamaModelCard> response = ollamaModels.modelCard(ollamaModel);
|
||||||
|
@ -40,7 +40,7 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
|
||||||
assertThat(response.content().getModelfile()).isNotBlank();
|
assertThat(response.content().getModelfile()).isNotBlank();
|
||||||
assertThat(response.content().getTemplate()).isNotBlank();
|
assertThat(response.content().getTemplate()).isNotBlank();
|
||||||
assertThat(response.content().getParameters()).isNotBlank();
|
assertThat(response.content().getParameters()).isNotBlank();
|
||||||
assertThat(response.content().getDetails().getFamily()).isEqualTo("phi2");
|
assertThat(response.content().getDetails().getFamily()).isEqualTo("llama");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -48,13 +48,12 @@ class OllamaModelsIT extends AbstractOllamaInfrastructure {
|
||||||
// given AbstractOllamaInfrastructure
|
// given AbstractOllamaInfrastructure
|
||||||
|
|
||||||
// when
|
// when
|
||||||
Response<OllamaModelCard> response = ollamaModels.modelCard("phi:latest");
|
Response<OllamaModelCard> response = ollamaModels.modelCard("tinydolphin:latest");
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(response.content().getModelfile()).isNotBlank();
|
assertThat(response.content().getModelfile()).isNotBlank();
|
||||||
assertThat(response.content().getTemplate()).isNotBlank();
|
assertThat(response.content().getTemplate()).isNotBlank();
|
||||||
assertThat(response.content().getParameters()).isNotBlank();
|
assertThat(response.content().getParameters()).isNotBlank();
|
||||||
assertThat(response.content().getDetails().getFamily()).isEqualTo("phi2");
|
assertThat(response.content().getDetails().getFamily()).isEqualTo("llama");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -6,6 +6,7 @@ import dev.langchain4j.data.message.SystemMessage;
|
||||||
import dev.langchain4j.data.message.UserMessage;
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.model.StreamingResponseHandler;
|
import dev.langchain4j.model.StreamingResponseHandler;
|
||||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -13,55 +14,30 @@ import org.junit.jupiter.api.Test;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
class OllamaStreamingChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_stream_answer() throws Exception {
|
void should_stream_answer() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
String userMessage = "What is the capital of Germany?";
|
String userMessage = "What is the capital of Germany?";
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
model.generate(userMessage, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
|
String answer = response.content().text();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<AiMessage> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).contains("Berlin");
|
assertThat(answer).contains("Berlin");
|
||||||
|
@ -71,7 +47,7 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
assertThat(aiMessage.toolExecutionRequests()).isNull();
|
assertThat(aiMessage.toolExecutionRequests()).isNull();
|
||||||
|
|
||||||
TokenUsage tokenUsage = response.tokenUsage();
|
TokenUsage tokenUsage = response.tokenUsage();
|
||||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
|
assertThat(tokenUsage.inputTokenCount()).isEqualTo(35);
|
||||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||||
assertThat(tokenUsage.totalTokenCount())
|
assertThat(tokenUsage.totalTokenCount())
|
||||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||||
|
@ -80,14 +56,14 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_respect_numPredict() throws Exception {
|
void should_respect_numPredict() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
int numPredict = 1; // max output tokens
|
int numPredict = 1; // max output tokens
|
||||||
|
|
||||||
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.numPredict(numPredict)
|
.numPredict(numPredict)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
@ -95,35 +71,10 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
|
UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
model.generate(singletonList(userMessage), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
model.generate(singletonList(userMessage), new StreamingResponseHandler<AiMessage>() {
|
String answer = response.content().text();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<AiMessage> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).doesNotContain("Berlin");
|
assertThat(answer).doesNotContain("Berlin");
|
||||||
|
@ -132,44 +83,18 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
assertThat(response.tokenUsage().outputTokenCount()).isBetween(numPredict, numPredict + 2); // bug in Ollama
|
assertThat(response.tokenUsage().outputTokenCount()).isBetween(numPredict, numPredict + 2); // bug in Ollama
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_respect_system_message() throws Exception {
|
void should_respect_system_message() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German");
|
SystemMessage systemMessage = SystemMessage.from("Translate messages from user into German");
|
||||||
UserMessage userMessage = UserMessage.from("I love you");
|
UserMessage userMessage = UserMessage.from("I love you");
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
model.generate(asList(systemMessage, userMessage), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
model.generate(asList(systemMessage, userMessage), new StreamingResponseHandler<AiMessage>() {
|
String answer = response.content().text();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<AiMessage> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).containsIgnoringCase("liebe");
|
assertThat(answer).containsIgnoringCase("liebe");
|
||||||
|
@ -177,7 +102,7 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_respond_to_few_shot() throws Exception {
|
void should_respond_to_few_shot() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
List<ChatMessage> messages = asList(
|
List<ChatMessage> messages = asList(
|
||||||
|
@ -191,35 +116,10 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
);
|
);
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
model.generate(messages, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
|
String answer = response.content().text();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<AiMessage> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).startsWith(">>> 8");
|
assertThat(answer).startsWith(">>> 8");
|
||||||
|
@ -227,12 +127,12 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_generate_valid_json() throws Exception {
|
void should_generate_valid_json() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(TINY_DOLPHIN_MODEL)
|
||||||
.format("json")
|
.format("json")
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
@ -240,38 +140,51 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
|
||||||
String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
|
String userMessage = "Return JSON with two fields: name and age of John Doe, 42 years old.";
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
model.generate(userMessage, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
|
String answer = response.content().text();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<AiMessage> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
|
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
|
||||||
assertThat(response.content().text()).isEqualTo(answer);
|
assertThat(response.content().text()).isEqualTo(answer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_propagate_failure_to_handler_onError() throws Exception {
|
||||||
|
|
||||||
|
// given
|
||||||
|
String wrongModelName = "banana";
|
||||||
|
|
||||||
|
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
|
||||||
|
.baseUrl(ollama.getEndpoint())
|
||||||
|
.modelName(wrongModelName)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
CompletableFuture<Throwable> future = new CompletableFuture<>();
|
||||||
|
|
||||||
|
// when
|
||||||
|
model.generate("does not matter", new StreamingResponseHandler<AiMessage>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(String token) {
|
||||||
|
future.completeExceptionally(new Exception("onNext should never be called"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete(Response<AiMessage> response) {
|
||||||
|
future.completeExceptionally(new Exception("onComplete should never be called"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable error) {
|
||||||
|
future.complete(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(future.get())
|
||||||
|
.isExactlyInstanceOf(NullPointerException.class)
|
||||||
|
.hasMessageContaining("is null");
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package dev.langchain4j.model.ollama;
|
package dev.langchain4j.model.ollama;
|
||||||
|
|
||||||
import dev.langchain4j.model.StreamingResponseHandler;
|
import dev.langchain4j.model.StreamingResponseHandler;
|
||||||
|
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||||
import dev.langchain4j.model.language.StreamingLanguageModel;
|
import dev.langchain4j.model.language.StreamingLanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
@ -8,60 +9,34 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
|
class OllamaStreamingLanguageModelIT extends AbstractOllamaLanguageModelInfrastructure {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_stream_answer() throws Exception {
|
void should_stream_answer() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
String prompt = "What is the capital of Germany?";
|
String prompt = "What is the capital of Germany?";
|
||||||
|
|
||||||
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(OllamaImage.TINY_DOLPHIN_MODEL)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
|
model.generate(prompt, handler);
|
||||||
|
Response<String> response = handler.get();
|
||||||
model.generate(prompt, new StreamingResponseHandler<String>() {
|
String answer = response.content();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<String> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<String> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).contains("Berlin");
|
assertThat(answer).contains("Berlin");
|
||||||
assertThat(response.content()).isEqualTo(answer);
|
assertThat(response.content()).isEqualTo(answer);
|
||||||
|
|
||||||
TokenUsage tokenUsage = response.tokenUsage();
|
TokenUsage tokenUsage = response.tokenUsage();
|
||||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(38);
|
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
|
||||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||||
assertThat(tokenUsage.totalTokenCount())
|
assertThat(tokenUsage.totalTokenCount())
|
||||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||||
|
@ -70,14 +45,14 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_respect_numPredict() throws Exception {
|
void should_respect_numPredict() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
int numPredict = 1; // max output tokens
|
int numPredict = 1; // max output tokens
|
||||||
|
|
||||||
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(OllamaImage.TINY_DOLPHIN_MODEL)
|
||||||
.numPredict(numPredict)
|
.numPredict(numPredict)
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
@ -85,35 +60,10 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
String prompt = "What is the capital of Germany?";
|
String prompt = "What is the capital of Germany?";
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
|
model.generate(prompt, handler);
|
||||||
|
Response<String> response = handler.get();
|
||||||
model.generate(prompt, new StreamingResponseHandler<String>() {
|
String answer = response.content();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<String> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<String> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).doesNotContain("Berlin");
|
assertThat(answer).doesNotContain("Berlin");
|
||||||
|
@ -123,12 +73,12 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_stream_valid_json() throws Exception {
|
void should_stream_valid_json() {
|
||||||
|
|
||||||
// given
|
// given
|
||||||
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
||||||
.baseUrl(ollama.getEndpoint())
|
.baseUrl(ollama.getEndpoint())
|
||||||
.modelName(OllamaImage.PHI_MODEL)
|
.modelName(OllamaImage.TINY_DOLPHIN_MODEL)
|
||||||
.format("json")
|
.format("json")
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
@ -136,38 +86,51 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
|
||||||
String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old.";
|
String prompt = "Return JSON with two fields: name and age of John Doe, 42 years old.";
|
||||||
|
|
||||||
// when
|
// when
|
||||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
TestStreamingResponseHandler<String> handler = new TestStreamingResponseHandler<>();
|
||||||
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
|
model.generate(prompt, handler);
|
||||||
|
Response<String> response = handler.get();
|
||||||
model.generate(prompt, new StreamingResponseHandler<String>() {
|
String answer = response.content();
|
||||||
|
|
||||||
private final StringBuilder answerBuilder = new StringBuilder();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onNext(String token) {
|
|
||||||
System.out.println("onNext: '" + token + "'");
|
|
||||||
answerBuilder.append(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onComplete(Response<String> response) {
|
|
||||||
System.out.println("onComplete: '" + response + "'");
|
|
||||||
futureAnswer.complete(answerBuilder.toString());
|
|
||||||
futureResponse.complete(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(Throwable error) {
|
|
||||||
futureAnswer.completeExceptionally(error);
|
|
||||||
futureResponse.completeExceptionally(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
String answer = futureAnswer.get(30, SECONDS);
|
|
||||||
Response<String> response = futureResponse.get(30, SECONDS);
|
|
||||||
|
|
||||||
// then
|
// then
|
||||||
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
|
assertThat(answer).isEqualToIgnoringWhitespace("{\"name\": \"John Doe\", \"age\": 42}");
|
||||||
assertThat(response.content()).isEqualTo(answer);
|
assertThat(response.content()).isEqualTo(answer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_propagate_failure_to_handler_onError() throws Exception {
|
||||||
|
|
||||||
|
// given
|
||||||
|
String wrongModelName = "banana";
|
||||||
|
|
||||||
|
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
|
||||||
|
.baseUrl(ollama.getEndpoint())
|
||||||
|
.modelName(wrongModelName)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
CompletableFuture<Throwable> future = new CompletableFuture<>();
|
||||||
|
|
||||||
|
// when
|
||||||
|
model.generate("does not matter", new StreamingResponseHandler<String>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(String token) {
|
||||||
|
future.completeExceptionally(new Exception("onNext should never be called"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete(Response<String> response) {
|
||||||
|
future.completeExceptionally(new Exception("onComplete should never be called"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable error) {
|
||||||
|
future.complete(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(future.get())
|
||||||
|
.isExactlyInstanceOf(NullPointerException.class)
|
||||||
|
.hasMessageContaining("is null");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue