Add advanced RAG with Azure AI Search (#587): cosmetics
This commit is contained in:
parent
e8bfe166ea
commit
12f2dde087
|
@ -16,7 +16,6 @@ import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
|||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
|
||||
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException;
|
||||
import dev.langchain4j.store.embedding.azure.search.Document;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -24,7 +23,6 @@ import org.slf4j.LoggerFactory;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
|
@ -35,29 +33,42 @@ import static java.util.stream.Collectors.toList;
|
|||
|
||||
/**
|
||||
* Represents Azure AI Search Service as a {@link ContentRetriever}.
|
||||
*
|
||||
* <br>
|
||||
* This class supports 4 {@link AzureAiSearchQueryType}s:
|
||||
* - VECTOR: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
|
||||
* See https://learn.microsoft.com/en-us/azure/search/vector-search-overview for more information.
|
||||
* - FULL_TEXT: Uses the full text search to find the most similar {@link TextSegment}s.
|
||||
* See https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture for more information.
|
||||
* - HYBRID: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s.
|
||||
* See https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview for more information.
|
||||
* - HYBRID_WITH_RERANKING: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s, and uses the semantic re-ranker algorithm to rank the results.
|
||||
* See https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking for more information.
|
||||
* <br>
|
||||
* - {@code VECTOR}: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/vector-search-overview">here</a>.
|
||||
* <br>
|
||||
* - {@code FULL_TEXT}: Uses the full text search to find the most similar {@code TextSegment}s.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture">here</a>.
|
||||
* <br>
|
||||
* - {@code HYBRID}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview">here</a>.
|
||||
* <br>
|
||||
* - {@code HYBRID_WITH_RERANKING}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s,
|
||||
* and uses the semantic re-ranker algorithm to rank the results.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking">here</a>.
|
||||
*/
|
||||
public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddingStore implements ContentRetriever {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetriever.class);
|
||||
|
||||
private EmbeddingModel embeddingModel;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
private AzureAiSearchQueryType azureAiSearchQueryType;
|
||||
private final AzureAiSearchQueryType azureAiSearchQueryType;
|
||||
|
||||
private int maxResults;
|
||||
private double minScore;
|
||||
private final int maxResults;
|
||||
private final double minScore;
|
||||
|
||||
public AzureAiSearchContentRetriever(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index, EmbeddingModel embeddingModel, int maxResults, double minScore, AzureAiSearchQueryType azureAiSearchQueryType) {
|
||||
public AzureAiSearchContentRetriever(String endpoint,
|
||||
AzureKeyCredential keyCredential,
|
||||
TokenCredential tokenCredential,
|
||||
int dimensions,
|
||||
SearchIndex index,
|
||||
EmbeddingModel embeddingModel,
|
||||
int maxResults,
|
||||
double minScore,
|
||||
AzureAiSearchQueryType azureAiSearchQueryType) {
|
||||
ensureNotNull(endpoint, "endpoint");
|
||||
ensureTrue(keyCredential != null || tokenCredential != null, "either keyCredential or tokenCredential must be set");
|
||||
ensureTrue(dimensions > 0 || index != null, "either dimensions or index must be set");
|
||||
|
@ -87,23 +98,37 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
* Add content to the full text search engine.
|
||||
*/
|
||||
public void add(String content) {
|
||||
add(singletonList(content));
|
||||
add(singletonList(TextSegment.from(content)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a list of content to the full text search engine.
|
||||
* Add {@code Document} to the full text search engine.
|
||||
*/
|
||||
public void add(List<String> contents) {
|
||||
if (isNullOrEmpty(contents)) {
|
||||
public void add(dev.langchain4j.data.document.Document document) {
|
||||
add(singletonList(document.toTextSegment()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add {@code TextSegment} to the full text search engine.
|
||||
*/
|
||||
public void add(TextSegment segment) {
|
||||
add(singletonList(segment));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a list of {@code TextSegment}s to the full text search engine.
|
||||
*/
|
||||
public void add(List<TextSegment> segments) {
|
||||
if (isNullOrEmpty(segments)) {
|
||||
log.info("Empty embeddings - no ops");
|
||||
return;
|
||||
}
|
||||
|
||||
List<Document> documents = new ArrayList<>();
|
||||
for (String content : contents) {
|
||||
for (TextSegment segment : segments) {
|
||||
Document document = new Document();
|
||||
document.setId(randomUUID());
|
||||
document.setContent(content);
|
||||
document.setContent(segment.text());
|
||||
documents.add(document);
|
||||
}
|
||||
|
||||
|
@ -126,7 +151,6 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
.map(EmbeddingMatch::embedded)
|
||||
.map(Content::from)
|
||||
.collect(toList());
|
||||
|
||||
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
|
||||
String content = query.text();
|
||||
return findRelevantWithFullText(content, maxResults, minScore);
|
||||
|
@ -143,7 +167,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
}
|
||||
}
|
||||
|
||||
List<Content> findRelevantWithFullText(String content, int maxResults, double minScore) {
|
||||
private List<Content> findRelevantWithFullText(String content, int maxResults, double minScore) {
|
||||
SearchPagedIterable searchResults =
|
||||
searchClient.search(content,
|
||||
new SearchOptions()
|
||||
|
@ -153,7 +177,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.FULL_TEXT, minScore);
|
||||
}
|
||||
|
||||
List<Content> findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
|
||||
private List<Content> findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
|
||||
List<Float> vector = referenceEmbedding.vectorAsList();
|
||||
|
||||
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
|
||||
|
@ -170,7 +194,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.HYBRID, minScore);
|
||||
}
|
||||
|
||||
List<Content> findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
|
||||
private List<Content> findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
|
||||
List<Float> vector = referenceEmbedding.vectorAsList();
|
||||
|
||||
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
|
||||
|
@ -211,11 +235,11 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
if (azureAiSearchQueryType == AzureAiSearchQueryType.VECTOR) {
|
||||
// Calculates LangChain4j's RelevanceScore from Azure AI Search's score.
|
||||
|
||||
// Score in Azure AI Search is transformed into a cosine similarity as described here:
|
||||
// https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
|
||||
// Score in Azure AI Search is transformed into a cosine similarity as described here:
|
||||
// https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
|
||||
|
||||
// RelevanceScore in LangChain4j is a derivative of cosine similarity,
|
||||
// but it compresses it into 0..1 range (instead of -1..1) for ease of use.
|
||||
// RelevanceScore in LangChain4j is a derivative of cosine similarity,
|
||||
// but it compresses it into 0..1 range (instead of -1..1) for ease of use.
|
||||
double score = searchResult.getScore();
|
||||
return AbstractAzureAiSearchEmbeddingStore.fromAzureScoreToRelevanceScore(score);
|
||||
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
|
||||
|
@ -282,6 +306,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
|
||||
/**
|
||||
* Used to authenticate to Azure OpenAI with Azure Active Directory credentials.
|
||||
*
|
||||
* @param tokenCredential the credentials to authenticate with Azure Active Directory
|
||||
* @return builder
|
||||
*/
|
||||
|
@ -360,7 +385,8 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
|
|||
}
|
||||
|
||||
public AzureAiSearchContentRetriever build() {
|
||||
return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index, embeddingModel, maxResults, minScore, azureAiSearchQueryType);
|
||||
return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index,
|
||||
embeddingModel, maxResults, minScore, azureAiSearchQueryType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
package dev.langchain4j.rag.content.retriever.azure.search;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
public enum AzureAiSearchQueryType {
|
||||
|
||||
/**
|
||||
* Uses the vector search algorithm to find the most similar {@link TextSegment}s.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/vector-search-overview">here</a>.
|
||||
*/
|
||||
VECTOR,
|
||||
|
||||
/**
|
||||
* Uses the full text search to find the most similar {@code TextSegment}s.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture">here</a>.
|
||||
*/
|
||||
FULL_TEXT,
|
||||
|
||||
/**
|
||||
* Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview">here</a>.
|
||||
*/
|
||||
HYBRID,
|
||||
|
||||
/**
|
||||
* Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s,
|
||||
* and uses the semantic re-ranker algorithm to rank the results.
|
||||
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking">here</a>.
|
||||
*/
|
||||
HYBRID_WITH_RERANKING
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.azure.search;
|
||||
|
||||
public enum AzureAiSearchQueryType {
|
||||
|
||||
VECTOR, FULL_TEXT, HYBRID, HYBRID_WITH_RERANKING;
|
||||
}
|
|
@ -2,8 +2,6 @@ package dev.langchain4j.rag.content.retriever.azure.search;
|
|||
|
||||
import com.azure.search.documents.models.SearchResult;
|
||||
import com.azure.search.documents.models.SemanticSearchResult;
|
||||
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever;
|
||||
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
|
|
@ -7,7 +7,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
|||
import dev.langchain4j.rag.content.Content;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.store.embedding.*;
|
||||
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
|
|
Loading…
Reference in New Issue