InfinispanEmbeddingStore: adopt new EmbeddingStore.search(EmbeddingSearchRequest) API

This commit is contained in:
LangChain4j 2024-10-30 15:42:14 +01:00
parent e66eb6327a
commit 9af248c980
1 changed files with 14 additions and 7 deletions

View File

@ -4,6 +4,8 @@ import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager;
@ -32,7 +34,6 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_CACHE_CONFIG;
import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_DISTANCE;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
@ -168,14 +169,18 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
Query<Object[]> query = remoteCache.query("select i, score(i) from " + storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~" + storeConfiguration.distance());
List<Object[]> hits = query.maxResults(maxResults).list();
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
Query<Object[]> query = remoteCache.query("select i, score(i) from " +
storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " +
Arrays.toString(request.queryEmbedding().vector()) +
"~" + storeConfiguration.distance());
return hits.stream().map(obj -> {
List<Object[]> hits = query.maxResults(request.maxResults()).list();
List<EmbeddingMatch<TextSegment>> matches = hits.stream().map(obj -> {
LangChainInfinispanItem item = (LangChainInfinispanItem) obj[0];
Float score = (Float) obj[1];
if (score.doubleValue() < minScore) {
if (score.doubleValue() < request.minScore()) {
return null;
}
TextSegment embedded = null;
@ -188,7 +193,9 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
}
Embedding embedding = new Embedding(item.embedding());
return new EmbeddingMatch<>(score.doubleValue(), item.id(), embedding, embedded);
}).filter(Objects::nonNull).collect(Collectors.toList());
}).filter(Objects::nonNull).collect(toList());
return new EmbeddingSearchResult<>(matches);
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {