Use Testcontainers Ollama module (#702)

Testcontainers 1.19.7 offers an Ollama module. It also configures gpu if
available.
This commit is contained in:
Eddú Meléndez Gonzales 2024-03-19 07:05:24 -05:00 committed by GitHub
parent f19a0a3b11
commit 3fabe0ed66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 128 additions and 219 deletions

View File

@ -52,7 +52,8 @@ Try out a simple chat example code:
```java
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName;
public class OllamaChatExample {
@ -62,9 +63,9 @@ public class OllamaChatExample {
String modelName = "orca-mini";
// Create and start the Ollama container
GenericContainer<?> ollama =
new GenericContainer<>("langchain4j/ollama-" + modelName + ":latest")
.withExposedPorts(11434);
OllamaContainer ollama =
new OllamaContainer(DockerImageName.parse("langchain4j/ollama-" + modelName + ":latest")
.asCompatibleSubstituteFor("ollama/ollama"));
ollama.start();
// Build the ChatLanguageModel
@ -94,7 +95,8 @@ import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
import dev.langchain4j.model.output.Response;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName;
import java.util.concurrent.CompletableFuture;
@ -102,10 +104,9 @@ public class OllamaStreamingChatExample {
static String MODEL_NAME = "orca-mini"; // try "mistral", "llama2", "codellama" or "phi"
static String DOCKER_IMAGE_NAME = "langchain4j/ollama-" + MODEL_NAME + ":latest";
static Integer PORT = 11434;
static GenericContainer<?> ollama = new GenericContainer<>(DOCKER_IMAGE_NAME)
.withExposedPorts(PORT);
static OllamaContainer ollama = new OllamaContainer(
DockerImageName.parse(DOCKER_IMAGE_NAME).asCompatibleSubstituteFor("ollama/ollama"));
public static void main(String[] args) {
ollama.start();

View File

@ -64,7 +64,7 @@
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<artifactId>ollama</artifactId>
<scope>test</scope>
</dependency>

View File

@ -1,103 +1,33 @@
package dev.langchain4j.model.ollama;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.command.InspectContainerResponse;
import com.github.dockerjava.api.model.Image;
import lombok.extern.slf4j.Slf4j;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.utility.DockerImageName;
import org.testcontainers.utility.LazyFuture;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
@Slf4j
public class AbstractOllamaInfrastructure {
private static final String OLLAMA_IMAGE = "ollama/ollama:latest";
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.PHI_MODEL);
static final String MODEL = "phi";
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, MODEL);
static OllamaContainer ollama;
static LangChain4jOllamaContainer ollama;
static {
ollama = new OllamaContainer(new OllamaImage(OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE));
ollama = new LangChain4jOllamaContainer(resolveImage(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.PHI_MODEL);
ollama.start();
createImage(ollama, LOCAL_OLLAMA_IMAGE);
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
String getBaseUrl() {
return "http://" + ollama.getHost() + ":" + ollama.getMappedPort(11434);
}
static void createImage(GenericContainer<?> container, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(container.getDockerImageName());
if (!dockerImageName.equals(DockerImageName.parse(localImageName))) {
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
DockerImageName imageModel = DockerImageName.parse(localImageName);
dockerClient.commitCmd(container.getContainerId())
.withRepository(imageModel.getUnversionedPart())
.withLabels(Collections.singletonMap("org.testcontainers.sessionId", ""))
.withTag(imageModel.getVersionPart())
.exec();
}
static DockerImageName resolveImage(String baseImage, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
}
static class OllamaContainer extends GenericContainer<OllamaContainer> {
private final DockerImageName dockerImageName;
OllamaContainer(LazyFuture<DockerImageName> image) {
super(image.get());
this.dockerImageName = image.get();
withExposedPorts(11434);
withImagePullPolicy(dockerImageName -> !dockerImageName.getVersionPart().endsWith(MODEL));
}
@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
if (!this.dockerImageName.equals(DockerImageName.parse(LOCAL_OLLAMA_IMAGE))) {
try {
log.info("Start pulling the '{}' model ... would take several minutes ...", MODEL);
execInContainer("ollama", "pull", MODEL);
log.info("Model pulling competed!");
} catch (IOException | InterruptedException e) {
throw new RuntimeException("Error pulling model", e);
}
}
}
}
static class OllamaImage extends LazyFuture<DockerImageName> {
private final String baseImage;
private final String localImageName;
OllamaImage(String baseImage, String localImageName) {
this.baseImage = baseImage;
this.localImageName = localImageName;
}
@Override
protected DockerImageName resolve() {
DockerImageName dockerImageName = DockerImageName.parse(this.baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(this.localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
return DockerImageName.parse(this.localImageName);
}
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
}
}

View File

@ -1,105 +1,16 @@
package dev.langchain4j.model.ollama;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.command.InspectContainerResponse;
import com.github.dockerjava.api.model.Image;
import lombok.extern.slf4j.Slf4j;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.utility.DockerImageName;
import org.testcontainers.utility.LazyFuture;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
@Slf4j
public class AbstractOllamaInfrastructureVisionModel {
private static final String OLLAMA_IMAGE = "ollama/ollama:latest";
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.BAKLLAVA_MODEL);
static final String MODEL = "bakllava";
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OLLAMA_IMAGE, MODEL);
static OllamaContainer ollama;
static LangChain4jOllamaContainer ollama;
static {
ollama = new OllamaContainer(new OllamaImage(OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE));
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.BAKLLAVA_MODEL);
ollama.start();
createImage(ollama, LOCAL_OLLAMA_IMAGE);
}
String getBaseUrl() {
return "http://" + ollama.getHost() + ":" + ollama.getMappedPort(11434);
}
static void createImage(GenericContainer<?> container, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(container.getDockerImageName());
if (!dockerImageName.equals(DockerImageName.parse(localImageName))) {
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
DockerImageName imageModel = DockerImageName.parse(localImageName);
dockerClient.commitCmd(container.getContainerId())
.withRepository(imageModel.getUnversionedPart())
.withLabels(Collections.singletonMap("org.testcontainers.sessionId", ""))
.withTag(imageModel.getVersionPart())
.exec();
}
}
}
static class OllamaContainer extends GenericContainer<OllamaContainer> {
private final DockerImageName dockerImageName;
OllamaContainer(LazyFuture<DockerImageName> image) {
super(image.get());
this.dockerImageName = image.get();
withExposedPorts(11434);
withImagePullPolicy(dockerImageName -> !dockerImageName.getVersionPart().endsWith(MODEL));
}
@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
if (!this.dockerImageName.equals(DockerImageName.parse(LOCAL_OLLAMA_IMAGE))) {
try {
log.info("Start pulling the '{}' model ... would take several minutes ...", MODEL);
execInContainer("ollama", "pull", MODEL);
log.info("Model pulling competed!");
} catch (IOException | InterruptedException e) {
throw new RuntimeException("Error pulling model", e);
}
}
}
}
static class OllamaImage extends LazyFuture<DockerImageName> {
private final String baseImage;
private final String localImageName;
OllamaImage(String baseImage, String localImageName) {
this.baseImage = baseImage;
this.localImageName = localImageName;
}
@Override
protected DockerImageName resolve() {
DockerImageName dockerImageName = DockerImageName.parse(this.baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(this.localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
return DockerImageName.parse(this.localImageName);
}
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}
}

View File

@ -0,0 +1,39 @@
package dev.langchain4j.model.ollama;
import com.github.dockerjava.api.command.InspectContainerResponse;
import lombok.extern.slf4j.Slf4j;
import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName;
import java.io.IOException;
@Slf4j
public class LangChain4jOllamaContainer extends OllamaContainer {
private final DockerImageName dockerImageName;
private String model;
LangChain4jOllamaContainer(DockerImageName dockerImageName) {
super(dockerImageName);
this.dockerImageName = dockerImageName;
}
LangChain4jOllamaContainer withModel(String model) {
this.model = model;
return this;
}
@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
if (this.model != null) {
try {
log.info("Start pulling the '{}' model ... would take several minutes ...", this.model);
execInContainer("ollama", "pull", this.model);
log.info("Model pulling competed!");
} catch (IOException | InterruptedException e) {
throw new RuntimeException("Error pulling model", e);
}
}
}
}

View File

@ -21,9 +21,9 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
// given
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(getBaseUrl())
.baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(3))
.modelName(MODEL)
.modelName(OllamaImage.BAKLLAVA_MODEL)
.temperature(0.0)
.build();
@ -43,9 +43,9 @@ class OllamaChatModeVisionModellIT extends AbstractOllamaInfrastructureVisionMod
// given
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(getBaseUrl())
.baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(3))
.modelName(MODEL)
.modelName(OllamaImage.BAKLLAVA_MODEL)
.temperature(0.0)
.build();

View File

@ -17,8 +17,8 @@ import static org.assertj.core.api.Assertions.assertThat;
class OllamaChatModelIT extends AbstractOllamaInfrastructure {
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.temperature(0.0)
.build();
@ -53,8 +53,8 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
int numPredict = 1; // max output tokens
OllamaChatModel model = OllamaChatModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -112,8 +112,8 @@ class OllamaChatModelIT extends AbstractOllamaInfrastructure {
// given
ChatLanguageModel model = OllamaChatModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.format("json")
.temperature(0.0)
.build();

View File

@ -14,7 +14,7 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
// when
OllamaClient ollamaClient = OllamaClient.builder()
.baseUrl(getBaseUrl())
.baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(1))
.build();
@ -34,7 +34,7 @@ class OllamaClientIT extends AbstractOllamaInfrastructure {
// when
OllamaClient ollamaClient = OllamaClient.builder()
.baseUrl(getBaseUrl())
.baseUrl(ollama.getEndpoint())
.timeout(Duration.ofMinutes(1))
.build();

View File

@ -10,8 +10,8 @@ import static org.assertj.core.api.Assertions.assertThat;
class OllamaEmbeddingModelIT extends AbstractOllamaInfrastructure {
EmbeddingModel model = OllamaEmbeddingModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.build();
@Test

View File

@ -0,0 +1,28 @@
package dev.langchain4j.model.ollama;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.model.Image;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.utility.DockerImageName;
import java.util.List;
public class OllamaImage {
static final String OLLAMA_IMAGE = "ollama/ollama:latest";
static final String BAKLLAVA_MODEL = "bakllava";
static final String PHI_MODEL = "phi";
static DockerImageName resolve(String baseImage, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
}
}

View File

@ -10,8 +10,8 @@ import static org.assertj.core.api.Assertions.assertThat;
class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.temperature(0.0)
.build();
@ -44,8 +44,8 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
int numPredict = 1; // max output tokens
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -66,8 +66,8 @@ class OllamaLanguageModelIT extends AbstractOllamaInfrastructure {
// given
LanguageModel model = OllamaLanguageModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.format("json")
.temperature(0.0)
.build();

View File

@ -10,7 +10,7 @@ import static org.assertj.core.api.Assertions.assertThat;
class OllamaModelsIT extends AbstractOllamaInfrastructure {
OllamaModels ollamaModels = OllamaModels.builder()
.baseUrl(getBaseUrl())
.baseUrl(ollama.getEndpoint())
.build();
@Test

View File

@ -21,8 +21,8 @@ import static org.assertj.core.api.Assertions.assertThat;
class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.temperature(0.0)
.build();
@ -86,8 +86,8 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
int numPredict = 1; // max output tokens
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -231,8 +231,8 @@ class OllamaStreamingChatModelIT extends AbstractOllamaInfrastructure {
// given
StreamingChatLanguageModel model = OllamaStreamingChatModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.format("json")
.temperature(0.0)
.build();

View File

@ -20,8 +20,8 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
String prompt = "What is the capital of Germany?";
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.temperature(0.0)
.build();
@ -76,8 +76,8 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
int numPredict = 1; // max output tokens
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.numPredict(numPredict)
.temperature(0.0)
.build();
@ -127,8 +127,8 @@ class OllamaStreamingLanguageModelIT extends AbstractOllamaInfrastructure {
// given
StreamingLanguageModel model = OllamaStreamingLanguageModel.builder()
.baseUrl(getBaseUrl())
.modelName(MODEL)
.baseUrl(ollama.getEndpoint())
.modelName(OllamaImage.PHI_MODEL)
.format("json")
.temperature(0.0)
.build();

View File

@ -32,7 +32,7 @@
<slf4j-api.version>2.0.7</slf4j-api.version>
<gson.version>2.10.1</gson.version>
<junit.version>5.10.0</junit.version>
<testcontainers.version>1.19.6</testcontainers.version>
<testcontainers.version>1.19.7</testcontainers.version>
<bytebuddy.version>1.14.10</bytebuddy.version>
<mockito.version>4.11.0</mockito.version>
<assertj.version>3.24.2</assertj.version>