From a886eee15a012818452300629b53000157d14392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edd=C3=BA=20Mel=C3=A9ndez=20Gonzales?= Date: Fri, 8 Dec 2023 03:33:56 -0600 Subject: [PATCH] Use Testcontainers for Weaviate IT (#332) Run `semitechnologies/weaviate` image using `GenericContainer` and allow rapid feeback when using Weaviate. Relax `apiKey` constraint to allow empty value for testing purposes. --- langchain4j-weaviate/pom.xml | 6 +++++ .../weaviate/WeaviateEmbeddingStore.java | 3 ++- .../weaviate/WeaviateEmbeddingStoreIT.java | 22 ++++++++++++++----- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/langchain4j-weaviate/pom.xml b/langchain4j-weaviate/pom.xml index 3fcf2f510..893f30b40 100644 --- a/langchain4j-weaviate/pom.xml +++ b/langchain4j-weaviate/pom.xml @@ -70,6 +70,12 @@ test + + org.testcontainers + junit-jupiter + test + + diff --git a/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java b/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java index 18fd9908b..26d2110f9 100644 --- a/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java +++ b/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java @@ -2,6 +2,7 @@ package dev.langchain4j.store.embedding.weaviate; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static io.weaviate.client.v1.data.replication.model.ConsistencyLevel.QUORUM; import static java.util.Arrays.stream; import static java.util.Collections.emptyList; @@ -66,7 +67,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore { ) { try { Config config = new Config(ensureNotBlank(scheme, "scheme"), ensureNotBlank(host, "host")); - this.client = WeaviateAuthClient.apiKey(config, ensureNotBlank(apiKey, "apiKey")); + this.client = WeaviateAuthClient.apiKey(config, ensureNotNull(apiKey, "apiKey")); } catch (AuthException e) { throw new IllegalArgumentException(e); } diff --git a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java index be02e8245..094803fa9 100644 --- a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java +++ b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java @@ -5,17 +5,29 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; import static dev.langchain4j.internal.Utils.randomUUID; -@EnabledIfEnvironmentVariable(named = "WEAVIATE_API_KEY", matches = ".+") +@Testcontainers class WeaviateEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { + @Container + static GenericContainer weaviate = new GenericContainer<>("semitechnologies/weaviate:1.22.4") + .withEnv("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true") + .withEnv("PERSISTENCE_DATA_PATH", "/var/lib/weaviate") + .withEnv("QUERY_DEFAULTS_LIMIT", "25") + .withEnv("DEFAULT_VECTORIZER_MODULE", "none") + .withEnv("CLUSTER_HOSTNAME", "node1") + .withExposedPorts(8080); + EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() - .apiKey(System.getenv("WEAVIATE_API_KEY")) - .scheme("https") - .host("test-am8ocede.weaviate.network") + .apiKey("") + .scheme("http") + .host(String.format("%s:%d", weaviate.getHost(), weaviate.getMappedPort(8080))) .objectClass("Test" + randomUUID().replace("-", "")) .build();