Add advanced RAG with Azure AI Search (#587): cosmetics

This commit is contained in:
LangChain4j 2024-03-21 08:22:46 +01:00
parent e8bfe166ea
commit 12f2dde087
5 changed files with 88 additions and 40 deletions

View File

@ -16,7 +16,6 @@ import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException;
import dev.langchain4j.store.embedding.azure.search.Document;
import org.slf4j.Logger;
@ -24,7 +23,6 @@ import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.randomUUID;
@ -35,29 +33,42 @@ import static java.util.stream.Collectors.toList;
/**
* Represents Azure AI Search Service as a {@link ContentRetriever}.
*
* <br>
* This class supports 4 {@link AzureAiSearchQueryType}s:
* - VECTOR: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
* See https://learn.microsoft.com/en-us/azure/search/vector-search-overview for more information.
* - FULL_TEXT: Uses the full text search to find the most similar {@link TextSegment}s.
* See https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture for more information.
* - HYBRID: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s.
* See https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview for more information.
* - HYBRID_WITH_RERANKING: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s, and uses the semantic re-ranker algorithm to rank the results.
* See https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking for more information.
* <br>
* - {@code VECTOR}: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/vector-search-overview">here</a>.
* <br>
* - {@code FULL_TEXT}: Uses the full text search to find the most similar {@code TextSegment}s.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture">here</a>.
* <br>
* - {@code HYBRID}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview">here</a>.
* <br>
* - {@code HYBRID_WITH_RERANKING}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s,
* and uses the semantic re-ranker algorithm to rank the results.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking">here</a>.
*/
public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddingStore implements ContentRetriever {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetriever.class);
private EmbeddingModel embeddingModel;
private final EmbeddingModel embeddingModel;
private AzureAiSearchQueryType azureAiSearchQueryType;
private final AzureAiSearchQueryType azureAiSearchQueryType;
private int maxResults;
private double minScore;
private final int maxResults;
private final double minScore;
public AzureAiSearchContentRetriever(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index, EmbeddingModel embeddingModel, int maxResults, double minScore, AzureAiSearchQueryType azureAiSearchQueryType) {
public AzureAiSearchContentRetriever(String endpoint,
AzureKeyCredential keyCredential,
TokenCredential tokenCredential,
int dimensions,
SearchIndex index,
EmbeddingModel embeddingModel,
int maxResults,
double minScore,
AzureAiSearchQueryType azureAiSearchQueryType) {
ensureNotNull(endpoint, "endpoint");
ensureTrue(keyCredential != null || tokenCredential != null, "either keyCredential or tokenCredential must be set");
ensureTrue(dimensions > 0 || index != null, "either dimensions or index must be set");
@ -87,23 +98,37 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
* Add content to the full text search engine.
*/
public void add(String content) {
add(singletonList(content));
add(singletonList(TextSegment.from(content)));
}
/**
* Add a list of content to the full text search engine.
* Add {@code Document} to the full text search engine.
*/
public void add(List<String> contents) {
if (isNullOrEmpty(contents)) {
public void add(dev.langchain4j.data.document.Document document) {
add(singletonList(document.toTextSegment()));
}
/**
* Add {@code TextSegment} to the full text search engine.
*/
public void add(TextSegment segment) {
add(singletonList(segment));
}
/**
* Add a list of {@code TextSegment}s to the full text search engine.
*/
public void add(List<TextSegment> segments) {
if (isNullOrEmpty(segments)) {
log.info("Empty embeddings - no ops");
return;
}
List<Document> documents = new ArrayList<>();
for (String content : contents) {
for (TextSegment segment : segments) {
Document document = new Document();
document.setId(randomUUID());
document.setContent(content);
document.setContent(segment.text());
documents.add(document);
}
@ -126,7 +151,6 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
.map(EmbeddingMatch::embedded)
.map(Content::from)
.collect(toList());
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
String content = query.text();
return findRelevantWithFullText(content, maxResults, minScore);
@ -143,7 +167,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
}
}
List<Content> findRelevantWithFullText(String content, int maxResults, double minScore) {
private List<Content> findRelevantWithFullText(String content, int maxResults, double minScore) {
SearchPagedIterable searchResults =
searchClient.search(content,
new SearchOptions()
@ -153,7 +177,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.FULL_TEXT, minScore);
}
List<Content> findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
private List<Content> findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
List<Float> vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
@ -170,7 +194,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.HYBRID, minScore);
}
List<Content> findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
private List<Content> findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
List<Float> vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
@ -211,11 +235,11 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
if (azureAiSearchQueryType == AzureAiSearchQueryType.VECTOR) {
// Calculates LangChain4j's RelevanceScore from Azure AI Search's score.
// Score in Azure AI Search is transformed into a cosine similarity as described here:
// https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
// Score in Azure AI Search is transformed into a cosine similarity as described here:
// https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
// RelevanceScore in LangChain4j is a derivative of cosine similarity,
// but it compresses it into 0..1 range (instead of -1..1) for ease of use.
// RelevanceScore in LangChain4j is a derivative of cosine similarity,
// but it compresses it into 0..1 range (instead of -1..1) for ease of use.
double score = searchResult.getScore();
return AbstractAzureAiSearchEmbeddingStore.fromAzureScoreToRelevanceScore(score);
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
@ -282,6 +306,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
/**
* Used to authenticate to Azure OpenAI with Azure Active Directory credentials.
*
* @param tokenCredential the credentials to authenticate with Azure Active Directory
* @return builder
*/
@ -360,7 +385,8 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
}
public AzureAiSearchContentRetriever build() {
return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index, embeddingModel, maxResults, minScore, azureAiSearchQueryType);
return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index,
embeddingModel, maxResults, minScore, azureAiSearchQueryType);
}
}
}

View File

@ -0,0 +1,31 @@
package dev.langchain4j.rag.content.retriever.azure.search;
import dev.langchain4j.data.segment.TextSegment;
public enum AzureAiSearchQueryType {
/**
* Uses the vector search algorithm to find the most similar {@link TextSegment}s.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/vector-search-overview">here</a>.
*/
VECTOR,
/**
* Uses the full text search to find the most similar {@code TextSegment}s.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture">here</a>.
*/
FULL_TEXT,
/**
* Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview">here</a>.
*/
HYBRID,
/**
* Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s,
* and uses the semantic re-ranker algorithm to rank the results.
* More details can be found <a href="https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking">here</a>.
*/
HYBRID_WITH_RERANKING
}

View File

@ -1,6 +0,0 @@
package dev.langchain4j.store.embedding.azure.search;
public enum AzureAiSearchQueryType {
VECTOR, FULL_TEXT, HYBRID, HYBRID_WITH_RERANKING;
}

View File

@ -2,8 +2,6 @@ package dev.langchain4j.rag.content.retriever.azure.search;
import com.azure.search.documents.models.SearchResult;
import com.azure.search.documents.models.SemanticSearchResult;
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;

View File

@ -7,7 +7,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.*;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;