From 9af248c9801235a9614673851e38998e3720251d Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Wed, 30 Oct 2024 15:42:14 +0100 Subject: [PATCH] InfinispanEmbeddingStore: adopt new EmbeddingStore.search(EmbeddingSearchRequest) API --- .../infinispan/InfinispanEmbeddingStore.java | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java index f4dcdfc23..66fab5a32 100644 --- a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java @@ -4,6 +4,8 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; import org.infinispan.client.hotrod.RemoteCache; import org.infinispan.client.hotrod.RemoteCacheManager; @@ -32,7 +34,6 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.internal.ValidationUtils.ensureTrue; import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_CACHE_CONFIG; -import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_DISTANCE; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; @@ -168,14 +169,18 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { } @Override - public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { - Query query = remoteCache.query("select i, score(i) from " + storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~" + storeConfiguration.distance()); - List hits = query.maxResults(maxResults).list(); + public EmbeddingSearchResult search(EmbeddingSearchRequest request) { + Query query = remoteCache.query("select i, score(i) from " + + storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " + + Arrays.toString(request.queryEmbedding().vector()) + + "~" + storeConfiguration.distance()); - return hits.stream().map(obj -> { + List hits = query.maxResults(request.maxResults()).list(); + + List> matches = hits.stream().map(obj -> { LangChainInfinispanItem item = (LangChainInfinispanItem) obj[0]; Float score = (Float) obj[1]; - if (score.doubleValue() < minScore) { + if (score.doubleValue() < request.minScore()) { return null; } TextSegment embedded = null; @@ -188,7 +193,9 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { } Embedding embedding = new Embedding(item.embedding()); return new EmbeddingMatch<>(score.doubleValue(), item.id(), embedding, embedded); - }).filter(Objects::nonNull).collect(Collectors.toList()); + }).filter(Objects::nonNull).collect(toList()); + + return new EmbeddingSearchResult<>(matches); } private void addInternal(String id, Embedding embedding, TextSegment embedded) {