diff --git a/langchain4j-bom/pom.xml b/langchain4j-bom/pom.xml index e6a88f3ce..01abc332c 100644 --- a/langchain4j-bom/pom.xml +++ b/langchain4j-bom/pom.xml @@ -197,6 +197,12 @@ ${project.version} + + dev.langchain4j + langchain4j-mongodb-atlas + ${project.version} + + diff --git a/langchain4j-mongodb-atlas/pom.xml b/langchain4j-mongodb-atlas/pom.xml new file mode 100644 index 000000000..45c547316 --- /dev/null +++ b/langchain4j-mongodb-atlas/pom.xml @@ -0,0 +1,92 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-parent + 0.27.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-mongodb-atlas + jar + + LangChain4j :: Integration :: MongoDB Atlas + + + + dev.langchain4j + langchain4j-core + + + + org.mongodb + mongodb-driver-sync + 4.11.1 + + + + org.projectlombok + lombok + provided + + + + org.slf4j + slf4j-api + + + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.assertj + assertj-core + test + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + org.testcontainers + junit-jupiter + test + + + + org.testcontainers + mongodb + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + \ No newline at end of file diff --git a/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/IndexMapping.java b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/IndexMapping.java new file mode 100644 index 000000000..b02d9c177 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/IndexMapping.java @@ -0,0 +1,26 @@ +package dev.langchain4j.store.embedding.mongodb; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.HashSet; +import java.util.Set; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class IndexMapping { + + private int dimension; + private Set metadataFieldNames; + + public static IndexMapping defaultIndexMapping() { + return IndexMapping.builder() + .dimension(1536) + .metadataFieldNames(new HashSet<>()) + .build(); + } +} diff --git a/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MappingUtils.java b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MappingUtils.java new file mode 100644 index 000000000..bc4b512bb --- /dev/null +++ b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MappingUtils.java @@ -0,0 +1,77 @@ +package dev.langchain4j.store.embedding.mongodb; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import org.bson.Document; + +import java.util.Set; + +class MappingUtils { + + private MappingUtils() throws InstantiationException { + throw new InstantiationException("can't instantiate this class"); + } + + static MongoDbDocument toMongoDbDocument(String id, Embedding embedding, TextSegment textSegment) { + if (textSegment == null) { + return new MongoDbDocument(id, embedding.vectorAsList(), null, null); + } + return new MongoDbDocument(id, embedding.vectorAsList(), textSegment.text(), textSegment.metadata().asMap()); + } + + static EmbeddingMatch toEmbeddingMatch(MongoDbMatchedDocument matchedDocument) { + TextSegment textSegment = null; + if (matchedDocument.getText() != null) { + textSegment = matchedDocument.getMetadata() == null ? TextSegment.from(matchedDocument.getText()) : + TextSegment.from(matchedDocument.getText(), Metadata.from(matchedDocument.getMetadata())); + } + return new EmbeddingMatch<>(matchedDocument.getScore(), matchedDocument.getId(), Embedding.from(matchedDocument.getEmbedding()), textSegment); + } + + static Document fromIndexMapping(IndexMapping indexMapping) { + Document mapping = new Document(); + mapping.append("dynamic", false); + + Document fields = new Document(); + writeEmbedding(indexMapping.getDimension(), fields); + + Set metadataFields = indexMapping.getMetadataFieldNames(); + if (metadataFields != null && !metadataFields.isEmpty()) { + writeMetadata(metadataFields, fields); + } + + mapping.append("fields", fields); + + return new Document("mappings", mapping); + } + + private static void writeMetadata(Set metadataFields, Document fields) { + Document metadata = new Document(); + metadata.append("dynamic", false); + metadata.append("type", "document"); + + Document metadataFieldDoc = new Document(); + metadataFields.forEach(field -> writeMetadataField(metadataFieldDoc, field)); + + metadata.append("fields", metadataFieldDoc); + + fields.append("metadata", metadata); + } + + private static void writeMetadataField(Document metadataFieldDoc, String fieldName) { + Document field = new Document(); + field.append("type", "token"); + metadataFieldDoc.append(fieldName, field); + } + + private static void writeEmbedding(int dimensions, Document fields) { + Document embedding = new Document(); + embedding.append("dimensions", dimensions); + embedding.append("similarity", "cosine"); + embedding.append("type", "knnVector"); + + fields.append("embedding", embedding); + } +} diff --git a/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbDocument.java b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbDocument.java new file mode 100644 index 000000000..3371aa149 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbDocument.java @@ -0,0 +1,23 @@ +package dev.langchain4j.store.embedding.mongodb; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.bson.codecs.pojo.annotations.BsonId; + +import java.util.List; +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class MongoDbDocument { + + @BsonId + private String id; + private List embedding; + private String text; + private Map metadata; +} diff --git a/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStore.java b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStore.java new file mode 100644 index 000000000..56bc06c12 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStore.java @@ -0,0 +1,322 @@ +package dev.langchain4j.store.embedding.mongodb; + +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCommandException; +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.search.VectorSearchOptions; +import com.mongodb.client.result.InsertManyResult; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import org.bson.Document; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.codecs.pojo.PojoCodecProvider; +import org.bson.conversions.Bson; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static com.mongodb.client.model.Aggregates.*; +import static com.mongodb.client.model.Projections.*; +import static com.mongodb.client.model.search.SearchPath.fieldPath; +import static com.mongodb.client.model.search.VectorSearchOptions.vectorSearchOptions; +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureTrue; +import static dev.langchain4j.store.embedding.mongodb.IndexMapping.defaultIndexMapping; +import static dev.langchain4j.store.embedding.mongodb.MappingUtils.fromIndexMapping; +import static dev.langchain4j.store.embedding.mongodb.MappingUtils.toMongoDbDocument; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; + +/** + * Represents a MongoDB index as an embedding store. + *

+ * More info to set up MongoDb as vectorDatabase + *

+ * tutorial how to use a knn-vector in MongoDB Atlas (great startingpoint) + */ +public class MongoDbEmbeddingStore implements EmbeddingStore { + + private static final Logger log = LoggerFactory.getLogger(MongoDbEmbeddingStore.class); + + private final MongoCollection collection; + + private final String indexName; + private final long maxResultRatio; + private final VectorSearchOptions vectorSearchOptions; + + public MongoDbEmbeddingStore(MongoClient mongoClient, + String databaseName, + String collectionName, + String indexName, + Long maxResultRatio, + CreateCollectionOptions createCollectionOptions, + Bson filter, + IndexMapping indexMapping, + Boolean createIndex) { + databaseName = ensureNotNull(databaseName, "databaseName"); + collectionName = ensureNotNull(collectionName, "collectionName"); + createIndex = getOrDefault(createIndex, false); + this.indexName = ensureNotNull(indexName, "indexName"); + this.maxResultRatio = getOrDefault(maxResultRatio, 10L); + + CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder() + .register(MongoDbDocument.class, MongoDbMatchedDocument.class) + .build()); + CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry); + + // create collection if not exist + MongoDatabase database = mongoClient.getDatabase(databaseName); + if (!isCollectionExist(database, collectionName)) { + createCollection(database, collectionName, getOrDefault(createCollectionOptions, new CreateCollectionOptions())); + } + + this.collection = database.getCollection(collectionName, MongoDbDocument.class).withCodecRegistry(codecRegistry); + this.vectorSearchOptions = filter == null ? vectorSearchOptions() : vectorSearchOptions().filter(filter); + + // create index if not exist + if (Boolean.TRUE.equals(createIndex) && !isIndexExist(this.indexName)) { + createIndex(this.indexName, getOrDefault(indexMapping, defaultIndexMapping())); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private MongoClient mongoClient; + private String databaseName; + private String collectionName; + private String indexName; + private Long maxResultRatio; + private CreateCollectionOptions createCollectionOptions; + private Bson filter; + private IndexMapping indexMapping; + /** + * Whether MongoDB Atlas is deployed in cloud + * + *

if true, you need to create index in MongoDB Atlas

+ *

if false, {@link MongoDbEmbeddingStore} will create collection and index automatically

+ */ + private Boolean createIndex; + + /** + * Build Mongo Client, Please close the client to release resources after usage + */ + public Builder fromClient(MongoClient mongoClient) { + this.mongoClient = mongoClient; + return this; + } + + public Builder databaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + public Builder collectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + public Builder maxResultRatio(Long maxResultRatio) { + this.maxResultRatio = maxResultRatio; + return this; + } + + public Builder createCollectionOptions(CreateCollectionOptions createCollectionOptions) { + this.createCollectionOptions = createCollectionOptions; + return this; + } + + /** + * Document query filter, all fields included in filter must be contained in {@link IndexMapping#metadataFieldNames} + * + *

For example:

+ * + *
    + *
  • AND filter: Filters.and(Filters.in("type", asList("TXT", "md")), Filters.eqFull("test-key", "test-value"))
  • + *
  • OR filter: Filters.or(Filters.in("type", asList("TXT", "md")), Filters.eqFull("test-key", "test-value"))
  • + *
+ * + * @param filter document query filter + * @return builder + */ + public Builder filter(Bson filter) { + this.filter = filter; + return this; + } + + /** + * set MongoDB search index fields mapping + * + *

if {@link Builder#createIndex} is true, then indexMapping not work

+ * + * @param indexMapping MongoDB search index fields mapping + * @return builder + */ + public Builder indexMapping(IndexMapping indexMapping) { + this.indexMapping = indexMapping; + return this; + } + + /** + * Set whether in production mode, production mode will not create index automatically + * + *

default value is false

+ * + * @param createIndex whether in production mode + * @return builder + */ + public Builder createIndex(Boolean createIndex) { + this.createIndex = createIndex; + return this; + } + + public MongoDbEmbeddingStore build() { + return new MongoDbEmbeddingStore(mongoClient, databaseName, collectionName, indexName, maxResultRatio, createCollectionOptions, filter, indexMapping, createIndex); + } + } + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + List queryVector = referenceEmbedding.vectorAsList().stream() + .map(Float::doubleValue) + .collect(toList()); + long numCandidates = maxResults * maxResultRatio; + + List pipeline = Arrays.asList( + vectorSearch( + fieldPath("embedding"), + queryVector, + indexName, + numCandidates, + maxResults, + vectorSearchOptions), + project( + fields( + metaVectorSearchScore("score"), + include("embedding", "metadata", "text") + ) + ), + match( + Filters.gte("score", minScore) + )); + + try { + AggregateIterable results = collection.aggregate(pipeline, MongoDbMatchedDocument.class); + + return StreamSupport.stream(results.spliterator(), false) + .map(MappingUtils::toEmbeddingMatch) + .collect(Collectors.toList()); + + } catch (MongoCommandException e) { + if (log.isErrorEnabled()) { + log.error("Error in MongoDBEmbeddingStore.findRelevant", e); + } + throw new RuntimeException(e); + } + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal(List ids, List embeddings, List embedded) { + if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) { + log.info("do not add empty embeddings to MongoDB Atlas"); + return; + } + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); + + List documents = new ArrayList<>(ids.size()); + for (int i = 0; i < ids.size(); i++) { + MongoDbDocument document = toMongoDbDocument(ids.get(i), embeddings.get(i), embedded == null ? null : embedded.get(i)); + documents.add(document); + } + + InsertManyResult result = collection.insertMany(documents); + if (!result.wasAcknowledged() && log.isWarnEnabled()) { + String errMsg = String.format("[MongoDbEmbeddingStore] Add document failed, Document=%s", documents); + log.warn(errMsg); + throw new RuntimeException(errMsg); + } + } + + private boolean isCollectionExist(MongoDatabase database, String collectionName) { + return StreamSupport.stream(database.listCollectionNames().spliterator(), false) + .anyMatch(collectionName::equals); + } + + private void createCollection(MongoDatabase database, String collectionName, CreateCollectionOptions createCollectionOptions) { + database.createCollection(collectionName, createCollectionOptions); + } + + private boolean isIndexExist(String indexName) { + return StreamSupport.stream(collection.listSearchIndexes().spliterator(), false) + .anyMatch(index -> indexName.equals(index.getString("name"))); + } + + private void createIndex(String indexName, IndexMapping indexMapping) { + Document index = fromIndexMapping(indexMapping); + collection.createSearchIndex(indexName, index); + } +} diff --git a/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbMatchedDocument.java b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbMatchedDocument.java new file mode 100644 index 000000000..cb4e8a437 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/main/java/dev/langchain4j/store/embedding/mongodb/MongoDbMatchedDocument.java @@ -0,0 +1,20 @@ +package dev.langchain4j.store.embedding.mongodb; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class MongoDbMatchedDocument { + + private String id; + private List embedding; + private String text; + private Map metadata; + private Double score; +} diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java new file mode 100644 index 000000000..75684b126 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java @@ -0,0 +1,78 @@ +package dev.langchain4j.store.embedding.mongodb; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Filters; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import lombok.SneakyThrows; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.codecs.pojo.PojoCodecProvider; +import org.bson.conversions.Bson; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; + +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; + +@Disabled("Need Cloud Mongo Atlas Credential") +class MongoDbEmbeddingStoreCloudIT extends EmbeddingStoreIT { + + static MongoClient client; + + MongoDbEmbeddingStore embeddingStore = MongoDbEmbeddingStore.builder() + .fromClient(client) + .databaseName("test_database") + .collectionName("test_collection") + .indexName("test_index") + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeAll + static void beforeAll() { + client = MongoClients.create("mongodb+srv://:@/?retryWrites=true&w=majority"); + } + + @AfterAll + static void afterAll() { + client.close(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void clearStore() { + CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder() + .register(MongoDbDocument.class, MongoDbMatchedDocument.class) + .build()); + CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry); + + MongoCollection collection = client.getDatabase("test_database") + .getCollection("test_collection", MongoDbDocument.class) + .withCodecRegistry(codecRegistry); + + Bson filter = Filters.exists("embedding"); + collection.deleteMany(filter); + } + + @Override + @SneakyThrows + protected void awaitUntilPersisted() { + Thread.sleep(2000); + } +} diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreFilterIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreFilterIT.java new file mode 100644 index 000000000..a6d92e005 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreFilterIT.java @@ -0,0 +1,112 @@ +package dev.langchain4j.store.embedding.mongodb; + +import com.mongodb.*; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.model.Filters; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.DockerComposeContainer; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.shaded.com.google.common.collect.Sets; + +import java.io.File; +import java.time.Duration; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +class MongoDbEmbeddingStoreFilterIT { + + static final String MONGO_SERVICE_NAME = "mongo"; + static final Integer MONGO_SERVICE_PORT = 27778; + static DockerComposeContainer mongodb = new DockerComposeContainer<>(new File("src/test/resources/docker-compose.yml")) + .withExposedService(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT, new LogMessageWaitStrategy() + .withRegEx(".*Deployment created!.*\\n") + .withTimes(1) + .withStartupTimeout(Duration.ofMinutes(30))); + + static MongoClient client; + + IndexMapping indexMapping = IndexMapping.builder() + .dimension(384) + .metadataFieldNames(Sets.newHashSet("test-key")) + .build(); + + EmbeddingStore embeddingStore = MongoDbEmbeddingStore.builder() + .fromClient(client) + .databaseName("test_database") + .collectionName("test_collection") + .indexName("test_index") + .filter(Filters.and(Filters.eqFull("metadata.test-key", "test-value"))) + .indexMapping(indexMapping) + .createIndex(true) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeAll + @SneakyThrows + static void start() { + mongodb.start(); + + MongoCredential credential = MongoCredential.createCredential("root", "admin", "root".toCharArray()); + client = MongoClients.create( + MongoClientSettings.builder() + .credential(credential) + .serverApi(ServerApi.builder().version(ServerApiVersion.V1).build()) + .applyConnectionString(new ConnectionString(String.format("mongodb://%s:%s/?directConnection=true", + mongodb.getServiceHost(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT), mongodb.getServicePort(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT)))) + .build()); + } + + @AfterAll + static void stop() { + mongodb.stop(); + client.close(); + } + + @Test + void should_find_relevant_with_filter() { + TextSegment segment = TextSegment.from("this segment should be found", Metadata.from("test-key", "test-value")); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + TextSegment filterSegment = TextSegment.from("this segment should not be found", Metadata.from("test-key", "no-value")); + Embedding filterEmbedding = embeddingModel.embed(filterSegment.text()).content(); + + List ids = embeddingStore.addAll(asList(embedding, filterEmbedding), asList(segment, filterSegment)); + assertThat(ids) + .hasSize(2); + + TextSegment refSegment = TextSegment.from("find a segment"); + Embedding refEmbedding = embeddingModel.embed(refSegment.text()).content(); + + awaitUntilPersisted(); + + List> relevant = embeddingStore.findRelevant(refEmbedding, 2); + // Only segment should be found, filterSegment should be filtered + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(0.88, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(ids.get(0)); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @SneakyThrows + protected void awaitUntilPersisted() { + Thread.sleep(2000); + } +} diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java new file mode 100644 index 000000000..7da6dc253 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java @@ -0,0 +1,112 @@ +package dev.langchain4j.store.embedding.mongodb; + +import com.mongodb.*; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Filters; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import lombok.SneakyThrows; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.codecs.pojo.PojoCodecProvider; +import org.bson.conversions.Bson; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.testcontainers.containers.DockerComposeContainer; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.shaded.com.google.common.collect.Sets; + +import java.io.File; +import java.time.Duration; + +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; + +/** + * If container startup timeout (because atlas cli need to download mongodb binaries, which may take a few minutes), + * the alternative way is running `docker compose up -d` in `src/test/resources` + */ +class MongoDbEmbeddingStoreLocalIT extends EmbeddingStoreIT { + + static final String MONGO_SERVICE_NAME = "mongo"; + static final Integer MONGO_SERVICE_PORT = 27778; + static DockerComposeContainer mongodb = new DockerComposeContainer<>(new File("src/test/resources/docker-compose.yml")) + .withExposedService(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT, new LogMessageWaitStrategy() + .withRegEx(".*Deployment created!.*\\n") + .withTimes(1) + .withStartupTimeout(Duration.ofMinutes(30))); + + static MongoClient client; + + IndexMapping indexMapping = IndexMapping.builder() + .dimension(384) + .metadataFieldNames(Sets.newHashSet("test-key")) + .build(); + + EmbeddingStore embeddingStore = MongoDbEmbeddingStore.builder() + .fromClient(client) + .databaseName("test_database") + .collectionName("test_collection") + .indexName("test_index") + .indexMapping(indexMapping) + .createIndex(true) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeAll + @SneakyThrows + static void start() { + mongodb.start(); + + MongoCredential credential = MongoCredential.createCredential("root", "admin", "root".toCharArray()); + client = MongoClients.create( + MongoClientSettings.builder() + .credential(credential) + .serverApi(ServerApi.builder().version(ServerApiVersion.V1).build()) + .applyConnectionString(new ConnectionString(String.format("mongodb://%s:%s/?directConnection=true", + mongodb.getServiceHost(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT), mongodb.getServicePort(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT)))) + .build()); + } + + @AfterAll + static void stop() { + mongodb.stop(); + client.close(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void clearStore() { + CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder() + .register(MongoDbDocument.class, MongoDbMatchedDocument.class) + .build()); + CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry); + + MongoCollection collection = client.getDatabase("test_database") + .getCollection("test_collection", MongoDbDocument.class) + .withCodecRegistry(codecRegistry); + + Bson filter = Filters.exists("embedding"); + collection.deleteMany(filter); + } + + @Override + @SneakyThrows + protected void awaitUntilPersisted() { + Thread.sleep(2000); + } +} diff --git a/langchain4j-mongodb-atlas/src/test/resources/docker-compose.yml b/langchain4j-mongodb-atlas/src/test/resources/docker-compose.yml new file mode 100644 index 000000000..4959bcff0 --- /dev/null +++ b/langchain4j-mongodb-atlas/src/test/resources/docker-compose.yml @@ -0,0 +1,10 @@ +services: + mongo: + image: mongodb/atlas + privileged: true + command: | + /bin/bash -c "atlas deployments setup local-test --type local --port 27778 --bindIpAll --username root --password root --force && tail -f /dev/null" + volumes: + - /var/run/docker.sock:/var/run/docker.sock + ports: + - 27778:27778 \ No newline at end of file diff --git a/pom.xml b/pom.xml index 2e643c1c6..e5f89493c 100644 --- a/pom.xml +++ b/pom.xml @@ -49,6 +49,7 @@ langchain4j-weaviate langchain4j-neo4j langchain4j-vearch + langchain4j-mongodb-atlas document-loaders/langchain4j-document-loader-amazon-s3