feat: Metadata filtering Qdrant (#1646)

## Description
- Filtering support for Qdrant with tests for the converter.
- Updated metadata storage handling to support complex metadata types.
No longer limited to `Map<String, String>`.

Closes #1600 

## NOTE
In Qdrant,
- Eq, NEq, In, NIn don't allow float and double values. Only integers
and strings.
- LT, GT, LTE, GTE allow only numbers, not alphabets.
- For In and NIn conditions, if the key doesn't exist in the metadata,
it is not matched.
This commit is contained in:
Anush 2024-09-12 14:48:28 +05:30 committed by GitHub
parent aaaa71a5d7
commit 2a5189a78a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 717 additions and 113 deletions

View File

@ -1,7 +1,7 @@
Thank you for investing your time and effort in contributing to our project, we appreciate it a lot! 🤗 Thank you for investing your time and effort in contributing to our project, we appreciate it a lot! 🤗
# General guidelines # General guidelines
- If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly. - If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly.
- Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/) - Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/)
- Keep the code compatible with Java 8. We plan to increase the baseline to Java 17 a bit later. - Keep the code compatible with Java 8. We plan to increase the baseline to Java 17 a bit later.
@ -14,20 +14,20 @@ Thank you for investing your time and effort in contributing to our project, we
- Follow existing code style present in the project. - Follow existing code style present in the project.
- Large features should be discussed with maintainers before implementation. Please ping @langchain4j in the comments on the issue. - Large features should be discussed with maintainers before implementation. Please ping @langchain4j in the comments on the issue.
# Priorities # Priorities
All [issues](https://github.com/langchain4j/langchain4j/issues) are prioritized by maintainers. There are 4 priorities: [P1](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP1), [P2](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP2), [P3](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP3) and [P4](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP4). All [issues](https://github.com/langchain4j/langchain4j/issues) are prioritized by maintainers. There are 4 priorities: [P1](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP1), [P2](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP2), [P3](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP3) and [P4](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP4).
Please start with the higher priorities. PRs will be reviewed in order of priority, with bugs being a higher priority than new features. Please start with the higher priorities. PRs will be reviewed in order of priority, with bugs being a higher priority than new features.
Please note that we do not have the capacity to review PRs immediately. We ask for your patience. We are doing our best to review your PR as quickly as possible. Please note that we do not have the capacity to review PRs immediately. We ask for your patience. We are doing our best to review your PR as quickly as possible.
# Opening an issue # Opening an issue
- Please fill in all sections of the issue template. - Please fill in all sections of the issue template.
# Opening a draft PR # Opening a draft PR
- Please open the PR as a draft initially. Once it is reviewed and approved, we will then ask you to finalize it (see section below). - Please open the PR as a draft initially. Once it is reviewed and approved, we will then ask you to finalize it (see section below).
- Fill in all the sections of the PR template. - Fill in all the sections of the PR template.
- Please make it easier to review your PR: - Please make it easier to review your PR:
@ -35,14 +35,14 @@ Please note that we do not have the capacity to review PRs immediately. We ask f
- Do not combine refactoring with changes in a single PR. - Do not combine refactoring with changes in a single PR.
- Avoid reformatting existing code. - Avoid reformatting existing code.
# Finalizing the draft PR # Finalizing the draft PR
- Add [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) (if required). - Add [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) (if required).
- Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples) (if required). - Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples) (if required).
- [Mark a PR as ready for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request#marking-a-pull-request-as-ready-for-review) - [Mark a PR as ready for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request#marking-a-pull-request-as-ready-for-review)
# Guidelines on adding a new model integration # Guidelines on adding a new model integration
- [Integration with Anthropic](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-anthropic) is a good example. - [Integration with Anthropic](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-anthropic) is a good example.
- Use the official SDK if available. - Use the official SDK if available.
- If the official SDK is not available, use Retrofit and Jackson to implement the client. - If the official SDK is not available, use Retrofit and Jackson to implement the client.
@ -51,8 +51,8 @@ Please note that we do not have the capacity to review PRs immediately. We ask f
- Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml). - Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml).
- It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). - It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring).
# Guidelines on adding a new embedding store integration # Guidelines on adding a new embedding store integration
- [Integration with Chroma](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-chroma) is a good example. - [Integration with Chroma](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-chroma) is a good example.
- Use the official SDK if available. - Use the official SDK if available.
- If the official SDK is not available, use Retrofit and Jackson to implement the client. - If the official SDK is not available, use Retrofit and Jackson to implement the client.
@ -61,8 +61,8 @@ Please note that we do not have the capacity to review PRs immediately. We ask f
- Document the new integration [here](https://github.com/langchain4j/langchain4j/blob/main/README.md), [here](https://github.com/langchain4j/langchain4j/tree/main/docs/docs/integrations/embedding-stores) and [here](https://github.com/langchain4j/langchain4j/blob/main/docs/docs/integrations/embedding-stores/index.md). - Document the new integration [here](https://github.com/langchain4j/langchain4j/blob/main/README.md), [here](https://github.com/langchain4j/langchain4j/tree/main/docs/docs/integrations/embedding-stores) and [here](https://github.com/langchain4j/langchain4j/blob/main/docs/docs/integrations/embedding-stores/index.md).
- Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples), similar to [this](https://github.com/langchain4j/langchain4j-examples/tree/main/chroma-example). - Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples), similar to [this](https://github.com/langchain4j/langchain4j-examples/tree/main/chroma-example).
- Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml). - Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml).
- It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). (after - It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). (after
# Guidelines on changing an existing embedding store integration # Guidelines on changing an existing embedding store integration
- Ensure that your changes are backwards compatible. `Embedding`s and `TextSegment`s persisted with the latest released version of LangChain4j should still work. - Ensure that your changes are backwards compatible. `Embedding`s and `TextSegment`s persisted with the latest released version of LangChain4j should still work.

View File

@ -82,7 +82,7 @@
<dependency> <dependency>
<groupId>io.grpc</groupId> <groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId> <artifactId>grpc-protobuf</artifactId>
<version>1.59.0</version> <version>1.65.1</version>
</dependency> </dependency>
<dependency> <dependency>
@ -94,8 +94,8 @@
<dependency> <dependency>
<groupId>io.qdrant</groupId> <groupId>io.qdrant</groupId>
<artifactId>client</artifactId> <artifactId>client</artifactId>
<version>1.7.1</version> <version>1.11.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.awaitility</groupId> <groupId>org.awaitility</groupId>

View File

@ -0,0 +1,49 @@
package dev.langchain4j.store.embedding.qdrant;
import java.util.Map;
import java.util.stream.Collectors;
import io.qdrant.client.grpc.JsonWithInt.ListValue;
import io.qdrant.client.grpc.JsonWithInt.Value;
/**
* Utility methods for building Java objects from io.qdrant.client.grpc.JsonWithInt.Value.
*
* @author Anush Shetty
* @since 0.8.1
*/
class ObjectFactory {
private ObjectFactory() {
}
public static Object object(Value value) {
switch (value.getKindCase()) {
case INTEGER_VALUE:
return value.getIntegerValue();
case STRING_VALUE:
return value.getStringValue();
case DOUBLE_VALUE:
return value.getDoubleValue();
case BOOL_VALUE:
return value.getBoolValue();
case LIST_VALUE:
return object(value.getListValue());
case STRUCT_VALUE:
return objectMap(value.getStructValue().getFieldsMap());
case NULL_VALUE:
return null;
case KIND_NOT_SET:
default:
throw new IllegalArgumentException("Unknown value type: " + value.getKindCase());
}
}
private static Map<String, Object> objectMap(Map<String, Value> payload) {
return payload.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> object(e.getValue())));
}
private static Object object(ListValue listValue) {
return listValue.getValuesList().stream().map(ObjectFactory::object).collect(Collectors.toList());
}
}

View File

@ -14,10 +14,7 @@ import static java.util.stream.Collectors.toMap;
import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.*;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import io.qdrant.client.QdrantClient; import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient; import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.WithVectorsSelectorFactory; import io.qdrant.client.WithVectorsSelectorFactory;
@ -38,7 +35,8 @@ import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** /**
* Represents a <a href="https://qdrant.tech/">Qdrant</a> collection as an embedding store. With * Represents a <a href="https://qdrant.tech/">Qdrant</a> collection as an
* embedding store. With
* support for storing {@link dev.langchain4j.data.document.Metadata}. * support for storing {@link dev.langchain4j.data.document.Metadata}.
*/ */
public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> { public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
@ -49,11 +47,12 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
/** /**
* @param collectionName The name of the Qdrant collection. * @param collectionName The name of the Qdrant collection.
* @param host The host of the Qdrant instance. * @param host The host of the Qdrant instance.
* @param port The GRPC port of the Qdrant instance. * @param port The GRPC port of the Qdrant instance.
* @param useTls Whether to use TLS(HTTPS). * @param useTls Whether to use TLS(HTTPS).
* @param payloadTextKey The field name of the text segment in the Qdrant payload. * @param payloadTextKey The field name of the text segment in the Qdrant
* @param apiKey The Qdrant API key to authenticate with. * payload.
* @param apiKey The Qdrant API key to authenticate with.
*/ */
public QdrantEmbeddingStore( public QdrantEmbeddingStore(
String collectionName, String collectionName,
@ -75,9 +74,10 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
/** /**
* @param client A Qdrant client instance. * @param client A Qdrant client instance.
* @param collectionName The name of the Qdrant collection. * @param collectionName The name of the Qdrant collection.
* @param payloadTextKey The field name of the text segment in the Qdrant payload. * @param payloadTextKey The field name of the text segment in the Qdrant
* payload.
*/ */
public QdrantEmbeddingStore(QdrantClient client, String collectionName, String payloadTextKey) { public QdrantEmbeddingStore(QdrantClient client, String collectionName, String payloadTextKey) {
this.client = client; this.client = client;
@ -132,50 +132,88 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
private void addAllInternal( private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) { List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) throws RuntimeException {
List<PointStruct> points = new ArrayList<>(embeddings.size());
for (int i = 0; i < embeddings.size(); i++) {
String id = ids.get(i);
UUID uuid = UUID.fromString(id);
Embedding embedding = embeddings.get(i);
PointStruct.Builder pointBuilder =
PointStruct.newBuilder().setId(id(uuid)).setVectors(vectors(embedding.vector()));
if (textSegments != null) {
pointBuilder.putPayload(payloadTextKey, value(textSegments.get(i).text()));
textSegments
.get(i)
.metadata()
.asMap()
.forEach((key, value) -> pointBuilder.putPayload(key, value(value)));
}
points.add(pointBuilder.build());
}
try { try {
List<PointStruct> points = new ArrayList<>(embeddings.size());
for (int i = 0; i < embeddings.size(); i++) {
String id = ids.get(i);
UUID uuid = UUID.fromString(id);
Embedding embedding = embeddings.get(i);
PointStruct.Builder pointBuilder = PointStruct.newBuilder().setId(id(uuid))
.setVectors(vectors(embedding.vector()));
if (textSegments != null) {
Map<String, Object> metadata = textSegments
.get(i)
.metadata()
.toMap();
Map<String, Value> payload = ValueMapFactory.valueMap(metadata);
payload.put(payloadTextKey, value(textSegments.get(i).text()));
pointBuilder.putAllPayload(payload);
}
points.add(pointBuilder.build());
}
client.upsertAsync(collectionName, points).get(); client.upsertAsync(collectionName, points).get();
} catch (InterruptedException | ExecutionException e) { } catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
SearchPoints.Builder searchBuilder = SearchPoints.newBuilder()
.setCollectionName(collectionName)
.addAllVector(request.queryEmbedding().vectorAsList())
.setWithVectors(WithVectorsSelectorFactory.enable(true))
.setWithPayload(enable(true))
.setLimit(request.maxResults());
if (request.filter() != null) {
Filter filter = QdrantFilterConverter.convertExpression(request.filter());
searchBuilder.setFilter(filter);
}
List<ScoredPoint> results;
try {
results = client.searchAsync(searchBuilder.build()).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
if (results.isEmpty()) {
return new EmbeddingSearchResult<TextSegment>(emptyList());
}
List<EmbeddingMatch<TextSegment>> matches = results.stream()
.map(vector -> toEmbeddingMatch(vector, request.queryEmbedding()))
.filter(match -> match.score() >= request.minScore())
.sorted(comparingDouble(EmbeddingMatch::score))
.collect(toList());
Collections.reverse(matches);
return new EmbeddingSearchResult<TextSegment>(matches);
}
@Override @Override
public List<EmbeddingMatch<TextSegment>> findRelevant( public List<EmbeddingMatch<TextSegment>> findRelevant(
Embedding referenceEmbedding, int maxResults, double minScore) { Embedding referenceEmbedding, int maxResults, double minScore) {
SearchPoints search = SearchPoints search = SearchPoints.newBuilder()
SearchPoints.newBuilder() .setCollectionName(collectionName)
.setCollectionName(collectionName) .addAllVector(referenceEmbedding.vectorAsList())
.addAllVector(referenceEmbedding.vectorAsList()) .setWithVectors(WithVectorsSelectorFactory.enable(true))
.setWithVectors(WithVectorsSelectorFactory.enable(true)) .setWithPayload(enable(true))
.setWithPayload(enable(true)) .setLimit(maxResults)
.setLimit(maxResults) .build();
.build();
List<ScoredPoint> results; List<ScoredPoint> results;
@ -189,12 +227,11 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
return emptyList(); return emptyList();
} }
List<EmbeddingMatch<TextSegment>> matches = List<EmbeddingMatch<TextSegment>> matches = results.stream()
results.stream() .map(vector -> toEmbeddingMatch(vector, referenceEmbedding))
.map(vector -> toEmbeddingMatch(vector, referenceEmbedding)) .filter(match -> match.score() >= minScore)
.filter(match -> match.score() >= minScore) .sorted(comparingDouble(EmbeddingMatch::score))
.sorted(comparingDouble(EmbeddingMatch::score)) .collect(toList());
.collect(toList());
Collections.reverse(matches); Collections.reverse(matches);
@ -231,10 +268,9 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
Value textSegmentValue = payload.getOrDefault(payloadTextKey, null); Value textSegmentValue = payload.getOrDefault(payloadTextKey, null);
Map<String, String> metadata = Map<String, Object> metadata = payload.entrySet().stream()
payload.entrySet().stream() .filter(entry -> !entry.getKey().equals(payloadTextKey))
.filter(entry -> !entry.getKey().equals(payloadTextKey)) .collect(toMap(Map.Entry::getKey, entry -> ObjectFactory.object(entry.getValue())));
.collect(toMap(Map.Entry::getKey, entry -> entry.getValue().getStringValue()));
Embedding embedding = Embedding.from(scoredPoint.getVectors().getVector().getDataList()); Embedding embedding = Embedding.from(scoredPoint.getVectors().getVector().getDataList());
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding); double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);
@ -297,8 +333,9 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
/** /**
* @param payloadTextKey The field name of the text segment in the payload. Defaults to * @param payloadTextKey The field name of the text segment in the payload.
* "text_segment". * Defaults to
* "text_segment".
* @return * @return
*/ */
public Builder payloadTextKey(String payloadTextKey) { public Builder payloadTextKey(String payloadTextKey) {

View File

@ -0,0 +1,205 @@
package dev.langchain4j.store.embedding.qdrant;
import io.qdrant.client.ConditionFactory;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.*;
import dev.langchain4j.store.embedding.filter.logical.And;
import dev.langchain4j.store.embedding.filter.logical.Not;
import dev.langchain4j.store.embedding.filter.logical.Or;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import io.qdrant.client.grpc.Points;
import io.qdrant.client.grpc.Points.Condition;
class QdrantFilterConverter {
public static Points.Filter convertExpression(Filter expression) {
return QdrantFilterConverter.convertOperand(expression);
}
private static Points.Filter convertOperand(Filter operand) {
Points.Filter.Builder context = Points.Filter.newBuilder();
List<Condition> mustClauses = new ArrayList<Condition>();
List<Condition> shouldClauses = new ArrayList<Condition>();
List<Condition> mustNotClauses = new ArrayList<Condition>();
if (operand instanceof Not) {
Not not = (Not) operand;
mustNotClauses.add(ConditionFactory.filter(convertOperand(not.expression())));
} else if (operand instanceof And) {
And and = (And) operand;
mustClauses.add(ConditionFactory.filter(convertOperand(and.left())));
mustClauses.add(ConditionFactory.filter(convertOperand(and.right())));
} else if (operand instanceof Or) {
Or or = (Or) operand;
shouldClauses.add(ConditionFactory.filter(convertOperand(or.left())));
shouldClauses.add(ConditionFactory.filter(convertOperand(or.right())));
} else {
mustClauses.add(parseComparison(operand));
}
return context.addAllMust(mustClauses).addAllShould(shouldClauses).addAllMustNot(mustNotClauses).build();
}
private static Condition parseComparison(Filter comparision) {
if (comparision instanceof IsEqualTo) {
return buildEqCondition((IsEqualTo) comparision);
} else if (comparision instanceof IsNotEqualTo) {
return buildNeCondition((IsNotEqualTo) comparision);
} else if (comparision instanceof IsGreaterThan) {
return buildGtCondition((IsGreaterThan) comparision);
} else if (comparision instanceof IsGreaterThanOrEqualTo) {
return buildGteCondition((IsGreaterThanOrEqualTo) comparision);
} else if (comparision instanceof IsLessThan) {
return buildLtCondition((IsLessThan) comparision);
} else if (comparision instanceof IsLessThanOrEqualTo) {
return buildLteCondition((IsLessThanOrEqualTo) comparision);
} else if (comparision instanceof IsIn) {
return buildInCondition((IsIn) comparision);
} else if (comparision instanceof IsNotIn) {
return buildNInCondition((IsNotIn) comparision);
} else {
throw new UnsupportedOperationException("Unsupported comparision type: " + comparision);
}
}
private static Condition buildEqCondition(IsEqualTo equalTo) {
String key = equalTo.key();
Object value = equalTo.comparisonValue();
if (value instanceof String || value instanceof UUID) {
return ConditionFactory.matchKeyword(key, value.toString());
} else if (value instanceof Boolean) {
return ConditionFactory.match(key, (Boolean) value);
} else if (value instanceof Integer || value instanceof Long) {
long lValue = Long.parseLong(value.toString());
return ConditionFactory.match(key, lValue);
}
throw new IllegalArgumentException(
"Invalid value type for IsEqualTo. Can either be a String or Boolean or Integer or Long");
}
private static Condition buildNeCondition(IsNotEqualTo notEqual) {
String key = notEqual.key();
Object value = notEqual.comparisonValue();
if (value instanceof String || value instanceof UUID) {
return ConditionFactory.filter(
Points.Filter.newBuilder().addMustNot(ConditionFactory.matchKeyword(key, value.toString()))
.build());
} else if (value instanceof Boolean) {
Condition condition = ConditionFactory.match(key, (Boolean) value);
return ConditionFactory.filter(Points.Filter.newBuilder().addMustNot(condition).build());
} else if (value instanceof Integer || value instanceof Long) {
long lValue = Long.parseLong(value.toString());
Condition condition = ConditionFactory.match(key, lValue);
return ConditionFactory.filter(Points.Filter.newBuilder().addMustNot(condition).build());
}
throw new IllegalArgumentException(
"Invalid value type for IsNotEqualto. Can either be a String or Boolean or Integer or Long");
}
private static Condition buildGtCondition(IsGreaterThan greaterThan) {
String key = greaterThan.key();
Object value = greaterThan.comparisonValue();
if (value instanceof Number) {
Double dvalue = Double.parseDouble(value.toString());
return ConditionFactory.range(key, Points.Range.newBuilder().setGt(dvalue).build());
}
throw new RuntimeException("Unsupported value type for IsGreaterThan condition. Only supports Number");
}
private static Condition buildLtCondition(IsLessThan lessThan) {
String key = lessThan.key();
Object value = lessThan.comparisonValue();
if (value instanceof Number) {
Double dvalue = Double.parseDouble(value.toString());
return ConditionFactory.range(key, Points.Range.newBuilder().setLt(dvalue).build());
}
throw new RuntimeException("Unsupported value type for IsLessThan condition. Only supports Number");
}
private static Condition buildGteCondition(IsGreaterThanOrEqualTo greaterThanOrEqualTo) {
String key = greaterThanOrEqualTo.key();
Object value = greaterThanOrEqualTo.comparisonValue();
if (value instanceof Number) {
Double dvalue = Double.parseDouble(value.toString());
return ConditionFactory.range(key, Points.Range.newBuilder().setGte(dvalue).build());
}
throw new RuntimeException("Unsupported value type for IsGreaterThanOrEqualTo condition. Only supports Number");
}
private static Condition buildLteCondition(IsLessThanOrEqualTo lessThanOrEqualTo) {
String key = lessThanOrEqualTo.key();
Object value = lessThanOrEqualTo.comparisonValue();
if (value instanceof Number) {
Double dvalue = Double.parseDouble(value.toString());
return ConditionFactory.range(key, Points.Range.newBuilder().setLte(dvalue).build());
}
throw new RuntimeException("Unsupported value type for IsLessThanOrEqualTo condition. Only supports Number");
}
private static Condition buildInCondition(IsIn in) {
String key = in.key();
List<?> valueList = new ArrayList<>(in.comparisonValues());
Object firstValue = valueList.get(0);
if (firstValue instanceof String || firstValue instanceof UUID) {
// If the first value is a string, then all values should be strings
List<String> stringValues = new ArrayList<String>();
for (Object valueObj : valueList) {
stringValues.add(valueObj.toString());
}
return ConditionFactory.matchKeywords(key, stringValues);
} else if (firstValue instanceof Integer || firstValue instanceof Long) {
// If the first value is a number, then all values should be numbers
List<Long> longValues = new ArrayList<Long>();
for (Object valueObj : valueList) {
Long longValue = Long.parseLong(valueObj.toString());
longValues.add(longValue);
}
return ConditionFactory.matchValues(key, longValues);
} else {
throw new RuntimeException(
"Unsupported value in IsIn value list. Only supports String or Integer or Long");
}
}
private static Condition buildNInCondition(IsNotIn notIn) {
String key = notIn.key();
List<?> valueList = new ArrayList<>(notIn.comparisonValues());
Object firstValue = valueList.get(0);
if (firstValue instanceof String || firstValue instanceof UUID) {
// If the first value is a string, then all values should be strings
List<String> stringValues = new ArrayList<String>();
for (Object valueObj : valueList) {
stringValues.add(valueObj.toString());
}
return ConditionFactory.matchExceptKeywords(key, stringValues);
} else if (firstValue instanceof Integer || firstValue instanceof Long) {
// If the first value is a number, then all values should be numbers
List<Long> longValues = new ArrayList<Long>();
for (Object valueObj : valueList) {
Long longValue = Long.parseLong(valueObj.toString());
longValues.add(longValue);
}
return ConditionFactory.matchExceptValues(key, longValues);
} else {
throw new RuntimeException(
"Unsupported value in IsNotIn value list. Only supports String or Integer or Long");
}
}
}

View File

@ -0,0 +1,85 @@
package dev.langchain4j.store.embedding.qdrant;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import io.qdrant.client.ValueFactory;
import io.qdrant.client.grpc.JsonWithInt.Struct;
import io.qdrant.client.grpc.JsonWithInt.Value;
/**
* Utility methods for building io.qdrant.client.grpc.JsonWithInt.Value from Java objects.
*
* @author Anush Shetty
* @since 0.8.1
*/
class ValueMapFactory {
private ValueMapFactory() {
}
public static Map<String, Value> valueMap(Map<String, Object> inputMap) {
return inputMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> value(e.getValue())));
}
@SuppressWarnings("unchecked")
private static Value value(Object value) {
if (value == null) {
return ValueFactory.nullValue();
}
if (value.getClass().isArray()) {
int length = Array.getLength(value);
Object[] objectArray = new Object[length];
for (int i = 0; i < length; i++) {
objectArray[i] = Array.get(value, i);
}
return value(objectArray);
}
if (value instanceof Map) {
return value((Map<String, Object>) value);
}
switch (value.getClass().getSimpleName()) {
case "UUID":
return ValueFactory.value(value.toString());
case "String":
return ValueFactory.value((String) value);
case "Integer":
return ValueFactory.value((Integer) value);
case "Long":
return ValueFactory.value((Long) value);
case "Double":
return ValueFactory.value((Double) value);
case "Float":
return ValueFactory.value((Float) value);
case "Boolean":
return ValueFactory.value((Boolean) value);
default:
throw new IllegalArgumentException("Unsupported Qdrant value type: " + value.getClass());
}
}
private static Value value(Object[] elements) {
List<Value> values = new ArrayList<Value>(elements.length);
for (Object element : elements) {
values.add(value(element));
}
return ValueFactory.list(values);
}
private static Value value(Map<String, Object> inputMap) {
Struct.Builder structBuilder = Struct.newBuilder();
Map<String, Value> map = valueMap(inputMap);
structBuilder.putAllFields(map);
return Value.newBuilder().setStructValue(structBuilder).build();
}
}

View File

@ -1,78 +1,176 @@
package dev.langchain4j.store.embedding.qdrant; package dev.langchain4j.store.embedding.qdrant;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import io.qdrant.client.QdrantClient; import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient; import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections.Distance; import io.qdrant.client.grpc.Collections.Distance;
import io.qdrant.client.grpc.Collections.VectorParams; import io.qdrant.client.grpc.Collections.VectorParams;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.provider.Arguments;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.qdrant.QdrantContainer; import org.testcontainers.qdrant.QdrantContainer;
import java.util.stream.Stream;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.lang.reflect.Method;
import static dev.langchain4j.internal.Utils.randomUUID; import static dev.langchain4j.internal.Utils.randomUUID;
@Testcontainers @Testcontainers
class QdrantEmbeddingStoreIT extends EmbeddingStoreIT { class QdrantEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
protected static final UUID TEST_UUID = UUID.randomUUID();
static final UUID TEST_UUID2 = UUID.randomUUID();
private static String collectionName = "langchain4j-" + randomUUID();
private static int dimension = 384;
private static Distance distance = Distance.Cosine;
private static QdrantEmbeddingStore embeddingStore;
private static String collectionName = "langchain4j-" + randomUUID(); @Container
private static int dimension = 384; private static final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant:latest");
private static int grpcPort = 6334;
private static Distance distance = Distance.Cosine;
private static QdrantEmbeddingStore embeddingStore;
@Container EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
private static final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant:latest");
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @BeforeAll
static void setup() throws InterruptedException, ExecutionException {
@BeforeAll embeddingStore = QdrantEmbeddingStore.builder()
static void setup() throws InterruptedException, ExecutionException {
embeddingStore =
QdrantEmbeddingStore.builder()
.host(qdrant.getHost()) .host(qdrant.getHost())
.port(qdrant.getMappedPort(grpcPort)) .port(qdrant.getGrpcPort())
.collectionName(collectionName) .collectionName(collectionName)
.build(); .build();
QdrantClient client = QdrantClient client = new QdrantClient(
new QdrantClient( QdrantGrpcClient.newBuilder(qdrant.getHost(), qdrant.getGrpcPort(), false)
QdrantGrpcClient.newBuilder(qdrant.getHost(), qdrant.getMappedPort(grpcPort), false) .build());
.build());
client client
.createCollectionAsync( .createCollectionAsync(
collectionName, collectionName,
VectorParams.newBuilder().setDistance(distance).setSize(dimension).build()) VectorParams.newBuilder().setDistance(distance).setSize(dimension)
.get(); .build())
.get();
client.close(); client.close();
} }
@AfterAll @AfterAll
static void teardown() { static void teardown() {
embeddingStore.close(); embeddingStore.close();
} }
@Override @Override
protected EmbeddingStore<TextSegment> embeddingStore() { protected EmbeddingStore < TextSegment > embeddingStore() {
return embeddingStore; return embeddingStore;
} }
@Override @Override
protected EmbeddingModel embeddingModel() { protected EmbeddingModel embeddingModel() {
return embeddingModel; return embeddingModel;
} }
@Override @Override
protected void clearStore() { protected void clearStore() {
embeddingStore.clearStore(); embeddingStore.clearStore();
} }
}
// - Eq, NEq, In, NIn don't allow float and double values. Only integers and
// strings.
// - LT, GT, LTE, GTE allow only numbers, not alphabets.
// - For In and NIn conditions, if the key doesn't exist in the metadata, it is
// not matched.
protected static Stream < Arguments > should_filter_by_metadata_not() {
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata_not()
.filter(arguments -> {
Filter filter = (Filter) arguments.get()[0];
if (filter instanceof IsNotIn) {
try {
IsNotIn notIn = (IsNotIn) filter;
Method method = notIn.getClass().getMethod("key");
String key = (String) method.invoke(filter);
List < Metadata > matchingMetadatas = (List < Metadata > ) arguments.get()[1];
// For NIn conditions, if the key doesn't exist in the metadata it is not
// matched.
Boolean matching = matchingMetadatas.stream()
.allMatch(metadata -> metadata.containsKey(key));
if (!matching) {
return false;
}
Object firstValue = notIn.comparisonValues().stream().findFirst().get();
return firstValue instanceof String ||
firstValue instanceof UUID ||
firstValue instanceof Integer ||
firstValue instanceof Long;
} catch (Exception e) {
throw new RuntimeException(e);
}
} else if (filter instanceof IsNotEqualTo) {
IsNotEqualTo notEqualTo = (IsNotEqualTo) filter;
return notEqualTo.comparisonValue() instanceof String ||
notEqualTo.comparisonValue() instanceof UUID ||
notEqualTo.comparisonValue() instanceof Integer ||
notEqualTo.comparisonValue() instanceof Long;
} else {
return true;
}
});
}
protected static Stream < Arguments > should_filter_by_metadata() {
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata()
.filter(arguments -> {
Filter filter = (Filter) arguments.get()[0];
if (filter instanceof IsLessThan) {
IsLessThan lessThan = (IsLessThan) filter;
return lessThan.comparisonValue() instanceof Integer ||
lessThan.comparisonValue() instanceof Long;
} else if (filter instanceof IsLessThanOrEqualTo) {
IsLessThanOrEqualTo lessThanOrEqualTo = (IsLessThanOrEqualTo) filter;
return lessThanOrEqualTo.comparisonValue() instanceof Integer ||
lessThanOrEqualTo.comparisonValue() instanceof Long;
} else if (filter instanceof IsGreaterThan) {
IsGreaterThan greaterThan = (IsGreaterThan) filter;
return greaterThan.comparisonValue() instanceof Integer ||
greaterThan.comparisonValue() instanceof Long;
} else if (filter instanceof IsGreaterThanOrEqualTo) {
IsGreaterThanOrEqualTo greaterThanOrEqualTo = (IsGreaterThanOrEqualTo) filter;
return greaterThanOrEqualTo.comparisonValue() instanceof Integer ||
greaterThanOrEqualTo.comparisonValue() instanceof Long;
} else if (filter instanceof IsEqualTo) {
IsEqualTo equalTo = (IsEqualTo) filter;
return equalTo.comparisonValue() instanceof String ||
equalTo.comparisonValue() instanceof UUID ||
equalTo.comparisonValue() instanceof Integer ||
equalTo.comparisonValue() instanceof Long;
} else if (filter instanceof IsIn) {
IsIn in = (IsIn) filter;
Object firstValue = in .comparisonValues().stream().findFirst().get();
return firstValue instanceof String ||
firstValue instanceof UUID ||
firstValue instanceof Integer ||
firstValue instanceof Long;
} else {
return true;
}
});
}
}

View File

@ -0,0 +1,130 @@
package dev.langchain4j.store.embedding.qdrant;
import dev.langchain4j.store.embedding.filter.comparison.*;
import io.qdrant.client.grpc.Points;
import dev.langchain4j.store.embedding.filter.Filter;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.util.Arrays;
class QdrantFilterConverterTest {
@Test
void testIsEqualToFilter() {
Filter filter = new IsEqualTo("num-value", 5);
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals("num-value", convertedFilter.getMust(0).getField().getKey());
assertEquals(5, convertedFilter.getMust(0).getField().getMatch().getInteger());
filter = new IsEqualTo("str-value", "value");
convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals("str-value", convertedFilter.getMust(0).getField().getKey());
assertEquals("value", convertedFilter.getMust(0).getField().getMatch().getKeyword());
filter = new IsEqualTo("bool-value", true);
convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals("bool-value", convertedFilter.getMust(0).getField().getKey());
assertEquals(true, convertedFilter.getMust(0).getField().getMatch().getBoolean());
}
@Test
void testIsNotEqualToFilter() {
Filter filter = new IsNotEqualTo("num-value", 5);
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals("num-value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getKey());
assertEquals(5, convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getMatch().getInteger());
filter = new IsNotEqualTo("str-value", "value");
convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals("str-value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getKey());
assertEquals("value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getMatch().getKeyword());
filter = new IsNotEqualTo("bool-value", true);
convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals("bool-value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getKey());
assertEquals(true, convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getMatch().getBoolean());
}
@Test
void testIsGreaterThanFilter() {
Filter filter = new IsGreaterThan("key", 1);
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(convertedFilter.getMust(0).getField().getRange().getGt(), 1);
}
@Test
void testIsLessThanFilter() {
Filter filter = new IsLessThan("key", 10);
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(convertedFilter.getMust(0).getField().getRange().getLt(), 10);
}
@Test
void testIsGreaterThanOrEqualToFilter() {
Filter filter = new IsGreaterThanOrEqualTo("key", 1);
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(convertedFilter.getMust(0).getField().getRange().getGte(), 1);
}
@Test
void testIsLessThanOrEqualToFilter() {
Filter filter = new IsLessThanOrEqualTo("key", 10);
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(convertedFilter.getMust(0).getField().getRange().getLte(), 10);
}
@Test
void testInFilter() {
Filter filter = new IsIn("key", Arrays.asList(1, 2, 3));
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(3, convertedFilter.getMust(0).getField().getMatch().getIntegers().getIntegersCount());
filter = new IsIn("key", Arrays.asList("a", "b", "c"));
convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(3, convertedFilter.getMust(0).getField().getMatch().getKeywords().getStringsCount());
}
@Test
void testNInFilter() {
Filter filter = new IsNotIn("key", Arrays.asList(1, 2, 3, 4));
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(4, convertedFilter.getMust(0).getField().getMatch().getExceptIntegers().getIntegersCount());
filter = new IsNotIn("key", Arrays.asList("a", "b", "c", "k"));
convertedFilter = QdrantFilterConverter.convertExpression(filter);
assertNotNull(convertedFilter);
assertEquals(1, convertedFilter.getMustCount());
assertEquals(4, convertedFilter.getMust(0).getField().getMatch().getExceptKeywords().getStringsCount());
}
}