InfinispanEmbeddingStore: adopt new EmbeddingStore.search(EmbeddingSearchRequest) API
This commit is contained in:
parent
e66eb6327a
commit
9af248c980
|
@ -4,6 +4,8 @@ import dev.langchain4j.data.document.Metadata;
|
|||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import org.infinispan.client.hotrod.RemoteCache;
|
||||
import org.infinispan.client.hotrod.RemoteCacheManager;
|
||||
|
@ -32,7 +34,6 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
|||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
|
||||
import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_CACHE_CONFIG;
|
||||
import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_DISTANCE;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
@ -168,14 +169,18 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||
Query<Object[]> query = remoteCache.query("select i, score(i) from " + storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~" + storeConfiguration.distance());
|
||||
List<Object[]> hits = query.maxResults(maxResults).list();
|
||||
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
|
||||
Query<Object[]> query = remoteCache.query("select i, score(i) from " +
|
||||
storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " +
|
||||
Arrays.toString(request.queryEmbedding().vector()) +
|
||||
"~" + storeConfiguration.distance());
|
||||
|
||||
return hits.stream().map(obj -> {
|
||||
List<Object[]> hits = query.maxResults(request.maxResults()).list();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches = hits.stream().map(obj -> {
|
||||
LangChainInfinispanItem item = (LangChainInfinispanItem) obj[0];
|
||||
Float score = (Float) obj[1];
|
||||
if (score.doubleValue() < minScore) {
|
||||
if (score.doubleValue() < request.minScore()) {
|
||||
return null;
|
||||
}
|
||||
TextSegment embedded = null;
|
||||
|
@ -188,7 +193,9 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
Embedding embedding = new Embedding(item.embedding());
|
||||
return new EmbeddingMatch<>(score.doubleValue(), item.id(), embedding, embedded);
|
||||
}).filter(Objects::nonNull).collect(Collectors.toList());
|
||||
}).filter(Objects::nonNull).collect(toList());
|
||||
|
||||
return new EmbeddingSearchResult<>(matches);
|
||||
}
|
||||
|
||||
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
|
||||
|
|
Loading…
Reference in New Issue