Integration with MongoDB (#535)

see original PR #254. There are four mainly differences:

1. Using `Testcontainer` and MongoDB Atlas Local Deployment to test
2. Create collection and index when the `MongoDBEmbeddingStore`
initialize, rather than create when adding new embedding at the first
time.
3. Optimize `BsonUtils`, which is replaced by `org.bson.Document` to
create index mapping.
4. Rename `langchain4j-mongodb` to `langchain4j-mongodb-atlas`

Local deployment tests are all passed, but cloud tests are not tested
yet because I encounter some network problem when communicating with
MongoDB Atlas. (But I think it doesn't matter, because local deployment
is the same as cloud, the purpose of local deployment is to development
and test)
This commit is contained in:
ZYinNJU 2024-02-08 17:13:29 +08:00 committed by GitHub
parent dc4028b546
commit c694755cc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 879 additions and 0 deletions

View File

@ -197,6 +197,12 @@
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-mongodb-atlas</artifactId>
<version>${project.version}</version>
</dependency>
<!-- code execution engines -->
<dependency>

View File

@ -0,0 +1,92 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.27.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-mongodb-atlas</artifactId>
<packaging>jar</packaging>
<name>LangChain4j :: Integration :: MongoDB Atlas</name>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>mongodb-driver-sync</artifactId>
<version>4.11.1</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>mongodb</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding.mongodb;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.HashSet;
import java.util.Set;
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class IndexMapping {
private int dimension;
private Set<String> metadataFieldNames;
public static IndexMapping defaultIndexMapping() {
return IndexMapping.builder()
.dimension(1536)
.metadataFieldNames(new HashSet<>())
.build();
}
}

View File

@ -0,0 +1,77 @@
package dev.langchain4j.store.embedding.mongodb;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import org.bson.Document;
import java.util.Set;
class MappingUtils {
private MappingUtils() throws InstantiationException {
throw new InstantiationException("can't instantiate this class");
}
static MongoDbDocument toMongoDbDocument(String id, Embedding embedding, TextSegment textSegment) {
if (textSegment == null) {
return new MongoDbDocument(id, embedding.vectorAsList(), null, null);
}
return new MongoDbDocument(id, embedding.vectorAsList(), textSegment.text(), textSegment.metadata().asMap());
}
static EmbeddingMatch<TextSegment> toEmbeddingMatch(MongoDbMatchedDocument matchedDocument) {
TextSegment textSegment = null;
if (matchedDocument.getText() != null) {
textSegment = matchedDocument.getMetadata() == null ? TextSegment.from(matchedDocument.getText()) :
TextSegment.from(matchedDocument.getText(), Metadata.from(matchedDocument.getMetadata()));
}
return new EmbeddingMatch<>(matchedDocument.getScore(), matchedDocument.getId(), Embedding.from(matchedDocument.getEmbedding()), textSegment);
}
static Document fromIndexMapping(IndexMapping indexMapping) {
Document mapping = new Document();
mapping.append("dynamic", false);
Document fields = new Document();
writeEmbedding(indexMapping.getDimension(), fields);
Set<String> metadataFields = indexMapping.getMetadataFieldNames();
if (metadataFields != null && !metadataFields.isEmpty()) {
writeMetadata(metadataFields, fields);
}
mapping.append("fields", fields);
return new Document("mappings", mapping);
}
private static void writeMetadata(Set<String> metadataFields, Document fields) {
Document metadata = new Document();
metadata.append("dynamic", false);
metadata.append("type", "document");
Document metadataFieldDoc = new Document();
metadataFields.forEach(field -> writeMetadataField(metadataFieldDoc, field));
metadata.append("fields", metadataFieldDoc);
fields.append("metadata", metadata);
}
private static void writeMetadataField(Document metadataFieldDoc, String fieldName) {
Document field = new Document();
field.append("type", "token");
metadataFieldDoc.append(fieldName, field);
}
private static void writeEmbedding(int dimensions, Document fields) {
Document embedding = new Document();
embedding.append("dimensions", dimensions);
embedding.append("similarity", "cosine");
embedding.append("type", "knnVector");
fields.append("embedding", embedding);
}
}

View File

@ -0,0 +1,23 @@
package dev.langchain4j.store.embedding.mongodb;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.bson.codecs.pojo.annotations.BsonId;
import java.util.List;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class MongoDbDocument {
@BsonId
private String id;
private List<Float> embedding;
private String text;
private Map<String, String> metadata;
}

View File

@ -0,0 +1,322 @@
package dev.langchain4j.store.embedding.mongodb;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCommandException;
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.search.VectorSearchOptions;
import com.mongodb.client.result.InsertManyResult;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import org.bson.Document;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import static com.mongodb.client.model.Aggregates.*;
import static com.mongodb.client.model.Projections.*;
import static com.mongodb.client.model.search.SearchPath.fieldPath;
import static com.mongodb.client.model.search.VectorSearchOptions.vectorSearchOptions;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static dev.langchain4j.store.embedding.mongodb.IndexMapping.defaultIndexMapping;
import static dev.langchain4j.store.embedding.mongodb.MappingUtils.fromIndexMapping;
import static dev.langchain4j.store.embedding.mongodb.MappingUtils.toMongoDbDocument;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
/**
* Represents a <a href="https://www.mongodb.com/">MongoDB</a> index as an embedding store.
* <p>
* More <a href="https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/">info</a> to set up MongoDb as vectorDatabase
* <p>
* <a href="https://www.mongodb.com/developer/products/atlas/semantic-search-mongodb-atlas-vector-search/">tutorial</a> how to use a knn-vector in MongoDB Atlas (great startingpoint)
*/
public class MongoDbEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final Logger log = LoggerFactory.getLogger(MongoDbEmbeddingStore.class);
private final MongoCollection<MongoDbDocument> collection;
private final String indexName;
private final long maxResultRatio;
private final VectorSearchOptions vectorSearchOptions;
public MongoDbEmbeddingStore(MongoClient mongoClient,
String databaseName,
String collectionName,
String indexName,
Long maxResultRatio,
CreateCollectionOptions createCollectionOptions,
Bson filter,
IndexMapping indexMapping,
Boolean createIndex) {
databaseName = ensureNotNull(databaseName, "databaseName");
collectionName = ensureNotNull(collectionName, "collectionName");
createIndex = getOrDefault(createIndex, false);
this.indexName = ensureNotNull(indexName, "indexName");
this.maxResultRatio = getOrDefault(maxResultRatio, 10L);
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
.register(MongoDbDocument.class, MongoDbMatchedDocument.class)
.build());
CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry);
// create collection if not exist
MongoDatabase database = mongoClient.getDatabase(databaseName);
if (!isCollectionExist(database, collectionName)) {
createCollection(database, collectionName, getOrDefault(createCollectionOptions, new CreateCollectionOptions()));
}
this.collection = database.getCollection(collectionName, MongoDbDocument.class).withCodecRegistry(codecRegistry);
this.vectorSearchOptions = filter == null ? vectorSearchOptions() : vectorSearchOptions().filter(filter);
// create index if not exist
if (Boolean.TRUE.equals(createIndex) && !isIndexExist(this.indexName)) {
createIndex(this.indexName, getOrDefault(indexMapping, defaultIndexMapping()));
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private MongoClient mongoClient;
private String databaseName;
private String collectionName;
private String indexName;
private Long maxResultRatio;
private CreateCollectionOptions createCollectionOptions;
private Bson filter;
private IndexMapping indexMapping;
/**
* Whether MongoDB Atlas is deployed in cloud
*
* <p>if true, you need to create index in <a href="https://cloud.mongodb.com/">MongoDB Atlas</a></p>
* <p>if false, {@link MongoDbEmbeddingStore} will create collection and index automatically</p>
*/
private Boolean createIndex;
/**
* Build Mongo Client, Please close the client to release resources after usage
*/
public Builder fromClient(MongoClient mongoClient) {
this.mongoClient = mongoClient;
return this;
}
public Builder databaseName(String databaseName) {
this.databaseName = databaseName;
return this;
}
public Builder collectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
public Builder indexName(String indexName) {
this.indexName = indexName;
return this;
}
public Builder maxResultRatio(Long maxResultRatio) {
this.maxResultRatio = maxResultRatio;
return this;
}
public Builder createCollectionOptions(CreateCollectionOptions createCollectionOptions) {
this.createCollectionOptions = createCollectionOptions;
return this;
}
/**
* Document query filter, all fields included in filter must be contained in {@link IndexMapping#metadataFieldNames}
*
* <p>For example:</p>
*
* <ul>
* <li>AND filter: Filters.and(Filters.in("type", asList("TXT", "md")), Filters.eqFull("test-key", "test-value"))</li>
* <li>OR filter: Filters.or(Filters.in("type", asList("TXT", "md")), Filters.eqFull("test-key", "test-value"))</li>
* </ul>
*
* @param filter document query filter
* @return builder
*/
public Builder filter(Bson filter) {
this.filter = filter;
return this;
}
/**
* set MongoDB search index fields mapping
*
* <p>if {@link Builder#createIndex} is true, then indexMapping not work</p>
*
* @param indexMapping MongoDB search index fields mapping
* @return builder
*/
public Builder indexMapping(IndexMapping indexMapping) {
this.indexMapping = indexMapping;
return this;
}
/**
* Set whether in production mode, production mode will not create index automatically
*
* <p>default value is false</p>
*
* @param createIndex whether in production mode
* @return builder
*/
public Builder createIndex(Boolean createIndex) {
this.createIndex = createIndex;
return this;
}
public MongoDbEmbeddingStore build() {
return new MongoDbEmbeddingStore(mongoClient, databaseName, collectionName, indexName, maxResultRatio, createCollectionOptions, filter, indexMapping, createIndex);
}
}
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
addInternal(id, embedding, null);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = randomUUID();
addInternal(id, embedding, textSegment);
return id;
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = embeddings.stream()
.map(ignored -> randomUUID())
.collect(toList());
addAllInternal(ids, embeddings, null);
return ids;
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
List<String> ids = embeddings.stream()
.map(ignored -> randomUUID())
.collect(toList());
addAllInternal(ids, embeddings, embedded);
return ids;
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
List<Double> queryVector = referenceEmbedding.vectorAsList().stream()
.map(Float::doubleValue)
.collect(toList());
long numCandidates = maxResults * maxResultRatio;
List<Bson> pipeline = Arrays.asList(
vectorSearch(
fieldPath("embedding"),
queryVector,
indexName,
numCandidates,
maxResults,
vectorSearchOptions),
project(
fields(
metaVectorSearchScore("score"),
include("embedding", "metadata", "text")
)
),
match(
Filters.gte("score", minScore)
));
try {
AggregateIterable<MongoDbMatchedDocument> results = collection.aggregate(pipeline, MongoDbMatchedDocument.class);
return StreamSupport.stream(results.spliterator(), false)
.map(MappingUtils::toEmbeddingMatch)
.collect(Collectors.toList());
} catch (MongoCommandException e) {
if (log.isErrorEnabled()) {
log.error("Error in MongoDBEmbeddingStore.findRelevant", e);
}
throw new RuntimeException(e);
}
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded));
}
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
log.info("do not add empty embeddings to MongoDB Atlas");
return;
}
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");
List<MongoDbDocument> documents = new ArrayList<>(ids.size());
for (int i = 0; i < ids.size(); i++) {
MongoDbDocument document = toMongoDbDocument(ids.get(i), embeddings.get(i), embedded == null ? null : embedded.get(i));
documents.add(document);
}
InsertManyResult result = collection.insertMany(documents);
if (!result.wasAcknowledged() && log.isWarnEnabled()) {
String errMsg = String.format("[MongoDbEmbeddingStore] Add document failed, Document=%s", documents);
log.warn(errMsg);
throw new RuntimeException(errMsg);
}
}
private boolean isCollectionExist(MongoDatabase database, String collectionName) {
return StreamSupport.stream(database.listCollectionNames().spliterator(), false)
.anyMatch(collectionName::equals);
}
private void createCollection(MongoDatabase database, String collectionName, CreateCollectionOptions createCollectionOptions) {
database.createCollection(collectionName, createCollectionOptions);
}
private boolean isIndexExist(String indexName) {
return StreamSupport.stream(collection.listSearchIndexes().spliterator(), false)
.anyMatch(index -> indexName.equals(index.getString("name")));
}
private void createIndex(String indexName, IndexMapping indexMapping) {
Document index = fromIndexMapping(indexMapping);
collection.createSearchIndex(indexName, index);
}
}

View File

@ -0,0 +1,20 @@
package dev.langchain4j.store.embedding.mongodb;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class MongoDbMatchedDocument {
private String id;
private List<Float> embedding;
private String text;
private Map<String, String> metadata;
private Double score;
}

View File

@ -0,0 +1,78 @@
package dev.langchain4j.store.embedding.mongodb;
import com.mongodb.MongoClientSettings;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import lombok.SneakyThrows;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
@Disabled("Need Cloud Mongo Atlas Credential")
class MongoDbEmbeddingStoreCloudIT extends EmbeddingStoreIT {
static MongoClient client;
MongoDbEmbeddingStore embeddingStore = MongoDbEmbeddingStore.builder()
.fromClient(client)
.databaseName("test_database")
.collectionName("test_collection")
.indexName("test_index")
.build();
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@BeforeAll
static void beforeAll() {
client = MongoClients.create("mongodb+srv://<username>:<password>@<host>/?retryWrites=true&w=majority");
}
@AfterAll
static void afterAll() {
client.close();
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Override
protected void clearStore() {
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
.register(MongoDbDocument.class, MongoDbMatchedDocument.class)
.build());
CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry);
MongoCollection<MongoDbDocument> collection = client.getDatabase("test_database")
.getCollection("test_collection", MongoDbDocument.class)
.withCodecRegistry(codecRegistry);
Bson filter = Filters.exists("embedding");
collection.deleteMany(filter);
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(2000);
}
}

View File

@ -0,0 +1,112 @@
package dev.langchain4j.store.embedding.mongodb;
import com.mongodb.*;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import lombok.SneakyThrows;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.DockerComposeContainer;
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
import org.testcontainers.shaded.com.google.common.collect.Sets;
import java.io.File;
import java.time.Duration;
import java.util.List;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
class MongoDbEmbeddingStoreFilterIT {
static final String MONGO_SERVICE_NAME = "mongo";
static final Integer MONGO_SERVICE_PORT = 27778;
static DockerComposeContainer<?> mongodb = new DockerComposeContainer<>(new File("src/test/resources/docker-compose.yml"))
.withExposedService(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT, new LogMessageWaitStrategy()
.withRegEx(".*Deployment created!.*\\n")
.withTimes(1)
.withStartupTimeout(Duration.ofMinutes(30)));
static MongoClient client;
IndexMapping indexMapping = IndexMapping.builder()
.dimension(384)
.metadataFieldNames(Sets.newHashSet("test-key"))
.build();
EmbeddingStore<TextSegment> embeddingStore = MongoDbEmbeddingStore.builder()
.fromClient(client)
.databaseName("test_database")
.collectionName("test_collection")
.indexName("test_index")
.filter(Filters.and(Filters.eqFull("metadata.test-key", "test-value")))
.indexMapping(indexMapping)
.createIndex(true)
.build();
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@BeforeAll
@SneakyThrows
static void start() {
mongodb.start();
MongoCredential credential = MongoCredential.createCredential("root", "admin", "root".toCharArray());
client = MongoClients.create(
MongoClientSettings.builder()
.credential(credential)
.serverApi(ServerApi.builder().version(ServerApiVersion.V1).build())
.applyConnectionString(new ConnectionString(String.format("mongodb://%s:%s/?directConnection=true",
mongodb.getServiceHost(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT), mongodb.getServicePort(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT))))
.build());
}
@AfterAll
static void stop() {
mongodb.stop();
client.close();
}
@Test
void should_find_relevant_with_filter() {
TextSegment segment = TextSegment.from("this segment should be found", Metadata.from("test-key", "test-value"));
Embedding embedding = embeddingModel.embed(segment.text()).content();
TextSegment filterSegment = TextSegment.from("this segment should not be found", Metadata.from("test-key", "no-value"));
Embedding filterEmbedding = embeddingModel.embed(filterSegment.text()).content();
List<String> ids = embeddingStore.addAll(asList(embedding, filterEmbedding), asList(segment, filterSegment));
assertThat(ids)
.hasSize(2);
TextSegment refSegment = TextSegment.from("find a segment");
Embedding refEmbedding = embeddingModel.embed(refSegment.text()).content();
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(refEmbedding, 2);
// Only segment should be found, filterSegment should be filtered
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(0.88, withPercentage(1));
assertThat(match.embeddingId()).isEqualTo(ids.get(0));
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(2000);
}
}

View File

@ -0,0 +1,112 @@
package dev.langchain4j.store.embedding.mongodb;
import com.mongodb.*;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import lombok.SneakyThrows;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.testcontainers.containers.DockerComposeContainer;
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
import org.testcontainers.shaded.com.google.common.collect.Sets;
import java.io.File;
import java.time.Duration;
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
/**
* If container startup timeout (because atlas cli need to download mongodb binaries, which may take a few minutes),
* the alternative way is running `docker compose up -d` in `src/test/resources`
*/
class MongoDbEmbeddingStoreLocalIT extends EmbeddingStoreIT {
static final String MONGO_SERVICE_NAME = "mongo";
static final Integer MONGO_SERVICE_PORT = 27778;
static DockerComposeContainer<?> mongodb = new DockerComposeContainer<>(new File("src/test/resources/docker-compose.yml"))
.withExposedService(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT, new LogMessageWaitStrategy()
.withRegEx(".*Deployment created!.*\\n")
.withTimes(1)
.withStartupTimeout(Duration.ofMinutes(30)));
static MongoClient client;
IndexMapping indexMapping = IndexMapping.builder()
.dimension(384)
.metadataFieldNames(Sets.newHashSet("test-key"))
.build();
EmbeddingStore<TextSegment> embeddingStore = MongoDbEmbeddingStore.builder()
.fromClient(client)
.databaseName("test_database")
.collectionName("test_collection")
.indexName("test_index")
.indexMapping(indexMapping)
.createIndex(true)
.build();
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@BeforeAll
@SneakyThrows
static void start() {
mongodb.start();
MongoCredential credential = MongoCredential.createCredential("root", "admin", "root".toCharArray());
client = MongoClients.create(
MongoClientSettings.builder()
.credential(credential)
.serverApi(ServerApi.builder().version(ServerApiVersion.V1).build())
.applyConnectionString(new ConnectionString(String.format("mongodb://%s:%s/?directConnection=true",
mongodb.getServiceHost(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT), mongodb.getServicePort(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT))))
.build());
}
@AfterAll
static void stop() {
mongodb.stop();
client.close();
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Override
protected void clearStore() {
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
.register(MongoDbDocument.class, MongoDbMatchedDocument.class)
.build());
CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry);
MongoCollection<MongoDbDocument> collection = client.getDatabase("test_database")
.getCollection("test_collection", MongoDbDocument.class)
.withCodecRegistry(codecRegistry);
Bson filter = Filters.exists("embedding");
collection.deleteMany(filter);
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(2000);
}
}

View File

@ -0,0 +1,10 @@
services:
mongo:
image: mongodb/atlas
privileged: true
command: |
/bin/bash -c "atlas deployments setup local-test --type local --port 27778 --bindIpAll --username root --password root --force && tail -f /dev/null"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
ports:
- 27778:27778

View File

@ -49,6 +49,7 @@
<module>langchain4j-weaviate</module>
<module>langchain4j-neo4j</module>
<module>langchain4j-vearch</module>
<module>langchain4j-mongodb-atlas</module>
<!-- document loaders -->
<module>document-loaders/langchain4j-document-loader-amazon-s3</module>