feat: Metadata filtering Qdrant (#1646)
## Description - Filtering support for Qdrant with tests for the converter. - Updated metadata storage handling to support complex metadata types. No longer limited to `Map<String, String>`. Closes #1600 ## NOTE In Qdrant, - Eq, NEq, In, NIn don't allow float and double values. Only integers and strings. - LT, GT, LTE, GTE allow only numbers, not alphabets. - For In and NIn conditions, if the key doesn't exist in the metadata, it is not matched.
This commit is contained in:
parent
aaaa71a5d7
commit
2a5189a78a
|
@ -1,7 +1,7 @@
|
|||
Thank you for investing your time and effort in contributing to our project, we appreciate it a lot! 🤗
|
||||
|
||||
|
||||
# General guidelines
|
||||
|
||||
- If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly.
|
||||
- Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/)
|
||||
- Keep the code compatible with Java 8. We plan to increase the baseline to Java 17 a bit later.
|
||||
|
@ -14,20 +14,20 @@ Thank you for investing your time and effort in contributing to our project, we
|
|||
- Follow existing code style present in the project.
|
||||
- Large features should be discussed with maintainers before implementation. Please ping @langchain4j in the comments on the issue.
|
||||
|
||||
|
||||
# Priorities
|
||||
|
||||
All [issues](https://github.com/langchain4j/langchain4j/issues) are prioritized by maintainers. There are 4 priorities: [P1](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP1), [P2](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP2), [P3](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP3) and [P4](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP4).
|
||||
|
||||
Please start with the higher priorities. PRs will be reviewed in order of priority, with bugs being a higher priority than new features.
|
||||
|
||||
Please note that we do not have the capacity to review PRs immediately. We ask for your patience. We are doing our best to review your PR as quickly as possible.
|
||||
|
||||
|
||||
# Opening an issue
|
||||
|
||||
- Please fill in all sections of the issue template.
|
||||
|
||||
|
||||
# Opening a draft PR
|
||||
|
||||
- Please open the PR as a draft initially. Once it is reviewed and approved, we will then ask you to finalize it (see section below).
|
||||
- Fill in all the sections of the PR template.
|
||||
- Please make it easier to review your PR:
|
||||
|
@ -35,14 +35,14 @@ Please note that we do not have the capacity to review PRs immediately. We ask f
|
|||
- Do not combine refactoring with changes in a single PR.
|
||||
- Avoid reformatting existing code.
|
||||
|
||||
|
||||
# Finalizing the draft PR
|
||||
|
||||
- Add [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) (if required).
|
||||
- Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples) (if required).
|
||||
- [Mark a PR as ready for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request#marking-a-pull-request-as-ready-for-review)
|
||||
|
||||
|
||||
# Guidelines on adding a new model integration
|
||||
|
||||
- [Integration with Anthropic](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-anthropic) is a good example.
|
||||
- Use the official SDK if available.
|
||||
- If the official SDK is not available, use Retrofit and Jackson to implement the client.
|
||||
|
@ -51,8 +51,8 @@ Please note that we do not have the capacity to review PRs immediately. We ask f
|
|||
- Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml).
|
||||
- It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring).
|
||||
|
||||
|
||||
# Guidelines on adding a new embedding store integration
|
||||
|
||||
- [Integration with Chroma](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-chroma) is a good example.
|
||||
- Use the official SDK if available.
|
||||
- If the official SDK is not available, use Retrofit and Jackson to implement the client.
|
||||
|
@ -63,6 +63,6 @@ Please note that we do not have the capacity to review PRs immediately. We ask f
|
|||
- Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml).
|
||||
- It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). (after
|
||||
|
||||
|
||||
# Guidelines on changing an existing embedding store integration
|
||||
|
||||
- Ensure that your changes are backwards compatible. `Embedding`s and `TextSegment`s persisted with the latest released version of LangChain4j should still work.
|
||||
|
|
|
@ -82,7 +82,7 @@
|
|||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-protobuf</artifactId>
|
||||
<version>1.59.0</version>
|
||||
<version>1.65.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
@ -94,7 +94,7 @@
|
|||
<dependency>
|
||||
<groupId>io.qdrant</groupId>
|
||||
<artifactId>client</artifactId>
|
||||
<version>1.7.1</version>
|
||||
<version>1.11.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package dev.langchain4j.store.embedding.qdrant;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import io.qdrant.client.grpc.JsonWithInt.ListValue;
|
||||
import io.qdrant.client.grpc.JsonWithInt.Value;
|
||||
/**
|
||||
* Utility methods for building Java objects from io.qdrant.client.grpc.JsonWithInt.Value.
|
||||
*
|
||||
* @author Anush Shetty
|
||||
* @since 0.8.1
|
||||
*/
|
||||
class ObjectFactory {
|
||||
|
||||
private ObjectFactory() {
|
||||
}
|
||||
|
||||
public static Object object(Value value) {
|
||||
switch (value.getKindCase()) {
|
||||
case INTEGER_VALUE:
|
||||
return value.getIntegerValue();
|
||||
case STRING_VALUE:
|
||||
return value.getStringValue();
|
||||
case DOUBLE_VALUE:
|
||||
return value.getDoubleValue();
|
||||
case BOOL_VALUE:
|
||||
return value.getBoolValue();
|
||||
case LIST_VALUE:
|
||||
return object(value.getListValue());
|
||||
case STRUCT_VALUE:
|
||||
return objectMap(value.getStructValue().getFieldsMap());
|
||||
case NULL_VALUE:
|
||||
return null;
|
||||
case KIND_NOT_SET:
|
||||
default:
|
||||
throw new IllegalArgumentException("Unknown value type: " + value.getKindCase());
|
||||
}
|
||||
}
|
||||
|
||||
private static Map<String, Object> objectMap(Map<String, Value> payload) {
|
||||
return payload.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> object(e.getValue())));
|
||||
}
|
||||
|
||||
private static Object object(ListValue listValue) {
|
||||
return listValue.getValuesList().stream().map(ObjectFactory::object).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
|
@ -14,10 +14,7 @@ import static java.util.stream.Collectors.toMap;
|
|||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import dev.langchain4j.store.embedding.*;
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.QdrantGrpcClient;
|
||||
import io.qdrant.client.WithVectorsSelectorFactory;
|
||||
|
@ -38,7 +35,8 @@ import java.util.concurrent.ExecutionException;
|
|||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* Represents a <a href="https://qdrant.tech/">Qdrant</a> collection as an embedding store. With
|
||||
* Represents a <a href="https://qdrant.tech/">Qdrant</a> collection as an
|
||||
* embedding store. With
|
||||
* support for storing {@link dev.langchain4j.data.document.Metadata}.
|
||||
*/
|
||||
public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
@ -49,11 +47,12 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
|
||||
/**
|
||||
* @param collectionName The name of the Qdrant collection.
|
||||
* @param host The host of the Qdrant instance.
|
||||
* @param port The GRPC port of the Qdrant instance.
|
||||
* @param useTls Whether to use TLS(HTTPS).
|
||||
* @param payloadTextKey The field name of the text segment in the Qdrant payload.
|
||||
* @param apiKey The Qdrant API key to authenticate with.
|
||||
* @param host The host of the Qdrant instance.
|
||||
* @param port The GRPC port of the Qdrant instance.
|
||||
* @param useTls Whether to use TLS(HTTPS).
|
||||
* @param payloadTextKey The field name of the text segment in the Qdrant
|
||||
* payload.
|
||||
* @param apiKey The Qdrant API key to authenticate with.
|
||||
*/
|
||||
public QdrantEmbeddingStore(
|
||||
String collectionName,
|
||||
|
@ -75,9 +74,10 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
|
||||
/**
|
||||
* @param client A Qdrant client instance.
|
||||
* @param client A Qdrant client instance.
|
||||
* @param collectionName The name of the Qdrant collection.
|
||||
* @param payloadTextKey The field name of the text segment in the Qdrant payload.
|
||||
* @param payloadTextKey The field name of the text segment in the Qdrant
|
||||
* payload.
|
||||
*/
|
||||
public QdrantEmbeddingStore(QdrantClient client, String collectionName, String payloadTextKey) {
|
||||
this.client = client;
|
||||
|
@ -132,50 +132,88 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
|
||||
private void addAllInternal(
|
||||
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||
|
||||
List<PointStruct> points = new ArrayList<>(embeddings.size());
|
||||
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
|
||||
String id = ids.get(i);
|
||||
UUID uuid = UUID.fromString(id);
|
||||
Embedding embedding = embeddings.get(i);
|
||||
|
||||
PointStruct.Builder pointBuilder =
|
||||
PointStruct.newBuilder().setId(id(uuid)).setVectors(vectors(embedding.vector()));
|
||||
|
||||
if (textSegments != null) {
|
||||
pointBuilder.putPayload(payloadTextKey, value(textSegments.get(i).text()));
|
||||
textSegments
|
||||
.get(i)
|
||||
.metadata()
|
||||
.asMap()
|
||||
.forEach((key, value) -> pointBuilder.putPayload(key, value(value)));
|
||||
}
|
||||
|
||||
points.add(pointBuilder.build());
|
||||
}
|
||||
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) throws RuntimeException {
|
||||
|
||||
try {
|
||||
List<PointStruct> points = new ArrayList<>(embeddings.size());
|
||||
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
|
||||
String id = ids.get(i);
|
||||
UUID uuid = UUID.fromString(id);
|
||||
Embedding embedding = embeddings.get(i);
|
||||
|
||||
PointStruct.Builder pointBuilder = PointStruct.newBuilder().setId(id(uuid))
|
||||
.setVectors(vectors(embedding.vector()));
|
||||
|
||||
if (textSegments != null) {
|
||||
Map<String, Object> metadata = textSegments
|
||||
.get(i)
|
||||
.metadata()
|
||||
.toMap();
|
||||
|
||||
Map<String, Value> payload = ValueMapFactory.valueMap(metadata);
|
||||
payload.put(payloadTextKey, value(textSegments.get(i).text()));
|
||||
pointBuilder.putAllPayload(payload);
|
||||
}
|
||||
|
||||
points.add(pointBuilder.build());
|
||||
}
|
||||
|
||||
client.upsertAsync(collectionName, points).get();
|
||||
} catch (InterruptedException | ExecutionException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
|
||||
|
||||
SearchPoints.Builder searchBuilder = SearchPoints.newBuilder()
|
||||
.setCollectionName(collectionName)
|
||||
.addAllVector(request.queryEmbedding().vectorAsList())
|
||||
.setWithVectors(WithVectorsSelectorFactory.enable(true))
|
||||
.setWithPayload(enable(true))
|
||||
.setLimit(request.maxResults());
|
||||
|
||||
if (request.filter() != null) {
|
||||
Filter filter = QdrantFilterConverter.convertExpression(request.filter());
|
||||
searchBuilder.setFilter(filter);
|
||||
}
|
||||
|
||||
List<ScoredPoint> results;
|
||||
|
||||
try {
|
||||
results = client.searchAsync(searchBuilder.build()).get();
|
||||
} catch (InterruptedException | ExecutionException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if (results.isEmpty()) {
|
||||
return new EmbeddingSearchResult<TextSegment>(emptyList());
|
||||
}
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches = results.stream()
|
||||
.map(vector -> toEmbeddingMatch(vector, request.queryEmbedding()))
|
||||
.filter(match -> match.score() >= request.minScore())
|
||||
.sorted(comparingDouble(EmbeddingMatch::score))
|
||||
.collect(toList());
|
||||
|
||||
Collections.reverse(matches);
|
||||
|
||||
return new EmbeddingSearchResult<TextSegment>(matches);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(
|
||||
Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||
|
||||
SearchPoints search =
|
||||
SearchPoints.newBuilder()
|
||||
.setCollectionName(collectionName)
|
||||
.addAllVector(referenceEmbedding.vectorAsList())
|
||||
.setWithVectors(WithVectorsSelectorFactory.enable(true))
|
||||
.setWithPayload(enable(true))
|
||||
.setLimit(maxResults)
|
||||
.build();
|
||||
SearchPoints search = SearchPoints.newBuilder()
|
||||
.setCollectionName(collectionName)
|
||||
.addAllVector(referenceEmbedding.vectorAsList())
|
||||
.setWithVectors(WithVectorsSelectorFactory.enable(true))
|
||||
.setWithPayload(enable(true))
|
||||
.setLimit(maxResults)
|
||||
.build();
|
||||
|
||||
List<ScoredPoint> results;
|
||||
|
||||
|
@ -189,12 +227,11 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
return emptyList();
|
||||
}
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches =
|
||||
results.stream()
|
||||
.map(vector -> toEmbeddingMatch(vector, referenceEmbedding))
|
||||
.filter(match -> match.score() >= minScore)
|
||||
.sorted(comparingDouble(EmbeddingMatch::score))
|
||||
.collect(toList());
|
||||
List<EmbeddingMatch<TextSegment>> matches = results.stream()
|
||||
.map(vector -> toEmbeddingMatch(vector, referenceEmbedding))
|
||||
.filter(match -> match.score() >= minScore)
|
||||
.sorted(comparingDouble(EmbeddingMatch::score))
|
||||
.collect(toList());
|
||||
|
||||
Collections.reverse(matches);
|
||||
|
||||
|
@ -231,10 +268,9 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
|
||||
Value textSegmentValue = payload.getOrDefault(payloadTextKey, null);
|
||||
|
||||
Map<String, String> metadata =
|
||||
payload.entrySet().stream()
|
||||
.filter(entry -> !entry.getKey().equals(payloadTextKey))
|
||||
.collect(toMap(Map.Entry::getKey, entry -> entry.getValue().getStringValue()));
|
||||
Map<String, Object> metadata = payload.entrySet().stream()
|
||||
.filter(entry -> !entry.getKey().equals(payloadTextKey))
|
||||
.collect(toMap(Map.Entry::getKey, entry -> ObjectFactory.object(entry.getValue())));
|
||||
|
||||
Embedding embedding = Embedding.from(scoredPoint.getVectors().getVector().getDataList());
|
||||
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);
|
||||
|
@ -297,8 +333,9 @@ public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
|
||||
/**
|
||||
* @param payloadTextKey The field name of the text segment in the payload. Defaults to
|
||||
* "text_segment".
|
||||
* @param payloadTextKey The field name of the text segment in the payload.
|
||||
* Defaults to
|
||||
* "text_segment".
|
||||
* @return
|
||||
*/
|
||||
public Builder payloadTextKey(String payloadTextKey) {
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
package dev.langchain4j.store.embedding.qdrant;
|
||||
|
||||
import io.qdrant.client.ConditionFactory;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.comparison.*;
|
||||
import dev.langchain4j.store.embedding.filter.logical.And;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Not;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Or;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import io.qdrant.client.grpc.Points;
|
||||
import io.qdrant.client.grpc.Points.Condition;
|
||||
|
||||
class QdrantFilterConverter {
|
||||
|
||||
public static Points.Filter convertExpression(Filter expression) {
|
||||
return QdrantFilterConverter.convertOperand(expression);
|
||||
}
|
||||
|
||||
private static Points.Filter convertOperand(Filter operand) {
|
||||
Points.Filter.Builder context = Points.Filter.newBuilder();
|
||||
List<Condition> mustClauses = new ArrayList<Condition>();
|
||||
List<Condition> shouldClauses = new ArrayList<Condition>();
|
||||
List<Condition> mustNotClauses = new ArrayList<Condition>();
|
||||
|
||||
if (operand instanceof Not) {
|
||||
Not not = (Not) operand;
|
||||
mustNotClauses.add(ConditionFactory.filter(convertOperand(not.expression())));
|
||||
} else if (operand instanceof And) {
|
||||
And and = (And) operand;
|
||||
mustClauses.add(ConditionFactory.filter(convertOperand(and.left())));
|
||||
mustClauses.add(ConditionFactory.filter(convertOperand(and.right())));
|
||||
} else if (operand instanceof Or) {
|
||||
Or or = (Or) operand;
|
||||
shouldClauses.add(ConditionFactory.filter(convertOperand(or.left())));
|
||||
shouldClauses.add(ConditionFactory.filter(convertOperand(or.right())));
|
||||
} else {
|
||||
mustClauses.add(parseComparison(operand));
|
||||
}
|
||||
|
||||
return context.addAllMust(mustClauses).addAllShould(shouldClauses).addAllMustNot(mustNotClauses).build();
|
||||
}
|
||||
|
||||
private static Condition parseComparison(Filter comparision) {
|
||||
if (comparision instanceof IsEqualTo) {
|
||||
return buildEqCondition((IsEqualTo) comparision);
|
||||
} else if (comparision instanceof IsNotEqualTo) {
|
||||
return buildNeCondition((IsNotEqualTo) comparision);
|
||||
} else if (comparision instanceof IsGreaterThan) {
|
||||
return buildGtCondition((IsGreaterThan) comparision);
|
||||
} else if (comparision instanceof IsGreaterThanOrEqualTo) {
|
||||
return buildGteCondition((IsGreaterThanOrEqualTo) comparision);
|
||||
} else if (comparision instanceof IsLessThan) {
|
||||
return buildLtCondition((IsLessThan) comparision);
|
||||
} else if (comparision instanceof IsLessThanOrEqualTo) {
|
||||
return buildLteCondition((IsLessThanOrEqualTo) comparision);
|
||||
} else if (comparision instanceof IsIn) {
|
||||
return buildInCondition((IsIn) comparision);
|
||||
} else if (comparision instanceof IsNotIn) {
|
||||
return buildNInCondition((IsNotIn) comparision);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported comparision type: " + comparision);
|
||||
}
|
||||
}
|
||||
|
||||
private static Condition buildEqCondition(IsEqualTo equalTo) {
|
||||
String key = equalTo.key();
|
||||
Object value = equalTo.comparisonValue();
|
||||
if (value instanceof String || value instanceof UUID) {
|
||||
return ConditionFactory.matchKeyword(key, value.toString());
|
||||
} else if (value instanceof Boolean) {
|
||||
return ConditionFactory.match(key, (Boolean) value);
|
||||
} else if (value instanceof Integer || value instanceof Long) {
|
||||
long lValue = Long.parseLong(value.toString());
|
||||
return ConditionFactory.match(key, lValue);
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid value type for IsEqualTo. Can either be a String or Boolean or Integer or Long");
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildNeCondition(IsNotEqualTo notEqual) {
|
||||
String key = notEqual.key();
|
||||
Object value = notEqual.comparisonValue();
|
||||
if (value instanceof String || value instanceof UUID) {
|
||||
return ConditionFactory.filter(
|
||||
Points.Filter.newBuilder().addMustNot(ConditionFactory.matchKeyword(key, value.toString()))
|
||||
.build());
|
||||
} else if (value instanceof Boolean) {
|
||||
Condition condition = ConditionFactory.match(key, (Boolean) value);
|
||||
return ConditionFactory.filter(Points.Filter.newBuilder().addMustNot(condition).build());
|
||||
} else if (value instanceof Integer || value instanceof Long) {
|
||||
long lValue = Long.parseLong(value.toString());
|
||||
Condition condition = ConditionFactory.match(key, lValue);
|
||||
return ConditionFactory.filter(Points.Filter.newBuilder().addMustNot(condition).build());
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid value type for IsNotEqualto. Can either be a String or Boolean or Integer or Long");
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildGtCondition(IsGreaterThan greaterThan) {
|
||||
String key = greaterThan.key();
|
||||
Object value = greaterThan.comparisonValue();
|
||||
if (value instanceof Number) {
|
||||
Double dvalue = Double.parseDouble(value.toString());
|
||||
return ConditionFactory.range(key, Points.Range.newBuilder().setGt(dvalue).build());
|
||||
}
|
||||
throw new RuntimeException("Unsupported value type for IsGreaterThan condition. Only supports Number");
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildLtCondition(IsLessThan lessThan) {
|
||||
String key = lessThan.key();
|
||||
Object value = lessThan.comparisonValue();
|
||||
if (value instanceof Number) {
|
||||
Double dvalue = Double.parseDouble(value.toString());
|
||||
return ConditionFactory.range(key, Points.Range.newBuilder().setLt(dvalue).build());
|
||||
}
|
||||
throw new RuntimeException("Unsupported value type for IsLessThan condition. Only supports Number");
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildGteCondition(IsGreaterThanOrEqualTo greaterThanOrEqualTo) {
|
||||
String key = greaterThanOrEqualTo.key();
|
||||
Object value = greaterThanOrEqualTo.comparisonValue();
|
||||
if (value instanceof Number) {
|
||||
Double dvalue = Double.parseDouble(value.toString());
|
||||
return ConditionFactory.range(key, Points.Range.newBuilder().setGte(dvalue).build());
|
||||
}
|
||||
throw new RuntimeException("Unsupported value type for IsGreaterThanOrEqualTo condition. Only supports Number");
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildLteCondition(IsLessThanOrEqualTo lessThanOrEqualTo) {
|
||||
String key = lessThanOrEqualTo.key();
|
||||
Object value = lessThanOrEqualTo.comparisonValue();
|
||||
if (value instanceof Number) {
|
||||
Double dvalue = Double.parseDouble(value.toString());
|
||||
return ConditionFactory.range(key, Points.Range.newBuilder().setLte(dvalue).build());
|
||||
}
|
||||
throw new RuntimeException("Unsupported value type for IsLessThanOrEqualTo condition. Only supports Number");
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildInCondition(IsIn in) {
|
||||
String key = in.key();
|
||||
List<?> valueList = new ArrayList<>(in.comparisonValues());
|
||||
|
||||
Object firstValue = valueList.get(0);
|
||||
if (firstValue instanceof String || firstValue instanceof UUID) {
|
||||
// If the first value is a string, then all values should be strings
|
||||
List<String> stringValues = new ArrayList<String>();
|
||||
for (Object valueObj : valueList) {
|
||||
stringValues.add(valueObj.toString());
|
||||
}
|
||||
return ConditionFactory.matchKeywords(key, stringValues);
|
||||
} else if (firstValue instanceof Integer || firstValue instanceof Long) {
|
||||
// If the first value is a number, then all values should be numbers
|
||||
List<Long> longValues = new ArrayList<Long>();
|
||||
for (Object valueObj : valueList) {
|
||||
Long longValue = Long.parseLong(valueObj.toString());
|
||||
longValues.add(longValue);
|
||||
}
|
||||
return ConditionFactory.matchValues(key, longValues);
|
||||
} else {
|
||||
throw new RuntimeException(
|
||||
"Unsupported value in IsIn value list. Only supports String or Integer or Long");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static Condition buildNInCondition(IsNotIn notIn) {
|
||||
String key = notIn.key();
|
||||
List<?> valueList = new ArrayList<>(notIn.comparisonValues());
|
||||
Object firstValue = valueList.get(0);
|
||||
|
||||
if (firstValue instanceof String || firstValue instanceof UUID) {
|
||||
// If the first value is a string, then all values should be strings
|
||||
List<String> stringValues = new ArrayList<String>();
|
||||
for (Object valueObj : valueList) {
|
||||
stringValues.add(valueObj.toString());
|
||||
}
|
||||
return ConditionFactory.matchExceptKeywords(key, stringValues);
|
||||
} else if (firstValue instanceof Integer || firstValue instanceof Long) {
|
||||
// If the first value is a number, then all values should be numbers
|
||||
List<Long> longValues = new ArrayList<Long>();
|
||||
for (Object valueObj : valueList) {
|
||||
Long longValue = Long.parseLong(valueObj.toString());
|
||||
longValues.add(longValue);
|
||||
}
|
||||
return ConditionFactory.matchExceptValues(key, longValues);
|
||||
} else {
|
||||
throw new RuntimeException(
|
||||
"Unsupported value in IsNotIn value list. Only supports String or Integer or Long");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package dev.langchain4j.store.embedding.qdrant;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import io.qdrant.client.ValueFactory;
|
||||
import io.qdrant.client.grpc.JsonWithInt.Struct;
|
||||
import io.qdrant.client.grpc.JsonWithInt.Value;
|
||||
|
||||
/**
|
||||
* Utility methods for building io.qdrant.client.grpc.JsonWithInt.Value from Java objects.
|
||||
*
|
||||
* @author Anush Shetty
|
||||
* @since 0.8.1
|
||||
*/
|
||||
class ValueMapFactory {
|
||||
|
||||
private ValueMapFactory() {
|
||||
}
|
||||
|
||||
public static Map<String, Value> valueMap(Map<String, Object> inputMap) {
|
||||
return inputMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> value(e.getValue())));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Value value(Object value) {
|
||||
|
||||
if (value == null) {
|
||||
return ValueFactory.nullValue();
|
||||
}
|
||||
|
||||
if (value.getClass().isArray()) {
|
||||
int length = Array.getLength(value);
|
||||
Object[] objectArray = new Object[length];
|
||||
for (int i = 0; i < length; i++) {
|
||||
objectArray[i] = Array.get(value, i);
|
||||
}
|
||||
return value(objectArray);
|
||||
}
|
||||
|
||||
if (value instanceof Map) {
|
||||
return value((Map<String, Object>) value);
|
||||
}
|
||||
|
||||
switch (value.getClass().getSimpleName()) {
|
||||
case "UUID":
|
||||
return ValueFactory.value(value.toString());
|
||||
case "String":
|
||||
return ValueFactory.value((String) value);
|
||||
case "Integer":
|
||||
return ValueFactory.value((Integer) value);
|
||||
case "Long":
|
||||
return ValueFactory.value((Long) value);
|
||||
case "Double":
|
||||
return ValueFactory.value((Double) value);
|
||||
case "Float":
|
||||
return ValueFactory.value((Float) value);
|
||||
case "Boolean":
|
||||
return ValueFactory.value((Boolean) value);
|
||||
default:
|
||||
throw new IllegalArgumentException("Unsupported Qdrant value type: " + value.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
private static Value value(Object[] elements) {
|
||||
List<Value> values = new ArrayList<Value>(elements.length);
|
||||
|
||||
for (Object element : elements) {
|
||||
values.add(value(element));
|
||||
}
|
||||
|
||||
return ValueFactory.list(values);
|
||||
}
|
||||
|
||||
private static Value value(Map<String, Object> inputMap) {
|
||||
Struct.Builder structBuilder = Struct.newBuilder();
|
||||
Map<String, Value> map = valueMap(inputMap);
|
||||
structBuilder.putAllFields(map);
|
||||
return Value.newBuilder().setStructValue(structBuilder).build();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,78 +1,176 @@
|
|||
package dev.langchain4j.store.embedding.qdrant;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.QdrantGrpcClient;
|
||||
import io.qdrant.client.grpc.Collections.Distance;
|
||||
import io.qdrant.client.grpc.Collections.VectorParams;
|
||||
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.qdrant.QdrantContainer;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.lang.reflect.Method;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
|
||||
@Testcontainers
|
||||
class QdrantEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
class QdrantEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
|
||||
protected static final UUID TEST_UUID = UUID.randomUUID();
|
||||
static final UUID TEST_UUID2 = UUID.randomUUID();
|
||||
private static String collectionName = "langchain4j-" + randomUUID();
|
||||
private static int dimension = 384;
|
||||
private static Distance distance = Distance.Cosine;
|
||||
private static QdrantEmbeddingStore embeddingStore;
|
||||
|
||||
private static String collectionName = "langchain4j-" + randomUUID();
|
||||
private static int dimension = 384;
|
||||
private static int grpcPort = 6334;
|
||||
private static Distance distance = Distance.Cosine;
|
||||
private static QdrantEmbeddingStore embeddingStore;
|
||||
@Container
|
||||
private static final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant:latest");
|
||||
|
||||
@Container
|
||||
private static final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant:latest");
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@BeforeAll
|
||||
static void setup() throws InterruptedException, ExecutionException {
|
||||
embeddingStore =
|
||||
QdrantEmbeddingStore.builder()
|
||||
@BeforeAll
|
||||
static void setup() throws InterruptedException, ExecutionException {
|
||||
embeddingStore = QdrantEmbeddingStore.builder()
|
||||
.host(qdrant.getHost())
|
||||
.port(qdrant.getMappedPort(grpcPort))
|
||||
.port(qdrant.getGrpcPort())
|
||||
.collectionName(collectionName)
|
||||
.build();
|
||||
|
||||
QdrantClient client =
|
||||
new QdrantClient(
|
||||
QdrantGrpcClient.newBuilder(qdrant.getHost(), qdrant.getMappedPort(grpcPort), false)
|
||||
.build());
|
||||
QdrantClient client = new QdrantClient(
|
||||
QdrantGrpcClient.newBuilder(qdrant.getHost(), qdrant.getGrpcPort(), false)
|
||||
.build());
|
||||
|
||||
client
|
||||
.createCollectionAsync(
|
||||
collectionName,
|
||||
VectorParams.newBuilder().setDistance(distance).setSize(dimension).build())
|
||||
.get();
|
||||
client
|
||||
.createCollectionAsync(
|
||||
collectionName,
|
||||
VectorParams.newBuilder().setDistance(distance).setSize(dimension)
|
||||
.build())
|
||||
.get();
|
||||
|
||||
client.close();
|
||||
}
|
||||
client.close();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void teardown() {
|
||||
embeddingStore.close();
|
||||
}
|
||||
@AfterAll
|
||||
static void teardown() {
|
||||
embeddingStore.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
@Override
|
||||
protected EmbeddingStore < TextSegment > embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
embeddingStore.clearStore();
|
||||
}
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
embeddingStore.clearStore();
|
||||
}
|
||||
|
||||
// - Eq, NEq, In, NIn don't allow float and double values. Only integers and
|
||||
// strings.
|
||||
// - LT, GT, LTE, GTE allow only numbers, not alphabets.
|
||||
// - For In and NIn conditions, if the key doesn't exist in the metadata, it is
|
||||
// not matched.
|
||||
|
||||
protected static Stream < Arguments > should_filter_by_metadata_not() {
|
||||
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata_not()
|
||||
.filter(arguments -> {
|
||||
Filter filter = (Filter) arguments.get()[0];
|
||||
if (filter instanceof IsNotIn) {
|
||||
try {
|
||||
IsNotIn notIn = (IsNotIn) filter;
|
||||
Method method = notIn.getClass().getMethod("key");
|
||||
String key = (String) method.invoke(filter);
|
||||
|
||||
List < Metadata > matchingMetadatas = (List < Metadata > ) arguments.get()[1];
|
||||
// For NIn conditions, if the key doesn't exist in the metadata it is not
|
||||
// matched.
|
||||
Boolean matching = matchingMetadatas.stream()
|
||||
.allMatch(metadata -> metadata.containsKey(key));
|
||||
if (!matching) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Object firstValue = notIn.comparisonValues().stream().findFirst().get();
|
||||
return firstValue instanceof String ||
|
||||
firstValue instanceof UUID ||
|
||||
firstValue instanceof Integer ||
|
||||
firstValue instanceof Long;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else if (filter instanceof IsNotEqualTo) {
|
||||
IsNotEqualTo notEqualTo = (IsNotEqualTo) filter;
|
||||
return notEqualTo.comparisonValue() instanceof String ||
|
||||
notEqualTo.comparisonValue() instanceof UUID ||
|
||||
notEqualTo.comparisonValue() instanceof Integer ||
|
||||
notEqualTo.comparisonValue() instanceof Long;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected static Stream < Arguments > should_filter_by_metadata() {
|
||||
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata()
|
||||
.filter(arguments -> {
|
||||
Filter filter = (Filter) arguments.get()[0];
|
||||
if (filter instanceof IsLessThan) {
|
||||
IsLessThan lessThan = (IsLessThan) filter;
|
||||
return lessThan.comparisonValue() instanceof Integer ||
|
||||
lessThan.comparisonValue() instanceof Long;
|
||||
} else if (filter instanceof IsLessThanOrEqualTo) {
|
||||
IsLessThanOrEqualTo lessThanOrEqualTo = (IsLessThanOrEqualTo) filter;
|
||||
return lessThanOrEqualTo.comparisonValue() instanceof Integer ||
|
||||
lessThanOrEqualTo.comparisonValue() instanceof Long;
|
||||
} else if (filter instanceof IsGreaterThan) {
|
||||
IsGreaterThan greaterThan = (IsGreaterThan) filter;
|
||||
return greaterThan.comparisonValue() instanceof Integer ||
|
||||
greaterThan.comparisonValue() instanceof Long;
|
||||
} else if (filter instanceof IsGreaterThanOrEqualTo) {
|
||||
IsGreaterThanOrEqualTo greaterThanOrEqualTo = (IsGreaterThanOrEqualTo) filter;
|
||||
return greaterThanOrEqualTo.comparisonValue() instanceof Integer ||
|
||||
greaterThanOrEqualTo.comparisonValue() instanceof Long;
|
||||
} else if (filter instanceof IsEqualTo) {
|
||||
IsEqualTo equalTo = (IsEqualTo) filter;
|
||||
return equalTo.comparisonValue() instanceof String ||
|
||||
equalTo.comparisonValue() instanceof UUID ||
|
||||
equalTo.comparisonValue() instanceof Integer ||
|
||||
equalTo.comparisonValue() instanceof Long;
|
||||
} else if (filter instanceof IsIn) {
|
||||
IsIn in = (IsIn) filter;
|
||||
Object firstValue = in .comparisonValues().stream().findFirst().get();
|
||||
return firstValue instanceof String ||
|
||||
firstValue instanceof UUID ||
|
||||
firstValue instanceof Integer ||
|
||||
firstValue instanceof Long;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
package dev.langchain4j.store.embedding.qdrant;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.comparison.*;
|
||||
import io.qdrant.client.grpc.Points;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import java.util.Arrays;
|
||||
|
||||
class QdrantFilterConverterTest {
|
||||
|
||||
@Test
|
||||
void testIsEqualToFilter() {
|
||||
Filter filter = new IsEqualTo("num-value", 5);
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals("num-value", convertedFilter.getMust(0).getField().getKey());
|
||||
assertEquals(5, convertedFilter.getMust(0).getField().getMatch().getInteger());
|
||||
|
||||
filter = new IsEqualTo("str-value", "value");
|
||||
convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals("str-value", convertedFilter.getMust(0).getField().getKey());
|
||||
assertEquals("value", convertedFilter.getMust(0).getField().getMatch().getKeyword());
|
||||
|
||||
filter = new IsEqualTo("bool-value", true);
|
||||
convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals("bool-value", convertedFilter.getMust(0).getField().getKey());
|
||||
assertEquals(true, convertedFilter.getMust(0).getField().getMatch().getBoolean());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsNotEqualToFilter() {
|
||||
Filter filter = new IsNotEqualTo("num-value", 5);
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals("num-value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getKey());
|
||||
assertEquals(5, convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getMatch().getInteger());
|
||||
|
||||
filter = new IsNotEqualTo("str-value", "value");
|
||||
convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals("str-value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getKey());
|
||||
assertEquals("value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getMatch().getKeyword());
|
||||
|
||||
filter = new IsNotEqualTo("bool-value", true);
|
||||
convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals("bool-value", convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getKey());
|
||||
assertEquals(true, convertedFilter.getMust(0).getFilter().getMustNot(0).getField().getMatch().getBoolean());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsGreaterThanFilter() {
|
||||
Filter filter = new IsGreaterThan("key", 1);
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(convertedFilter.getMust(0).getField().getRange().getGt(), 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsLessThanFilter() {
|
||||
Filter filter = new IsLessThan("key", 10);
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(convertedFilter.getMust(0).getField().getRange().getLt(), 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsGreaterThanOrEqualToFilter() {
|
||||
Filter filter = new IsGreaterThanOrEqualTo("key", 1);
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(convertedFilter.getMust(0).getField().getRange().getGte(), 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsLessThanOrEqualToFilter() {
|
||||
Filter filter = new IsLessThanOrEqualTo("key", 10);
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(convertedFilter.getMust(0).getField().getRange().getLte(), 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testInFilter() {
|
||||
Filter filter = new IsIn("key", Arrays.asList(1, 2, 3));
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(3, convertedFilter.getMust(0).getField().getMatch().getIntegers().getIntegersCount());
|
||||
|
||||
filter = new IsIn("key", Arrays.asList("a", "b", "c"));
|
||||
convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(3, convertedFilter.getMust(0).getField().getMatch().getKeywords().getStringsCount());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNInFilter() {
|
||||
Filter filter = new IsNotIn("key", Arrays.asList(1, 2, 3, 4));
|
||||
Points.Filter convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(4, convertedFilter.getMust(0).getField().getMatch().getExceptIntegers().getIntegersCount());
|
||||
|
||||
filter = new IsNotIn("key", Arrays.asList("a", "b", "c", "k"));
|
||||
convertedFilter = QdrantFilterConverter.convertExpression(filter);
|
||||
assertNotNull(convertedFilter);
|
||||
assertEquals(1, convertedFilter.getMustCount());
|
||||
assertEquals(4, convertedFilter.getMust(0).getField().getMatch().getExceptKeywords().getStringsCount());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue