Integration with MongoDB (#535)
see original PR #254. There are four mainly differences: 1. Using `Testcontainer` and MongoDB Atlas Local Deployment to test 2. Create collection and index when the `MongoDBEmbeddingStore` initialize, rather than create when adding new embedding at the first time. 3. Optimize `BsonUtils`, which is replaced by `org.bson.Document` to create index mapping. 4. Rename `langchain4j-mongodb` to `langchain4j-mongodb-atlas` Local deployment tests are all passed, but cloud tests are not tested yet because I encounter some network problem when communicating with MongoDB Atlas. (But I think it doesn't matter, because local deployment is the same as cloud, the purpose of local deployment is to development and test)
This commit is contained in:
parent
dc4028b546
commit
c694755cc3
|
@ -197,6 +197,12 @@
|
|||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-mongodb-atlas</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- code execution engines -->
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.27.0-SNAPSHOT</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-mongodb-atlas</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>LangChain4j :: Integration :: MongoDB Atlas</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mongodb</groupId>
|
||||
<artifactId>mongodb-driver-sync</artifactId>
|
||||
<version>4.11.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<classifier>tests</classifier>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>mongodb</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>tinylog-impl</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,26 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class IndexMapping {
|
||||
|
||||
private int dimension;
|
||||
private Set<String> metadataFieldNames;
|
||||
|
||||
public static IndexMapping defaultIndexMapping() {
|
||||
return IndexMapping.builder()
|
||||
.dimension(1536)
|
||||
.metadataFieldNames(new HashSet<>())
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import org.bson.Document;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
class MappingUtils {
|
||||
|
||||
private MappingUtils() throws InstantiationException {
|
||||
throw new InstantiationException("can't instantiate this class");
|
||||
}
|
||||
|
||||
static MongoDbDocument toMongoDbDocument(String id, Embedding embedding, TextSegment textSegment) {
|
||||
if (textSegment == null) {
|
||||
return new MongoDbDocument(id, embedding.vectorAsList(), null, null);
|
||||
}
|
||||
return new MongoDbDocument(id, embedding.vectorAsList(), textSegment.text(), textSegment.metadata().asMap());
|
||||
}
|
||||
|
||||
static EmbeddingMatch<TextSegment> toEmbeddingMatch(MongoDbMatchedDocument matchedDocument) {
|
||||
TextSegment textSegment = null;
|
||||
if (matchedDocument.getText() != null) {
|
||||
textSegment = matchedDocument.getMetadata() == null ? TextSegment.from(matchedDocument.getText()) :
|
||||
TextSegment.from(matchedDocument.getText(), Metadata.from(matchedDocument.getMetadata()));
|
||||
}
|
||||
return new EmbeddingMatch<>(matchedDocument.getScore(), matchedDocument.getId(), Embedding.from(matchedDocument.getEmbedding()), textSegment);
|
||||
}
|
||||
|
||||
static Document fromIndexMapping(IndexMapping indexMapping) {
|
||||
Document mapping = new Document();
|
||||
mapping.append("dynamic", false);
|
||||
|
||||
Document fields = new Document();
|
||||
writeEmbedding(indexMapping.getDimension(), fields);
|
||||
|
||||
Set<String> metadataFields = indexMapping.getMetadataFieldNames();
|
||||
if (metadataFields != null && !metadataFields.isEmpty()) {
|
||||
writeMetadata(metadataFields, fields);
|
||||
}
|
||||
|
||||
mapping.append("fields", fields);
|
||||
|
||||
return new Document("mappings", mapping);
|
||||
}
|
||||
|
||||
private static void writeMetadata(Set<String> metadataFields, Document fields) {
|
||||
Document metadata = new Document();
|
||||
metadata.append("dynamic", false);
|
||||
metadata.append("type", "document");
|
||||
|
||||
Document metadataFieldDoc = new Document();
|
||||
metadataFields.forEach(field -> writeMetadataField(metadataFieldDoc, field));
|
||||
|
||||
metadata.append("fields", metadataFieldDoc);
|
||||
|
||||
fields.append("metadata", metadata);
|
||||
}
|
||||
|
||||
private static void writeMetadataField(Document metadataFieldDoc, String fieldName) {
|
||||
Document field = new Document();
|
||||
field.append("type", "token");
|
||||
metadataFieldDoc.append(fieldName, field);
|
||||
}
|
||||
|
||||
private static void writeEmbedding(int dimensions, Document fields) {
|
||||
Document embedding = new Document();
|
||||
embedding.append("dimensions", dimensions);
|
||||
embedding.append("similarity", "cosine");
|
||||
embedding.append("type", "knnVector");
|
||||
|
||||
fields.append("embedding", embedding);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.bson.codecs.pojo.annotations.BsonId;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class MongoDbDocument {
|
||||
|
||||
@BsonId
|
||||
private String id;
|
||||
private List<Float> embedding;
|
||||
private String text;
|
||||
private Map<String, String> metadata;
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import com.mongodb.MongoClientSettings;
|
||||
import com.mongodb.MongoCommandException;
|
||||
import com.mongodb.client.AggregateIterable;
|
||||
import com.mongodb.client.MongoClient;
|
||||
import com.mongodb.client.MongoCollection;
|
||||
import com.mongodb.client.MongoDatabase;
|
||||
import com.mongodb.client.model.CreateCollectionOptions;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import com.mongodb.client.model.search.VectorSearchOptions;
|
||||
import com.mongodb.client.result.InsertManyResult;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import org.bson.Document;
|
||||
import org.bson.codecs.configuration.CodecRegistry;
|
||||
import org.bson.codecs.pojo.PojoCodecProvider;
|
||||
import org.bson.conversions.Bson;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import static com.mongodb.client.model.Aggregates.*;
|
||||
import static com.mongodb.client.model.Projections.*;
|
||||
import static com.mongodb.client.model.search.SearchPath.fieldPath;
|
||||
import static com.mongodb.client.model.search.VectorSearchOptions.vectorSearchOptions;
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
|
||||
import static dev.langchain4j.store.embedding.mongodb.IndexMapping.defaultIndexMapping;
|
||||
import static dev.langchain4j.store.embedding.mongodb.MappingUtils.fromIndexMapping;
|
||||
import static dev.langchain4j.store.embedding.mongodb.MappingUtils.toMongoDbDocument;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
|
||||
|
||||
/**
|
||||
* Represents a <a href="https://www.mongodb.com/">MongoDB</a> index as an embedding store.
|
||||
* <p>
|
||||
* More <a href="https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/">info</a> to set up MongoDb as vectorDatabase
|
||||
* <p>
|
||||
* <a href="https://www.mongodb.com/developer/products/atlas/semantic-search-mongodb-atlas-vector-search/">tutorial</a> how to use a knn-vector in MongoDB Atlas (great startingpoint)
|
||||
*/
|
||||
public class MongoDbEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(MongoDbEmbeddingStore.class);
|
||||
|
||||
private final MongoCollection<MongoDbDocument> collection;
|
||||
|
||||
private final String indexName;
|
||||
private final long maxResultRatio;
|
||||
private final VectorSearchOptions vectorSearchOptions;
|
||||
|
||||
public MongoDbEmbeddingStore(MongoClient mongoClient,
|
||||
String databaseName,
|
||||
String collectionName,
|
||||
String indexName,
|
||||
Long maxResultRatio,
|
||||
CreateCollectionOptions createCollectionOptions,
|
||||
Bson filter,
|
||||
IndexMapping indexMapping,
|
||||
Boolean createIndex) {
|
||||
databaseName = ensureNotNull(databaseName, "databaseName");
|
||||
collectionName = ensureNotNull(collectionName, "collectionName");
|
||||
createIndex = getOrDefault(createIndex, false);
|
||||
this.indexName = ensureNotNull(indexName, "indexName");
|
||||
this.maxResultRatio = getOrDefault(maxResultRatio, 10L);
|
||||
|
||||
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
|
||||
.register(MongoDbDocument.class, MongoDbMatchedDocument.class)
|
||||
.build());
|
||||
CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry);
|
||||
|
||||
// create collection if not exist
|
||||
MongoDatabase database = mongoClient.getDatabase(databaseName);
|
||||
if (!isCollectionExist(database, collectionName)) {
|
||||
createCollection(database, collectionName, getOrDefault(createCollectionOptions, new CreateCollectionOptions()));
|
||||
}
|
||||
|
||||
this.collection = database.getCollection(collectionName, MongoDbDocument.class).withCodecRegistry(codecRegistry);
|
||||
this.vectorSearchOptions = filter == null ? vectorSearchOptions() : vectorSearchOptions().filter(filter);
|
||||
|
||||
// create index if not exist
|
||||
if (Boolean.TRUE.equals(createIndex) && !isIndexExist(this.indexName)) {
|
||||
createIndex(this.indexName, getOrDefault(indexMapping, defaultIndexMapping()));
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private MongoClient mongoClient;
|
||||
private String databaseName;
|
||||
private String collectionName;
|
||||
private String indexName;
|
||||
private Long maxResultRatio;
|
||||
private CreateCollectionOptions createCollectionOptions;
|
||||
private Bson filter;
|
||||
private IndexMapping indexMapping;
|
||||
/**
|
||||
* Whether MongoDB Atlas is deployed in cloud
|
||||
*
|
||||
* <p>if true, you need to create index in <a href="https://cloud.mongodb.com/">MongoDB Atlas</a></p>
|
||||
* <p>if false, {@link MongoDbEmbeddingStore} will create collection and index automatically</p>
|
||||
*/
|
||||
private Boolean createIndex;
|
||||
|
||||
/**
|
||||
* Build Mongo Client, Please close the client to release resources after usage
|
||||
*/
|
||||
public Builder fromClient(MongoClient mongoClient) {
|
||||
this.mongoClient = mongoClient;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder databaseName(String databaseName) {
|
||||
this.databaseName = databaseName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder collectionName(String collectionName) {
|
||||
this.collectionName = collectionName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder indexName(String indexName) {
|
||||
this.indexName = indexName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder maxResultRatio(Long maxResultRatio) {
|
||||
this.maxResultRatio = maxResultRatio;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder createCollectionOptions(CreateCollectionOptions createCollectionOptions) {
|
||||
this.createCollectionOptions = createCollectionOptions;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Document query filter, all fields included in filter must be contained in {@link IndexMapping#metadataFieldNames}
|
||||
*
|
||||
* <p>For example:</p>
|
||||
*
|
||||
* <ul>
|
||||
* <li>AND filter: Filters.and(Filters.in("type", asList("TXT", "md")), Filters.eqFull("test-key", "test-value"))</li>
|
||||
* <li>OR filter: Filters.or(Filters.in("type", asList("TXT", "md")), Filters.eqFull("test-key", "test-value"))</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param filter document query filter
|
||||
* @return builder
|
||||
*/
|
||||
public Builder filter(Bson filter) {
|
||||
this.filter = filter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* set MongoDB search index fields mapping
|
||||
*
|
||||
* <p>if {@link Builder#createIndex} is true, then indexMapping not work</p>
|
||||
*
|
||||
* @param indexMapping MongoDB search index fields mapping
|
||||
* @return builder
|
||||
*/
|
||||
public Builder indexMapping(IndexMapping indexMapping) {
|
||||
this.indexMapping = indexMapping;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set whether in production mode, production mode will not create index automatically
|
||||
*
|
||||
* <p>default value is false</p>
|
||||
*
|
||||
* @param createIndex whether in production mode
|
||||
* @return builder
|
||||
*/
|
||||
public Builder createIndex(Boolean createIndex) {
|
||||
this.createIndex = createIndex;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MongoDbEmbeddingStore build() {
|
||||
return new MongoDbEmbeddingStore(mongoClient, databaseName, collectionName, indexName, maxResultRatio, createCollectionOptions, filter, indexMapping, createIndex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
String id = randomUUID();
|
||||
add(id, embedding);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
addInternal(id, embedding, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
String id = randomUUID();
|
||||
addInternal(id, embedding, textSegment);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
List<String> ids = embeddings.stream()
|
||||
.map(ignored -> randomUUID())
|
||||
.collect(toList());
|
||||
addAllInternal(ids, embeddings, null);
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
List<String> ids = embeddings.stream()
|
||||
.map(ignored -> randomUUID())
|
||||
.collect(toList());
|
||||
addAllInternal(ids, embeddings, embedded);
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||
List<Double> queryVector = referenceEmbedding.vectorAsList().stream()
|
||||
.map(Float::doubleValue)
|
||||
.collect(toList());
|
||||
long numCandidates = maxResults * maxResultRatio;
|
||||
|
||||
List<Bson> pipeline = Arrays.asList(
|
||||
vectorSearch(
|
||||
fieldPath("embedding"),
|
||||
queryVector,
|
||||
indexName,
|
||||
numCandidates,
|
||||
maxResults,
|
||||
vectorSearchOptions),
|
||||
project(
|
||||
fields(
|
||||
metaVectorSearchScore("score"),
|
||||
include("embedding", "metadata", "text")
|
||||
)
|
||||
),
|
||||
match(
|
||||
Filters.gte("score", minScore)
|
||||
));
|
||||
|
||||
try {
|
||||
AggregateIterable<MongoDbMatchedDocument> results = collection.aggregate(pipeline, MongoDbMatchedDocument.class);
|
||||
|
||||
return StreamSupport.stream(results.spliterator(), false)
|
||||
.map(MappingUtils::toEmbeddingMatch)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
} catch (MongoCommandException e) {
|
||||
if (log.isErrorEnabled()) {
|
||||
log.error("Error in MongoDBEmbeddingStore.findRelevant", e);
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
|
||||
addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded));
|
||||
}
|
||||
|
||||
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
|
||||
log.info("do not add empty embeddings to MongoDB Atlas");
|
||||
return;
|
||||
}
|
||||
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
|
||||
ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");
|
||||
|
||||
List<MongoDbDocument> documents = new ArrayList<>(ids.size());
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
MongoDbDocument document = toMongoDbDocument(ids.get(i), embeddings.get(i), embedded == null ? null : embedded.get(i));
|
||||
documents.add(document);
|
||||
}
|
||||
|
||||
InsertManyResult result = collection.insertMany(documents);
|
||||
if (!result.wasAcknowledged() && log.isWarnEnabled()) {
|
||||
String errMsg = String.format("[MongoDbEmbeddingStore] Add document failed, Document=%s", documents);
|
||||
log.warn(errMsg);
|
||||
throw new RuntimeException(errMsg);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isCollectionExist(MongoDatabase database, String collectionName) {
|
||||
return StreamSupport.stream(database.listCollectionNames().spliterator(), false)
|
||||
.anyMatch(collectionName::equals);
|
||||
}
|
||||
|
||||
private void createCollection(MongoDatabase database, String collectionName, CreateCollectionOptions createCollectionOptions) {
|
||||
database.createCollection(collectionName, createCollectionOptions);
|
||||
}
|
||||
|
||||
private boolean isIndexExist(String indexName) {
|
||||
return StreamSupport.stream(collection.listSearchIndexes().spliterator(), false)
|
||||
.anyMatch(index -> indexName.equals(index.getString("name")));
|
||||
}
|
||||
|
||||
private void createIndex(String indexName, IndexMapping indexMapping) {
|
||||
Document index = fromIndexMapping(indexMapping);
|
||||
collection.createSearchIndex(indexName, index);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class MongoDbMatchedDocument {
|
||||
|
||||
private String id;
|
||||
private List<Float> embedding;
|
||||
private String text;
|
||||
private Map<String, String> metadata;
|
||||
private Double score;
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import com.mongodb.MongoClientSettings;
|
||||
import com.mongodb.client.MongoClient;
|
||||
import com.mongodb.client.MongoClients;
|
||||
import com.mongodb.client.MongoCollection;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import lombok.SneakyThrows;
|
||||
import org.bson.codecs.configuration.CodecRegistry;
|
||||
import org.bson.codecs.pojo.PojoCodecProvider;
|
||||
import org.bson.conversions.Bson;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
|
||||
|
||||
@Disabled("Need Cloud Mongo Atlas Credential")
|
||||
class MongoDbEmbeddingStoreCloudIT extends EmbeddingStoreIT {
|
||||
|
||||
static MongoClient client;
|
||||
|
||||
MongoDbEmbeddingStore embeddingStore = MongoDbEmbeddingStore.builder()
|
||||
.fromClient(client)
|
||||
.databaseName("test_database")
|
||||
.collectionName("test_collection")
|
||||
.indexName("test_index")
|
||||
.build();
|
||||
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@BeforeAll
|
||||
static void beforeAll() {
|
||||
client = MongoClients.create("mongodb+srv://<username>:<password>@<host>/?retryWrites=true&w=majority");
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void afterAll() {
|
||||
client.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
|
||||
.register(MongoDbDocument.class, MongoDbMatchedDocument.class)
|
||||
.build());
|
||||
CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry);
|
||||
|
||||
MongoCollection<MongoDbDocument> collection = client.getDatabase("test_database")
|
||||
.getCollection("test_collection", MongoDbDocument.class)
|
||||
.withCodecRegistry(codecRegistry);
|
||||
|
||||
Bson filter = Filters.exists("embedding");
|
||||
collection.deleteMany(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(2000);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import com.mongodb.*;
|
||||
import com.mongodb.client.MongoClient;
|
||||
import com.mongodb.client.MongoClients;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import lombok.SneakyThrows;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testcontainers.containers.DockerComposeContainer;
|
||||
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
|
||||
import org.testcontainers.shaded.com.google.common.collect.Sets;
|
||||
|
||||
import java.io.File;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
class MongoDbEmbeddingStoreFilterIT {
|
||||
|
||||
static final String MONGO_SERVICE_NAME = "mongo";
|
||||
static final Integer MONGO_SERVICE_PORT = 27778;
|
||||
static DockerComposeContainer<?> mongodb = new DockerComposeContainer<>(new File("src/test/resources/docker-compose.yml"))
|
||||
.withExposedService(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT, new LogMessageWaitStrategy()
|
||||
.withRegEx(".*Deployment created!.*\\n")
|
||||
.withTimes(1)
|
||||
.withStartupTimeout(Duration.ofMinutes(30)));
|
||||
|
||||
static MongoClient client;
|
||||
|
||||
IndexMapping indexMapping = IndexMapping.builder()
|
||||
.dimension(384)
|
||||
.metadataFieldNames(Sets.newHashSet("test-key"))
|
||||
.build();
|
||||
|
||||
EmbeddingStore<TextSegment> embeddingStore = MongoDbEmbeddingStore.builder()
|
||||
.fromClient(client)
|
||||
.databaseName("test_database")
|
||||
.collectionName("test_collection")
|
||||
.indexName("test_index")
|
||||
.filter(Filters.and(Filters.eqFull("metadata.test-key", "test-value")))
|
||||
.indexMapping(indexMapping)
|
||||
.createIndex(true)
|
||||
.build();
|
||||
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@BeforeAll
|
||||
@SneakyThrows
|
||||
static void start() {
|
||||
mongodb.start();
|
||||
|
||||
MongoCredential credential = MongoCredential.createCredential("root", "admin", "root".toCharArray());
|
||||
client = MongoClients.create(
|
||||
MongoClientSettings.builder()
|
||||
.credential(credential)
|
||||
.serverApi(ServerApi.builder().version(ServerApiVersion.V1).build())
|
||||
.applyConnectionString(new ConnectionString(String.format("mongodb://%s:%s/?directConnection=true",
|
||||
mongodb.getServiceHost(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT), mongodb.getServicePort(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT))))
|
||||
.build());
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void stop() {
|
||||
mongodb.stop();
|
||||
client.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_find_relevant_with_filter() {
|
||||
TextSegment segment = TextSegment.from("this segment should be found", Metadata.from("test-key", "test-value"));
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
TextSegment filterSegment = TextSegment.from("this segment should not be found", Metadata.from("test-key", "no-value"));
|
||||
Embedding filterEmbedding = embeddingModel.embed(filterSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(asList(embedding, filterEmbedding), asList(segment, filterSegment));
|
||||
assertThat(ids)
|
||||
.hasSize(2);
|
||||
|
||||
TextSegment refSegment = TextSegment.from("find a segment");
|
||||
Embedding refEmbedding = embeddingModel.embed(refSegment.text()).content();
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(refEmbedding, 2);
|
||||
// Only segment should be found, filterSegment should be filtered
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(0.88, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(2000);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import com.mongodb.*;
|
||||
import com.mongodb.client.MongoClient;
|
||||
import com.mongodb.client.MongoClients;
|
||||
import com.mongodb.client.MongoCollection;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import lombok.SneakyThrows;
|
||||
import org.bson.codecs.configuration.CodecRegistry;
|
||||
import org.bson.codecs.pojo.PojoCodecProvider;
|
||||
import org.bson.conversions.Bson;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.testcontainers.containers.DockerComposeContainer;
|
||||
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
|
||||
import org.testcontainers.shaded.com.google.common.collect.Sets;
|
||||
|
||||
import java.io.File;
|
||||
import java.time.Duration;
|
||||
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
|
||||
|
||||
/**
|
||||
* If container startup timeout (because atlas cli need to download mongodb binaries, which may take a few minutes),
|
||||
* the alternative way is running `docker compose up -d` in `src/test/resources`
|
||||
*/
|
||||
class MongoDbEmbeddingStoreLocalIT extends EmbeddingStoreIT {
|
||||
|
||||
static final String MONGO_SERVICE_NAME = "mongo";
|
||||
static final Integer MONGO_SERVICE_PORT = 27778;
|
||||
static DockerComposeContainer<?> mongodb = new DockerComposeContainer<>(new File("src/test/resources/docker-compose.yml"))
|
||||
.withExposedService(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT, new LogMessageWaitStrategy()
|
||||
.withRegEx(".*Deployment created!.*\\n")
|
||||
.withTimes(1)
|
||||
.withStartupTimeout(Duration.ofMinutes(30)));
|
||||
|
||||
static MongoClient client;
|
||||
|
||||
IndexMapping indexMapping = IndexMapping.builder()
|
||||
.dimension(384)
|
||||
.metadataFieldNames(Sets.newHashSet("test-key"))
|
||||
.build();
|
||||
|
||||
EmbeddingStore<TextSegment> embeddingStore = MongoDbEmbeddingStore.builder()
|
||||
.fromClient(client)
|
||||
.databaseName("test_database")
|
||||
.collectionName("test_collection")
|
||||
.indexName("test_index")
|
||||
.indexMapping(indexMapping)
|
||||
.createIndex(true)
|
||||
.build();
|
||||
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@BeforeAll
|
||||
@SneakyThrows
|
||||
static void start() {
|
||||
mongodb.start();
|
||||
|
||||
MongoCredential credential = MongoCredential.createCredential("root", "admin", "root".toCharArray());
|
||||
client = MongoClients.create(
|
||||
MongoClientSettings.builder()
|
||||
.credential(credential)
|
||||
.serverApi(ServerApi.builder().version(ServerApiVersion.V1).build())
|
||||
.applyConnectionString(new ConnectionString(String.format("mongodb://%s:%s/?directConnection=true",
|
||||
mongodb.getServiceHost(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT), mongodb.getServicePort(MONGO_SERVICE_NAME, MONGO_SERVICE_PORT))))
|
||||
.build());
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void stop() {
|
||||
mongodb.stop();
|
||||
client.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
|
||||
.register(MongoDbDocument.class, MongoDbMatchedDocument.class)
|
||||
.build());
|
||||
CodecRegistry codecRegistry = fromRegistries(MongoClientSettings.getDefaultCodecRegistry(), pojoCodecRegistry);
|
||||
|
||||
MongoCollection<MongoDbDocument> collection = client.getDatabase("test_database")
|
||||
.getCollection("test_collection", MongoDbDocument.class)
|
||||
.withCodecRegistry(codecRegistry);
|
||||
|
||||
Bson filter = Filters.exists("embedding");
|
||||
collection.deleteMany(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(2000);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
services:
|
||||
mongo:
|
||||
image: mongodb/atlas
|
||||
privileged: true
|
||||
command: |
|
||||
/bin/bash -c "atlas deployments setup local-test --type local --port 27778 --bindIpAll --username root --password root --force && tail -f /dev/null"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
ports:
|
||||
- 27778:27778
|
Loading…
Reference in New Issue