From 12f2dde0870693daf7d90e040b54ad2af9d86418 Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Thu, 21 Mar 2024 08:22:46 +0100 Subject: [PATCH] Add advanced RAG with Azure AI Search (#587): cosmetics --- .../search/AzureAiSearchContentRetriever.java | 88 ++++++++++++------- .../azure/search/AzureAiSearchQueryType.java | 31 +++++++ .../azure/search/AzureAiSearchQueryType.java | 6 -- .../AzureAiSearchContentRetrieverTest.java | 2 - .../AzureAiSearchContentRetrieverTestIT.java | 1 - 5 files changed, 88 insertions(+), 40 deletions(-) create mode 100644 langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java delete mode 100644 langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java index c2c5aab37..1cb5cb5aa 100644 --- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java +++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java @@ -16,7 +16,6 @@ import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; import dev.langchain4j.rag.query.Query; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore; -import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType; import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException; import dev.langchain4j.store.embedding.azure.search.Document; import org.slf4j.Logger; @@ -24,7 +23,6 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; -import java.util.Map; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.Utils.randomUUID; @@ -35,29 +33,42 @@ import static java.util.stream.Collectors.toList; /** * Represents Azure AI Search Service as a {@link ContentRetriever}. - * + *
* This class supports 4 {@link AzureAiSearchQueryType}s: - * - VECTOR: Uses the vector search algorithm to find the most similar {@link TextSegment}s. - * See https://learn.microsoft.com/en-us/azure/search/vector-search-overview for more information. - * - FULL_TEXT: Uses the full text search to find the most similar {@link TextSegment}s. - * See https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture for more information. - * - HYBRID: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s. - * See https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview for more information. - * - HYBRID_WITH_RERANKING: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s, and uses the semantic re-ranker algorithm to rank the results. - * See https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking for more information. + *
+ * - {@code VECTOR}: Uses the vector search algorithm to find the most similar {@link TextSegment}s. + * More details can be found here. + *
+ * - {@code FULL_TEXT}: Uses the full text search to find the most similar {@code TextSegment}s. + * More details can be found here. + *
+ * - {@code HYBRID}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s. + * More details can be found here. + *
+ * - {@code HYBRID_WITH_RERANKING}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s, + * and uses the semantic re-ranker algorithm to rank the results. + * More details can be found here. */ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddingStore implements ContentRetriever { private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetriever.class); - private EmbeddingModel embeddingModel; + private final EmbeddingModel embeddingModel; - private AzureAiSearchQueryType azureAiSearchQueryType; + private final AzureAiSearchQueryType azureAiSearchQueryType; - private int maxResults; - private double minScore; + private final int maxResults; + private final double minScore; - public AzureAiSearchContentRetriever(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index, EmbeddingModel embeddingModel, int maxResults, double minScore, AzureAiSearchQueryType azureAiSearchQueryType) { + public AzureAiSearchContentRetriever(String endpoint, + AzureKeyCredential keyCredential, + TokenCredential tokenCredential, + int dimensions, + SearchIndex index, + EmbeddingModel embeddingModel, + int maxResults, + double minScore, + AzureAiSearchQueryType azureAiSearchQueryType) { ensureNotNull(endpoint, "endpoint"); ensureTrue(keyCredential != null || tokenCredential != null, "either keyCredential or tokenCredential must be set"); ensureTrue(dimensions > 0 || index != null, "either dimensions or index must be set"); @@ -87,23 +98,37 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin * Add content to the full text search engine. */ public void add(String content) { - add(singletonList(content)); + add(singletonList(TextSegment.from(content))); } /** - * Add a list of content to the full text search engine. + * Add {@code Document} to the full text search engine. */ - public void add(List contents) { - if (isNullOrEmpty(contents)) { + public void add(dev.langchain4j.data.document.Document document) { + add(singletonList(document.toTextSegment())); + } + + /** + * Add {@code TextSegment} to the full text search engine. + */ + public void add(TextSegment segment) { + add(singletonList(segment)); + } + + /** + * Add a list of {@code TextSegment}s to the full text search engine. + */ + public void add(List segments) { + if (isNullOrEmpty(segments)) { log.info("Empty embeddings - no ops"); return; } List documents = new ArrayList<>(); - for (String content : contents) { + for (TextSegment segment : segments) { Document document = new Document(); document.setId(randomUUID()); - document.setContent(content); + document.setContent(segment.text()); documents.add(document); } @@ -126,7 +151,6 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin .map(EmbeddingMatch::embedded) .map(Content::from) .collect(toList()); - } else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) { String content = query.text(); return findRelevantWithFullText(content, maxResults, minScore); @@ -143,7 +167,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin } } - List findRelevantWithFullText(String content, int maxResults, double minScore) { + private List findRelevantWithFullText(String content, int maxResults, double minScore) { SearchPagedIterable searchResults = searchClient.search(content, new SearchOptions() @@ -153,7 +177,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin return mapResultsToContentList(searchResults, AzureAiSearchQueryType.FULL_TEXT, minScore); } - List findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) { + private List findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) { List vector = referenceEmbedding.vectorAsList(); VectorizedQuery vectorizedQuery = new VectorizedQuery(vector) @@ -170,7 +194,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin return mapResultsToContentList(searchResults, AzureAiSearchQueryType.HYBRID, minScore); } - List findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) { + private List findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) { List vector = referenceEmbedding.vectorAsList(); VectorizedQuery vectorizedQuery = new VectorizedQuery(vector) @@ -211,11 +235,11 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin if (azureAiSearchQueryType == AzureAiSearchQueryType.VECTOR) { // Calculates LangChain4j's RelevanceScore from Azure AI Search's score. - // Score in Azure AI Search is transformed into a cosine similarity as described here: - // https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results + // Score in Azure AI Search is transformed into a cosine similarity as described here: + // https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results - // RelevanceScore in LangChain4j is a derivative of cosine similarity, - // but it compresses it into 0..1 range (instead of -1..1) for ease of use. + // RelevanceScore in LangChain4j is a derivative of cosine similarity, + // but it compresses it into 0..1 range (instead of -1..1) for ease of use. double score = searchResult.getScore(); return AbstractAzureAiSearchEmbeddingStore.fromAzureScoreToRelevanceScore(score); } else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) { @@ -282,6 +306,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin /** * Used to authenticate to Azure OpenAI with Azure Active Directory credentials. + * * @param tokenCredential the credentials to authenticate with Azure Active Directory * @return builder */ @@ -360,7 +385,8 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin } public AzureAiSearchContentRetriever build() { - return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index, embeddingModel, maxResults, minScore, azureAiSearchQueryType); + return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index, + embeddingModel, maxResults, minScore, azureAiSearchQueryType); } } } diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java new file mode 100644 index 000000000..660c0b440 --- /dev/null +++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java @@ -0,0 +1,31 @@ +package dev.langchain4j.rag.content.retriever.azure.search; + +import dev.langchain4j.data.segment.TextSegment; + +public enum AzureAiSearchQueryType { + + /** + * Uses the vector search algorithm to find the most similar {@link TextSegment}s. + * More details can be found here. + */ + VECTOR, + + /** + * Uses the full text search to find the most similar {@code TextSegment}s. + * More details can be found here. + */ + FULL_TEXT, + + /** + * Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s. + * More details can be found here. + */ + HYBRID, + + /** + * Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s, + * and uses the semantic re-ranker algorithm to rank the results. + * More details can be found here. + */ + HYBRID_WITH_RERANKING +} diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java deleted file mode 100644 index 1ac7f15a7..000000000 --- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java +++ /dev/null @@ -1,6 +0,0 @@ -package dev.langchain4j.store.embedding.azure.search; - -public enum AzureAiSearchQueryType { - - VECTOR, FULL_TEXT, HYBRID, HYBRID_WITH_RERANKING; -} diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java index cf7207b8c..487811d7e 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java @@ -2,8 +2,6 @@ package dev.langchain4j.rag.content.retriever.azure.search; import com.azure.search.documents.models.SearchResult; import com.azure.search.documents.models.SemanticSearchResult; -import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever; -import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java index d4f383056..331962de8 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java @@ -7,7 +7,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.query.Query; import dev.langchain4j.store.embedding.*; -import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger;