diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java
index c2c5aab37..1cb5cb5aa 100644
--- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java
+++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java
@@ -16,7 +16,6 @@ import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore;
-import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException;
import dev.langchain4j.store.embedding.azure.search.Document;
import org.slf4j.Logger;
@@ -24,7 +23,6 @@ import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.randomUUID;
@@ -35,29 +33,42 @@ import static java.util.stream.Collectors.toList;
/**
* Represents Azure AI Search Service as a {@link ContentRetriever}.
- *
+ *
* This class supports 4 {@link AzureAiSearchQueryType}s:
- * - VECTOR: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
- * See https://learn.microsoft.com/en-us/azure/search/vector-search-overview for more information.
- * - FULL_TEXT: Uses the full text search to find the most similar {@link TextSegment}s.
- * See https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture for more information.
- * - HYBRID: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s.
- * See https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview for more information.
- * - HYBRID_WITH_RERANKING: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s, and uses the semantic re-ranker algorithm to rank the results.
- * See https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking for more information.
+ *
+ * - {@code VECTOR}: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
+ * More details can be found here.
+ *
+ * - {@code FULL_TEXT}: Uses the full text search to find the most similar {@code TextSegment}s.
+ * More details can be found here.
+ *
+ * - {@code HYBRID}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s.
+ * More details can be found here.
+ *
+ * - {@code HYBRID_WITH_RERANKING}: Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s,
+ * and uses the semantic re-ranker algorithm to rank the results.
+ * More details can be found here.
*/
public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddingStore implements ContentRetriever {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetriever.class);
- private EmbeddingModel embeddingModel;
+ private final EmbeddingModel embeddingModel;
- private AzureAiSearchQueryType azureAiSearchQueryType;
+ private final AzureAiSearchQueryType azureAiSearchQueryType;
- private int maxResults;
- private double minScore;
+ private final int maxResults;
+ private final double minScore;
- public AzureAiSearchContentRetriever(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index, EmbeddingModel embeddingModel, int maxResults, double minScore, AzureAiSearchQueryType azureAiSearchQueryType) {
+ public AzureAiSearchContentRetriever(String endpoint,
+ AzureKeyCredential keyCredential,
+ TokenCredential tokenCredential,
+ int dimensions,
+ SearchIndex index,
+ EmbeddingModel embeddingModel,
+ int maxResults,
+ double minScore,
+ AzureAiSearchQueryType azureAiSearchQueryType) {
ensureNotNull(endpoint, "endpoint");
ensureTrue(keyCredential != null || tokenCredential != null, "either keyCredential or tokenCredential must be set");
ensureTrue(dimensions > 0 || index != null, "either dimensions or index must be set");
@@ -87,23 +98,37 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
* Add content to the full text search engine.
*/
public void add(String content) {
- add(singletonList(content));
+ add(singletonList(TextSegment.from(content)));
}
/**
- * Add a list of content to the full text search engine.
+ * Add {@code Document} to the full text search engine.
*/
- public void add(List contents) {
- if (isNullOrEmpty(contents)) {
+ public void add(dev.langchain4j.data.document.Document document) {
+ add(singletonList(document.toTextSegment()));
+ }
+
+ /**
+ * Add {@code TextSegment} to the full text search engine.
+ */
+ public void add(TextSegment segment) {
+ add(singletonList(segment));
+ }
+
+ /**
+ * Add a list of {@code TextSegment}s to the full text search engine.
+ */
+ public void add(List segments) {
+ if (isNullOrEmpty(segments)) {
log.info("Empty embeddings - no ops");
return;
}
List documents = new ArrayList<>();
- for (String content : contents) {
+ for (TextSegment segment : segments) {
Document document = new Document();
document.setId(randomUUID());
- document.setContent(content);
+ document.setContent(segment.text());
documents.add(document);
}
@@ -126,7 +151,6 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
.map(EmbeddingMatch::embedded)
.map(Content::from)
.collect(toList());
-
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
String content = query.text();
return findRelevantWithFullText(content, maxResults, minScore);
@@ -143,7 +167,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
}
}
- List findRelevantWithFullText(String content, int maxResults, double minScore) {
+ private List findRelevantWithFullText(String content, int maxResults, double minScore) {
SearchPagedIterable searchResults =
searchClient.search(content,
new SearchOptions()
@@ -153,7 +177,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.FULL_TEXT, minScore);
}
- List findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
+ private List findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
List vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
@@ -170,7 +194,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.HYBRID, minScore);
}
- List findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
+ private List findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
List vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
@@ -211,11 +235,11 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
if (azureAiSearchQueryType == AzureAiSearchQueryType.VECTOR) {
// Calculates LangChain4j's RelevanceScore from Azure AI Search's score.
- // Score in Azure AI Search is transformed into a cosine similarity as described here:
- // https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
+ // Score in Azure AI Search is transformed into a cosine similarity as described here:
+ // https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
- // RelevanceScore in LangChain4j is a derivative of cosine similarity,
- // but it compresses it into 0..1 range (instead of -1..1) for ease of use.
+ // RelevanceScore in LangChain4j is a derivative of cosine similarity,
+ // but it compresses it into 0..1 range (instead of -1..1) for ease of use.
double score = searchResult.getScore();
return AbstractAzureAiSearchEmbeddingStore.fromAzureScoreToRelevanceScore(score);
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
@@ -282,6 +306,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
/**
* Used to authenticate to Azure OpenAI with Azure Active Directory credentials.
+ *
* @param tokenCredential the credentials to authenticate with Azure Active Directory
* @return builder
*/
@@ -360,7 +385,8 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin
}
public AzureAiSearchContentRetriever build() {
- return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index, embeddingModel, maxResults, minScore, azureAiSearchQueryType);
+ return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index,
+ embeddingModel, maxResults, minScore, azureAiSearchQueryType);
}
}
}
diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java
new file mode 100644
index 000000000..660c0b440
--- /dev/null
+++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchQueryType.java
@@ -0,0 +1,31 @@
+package dev.langchain4j.rag.content.retriever.azure.search;
+
+import dev.langchain4j.data.segment.TextSegment;
+
+public enum AzureAiSearchQueryType {
+
+ /**
+ * Uses the vector search algorithm to find the most similar {@link TextSegment}s.
+ * More details can be found here.
+ */
+ VECTOR,
+
+ /**
+ * Uses the full text search to find the most similar {@code TextSegment}s.
+ * More details can be found here.
+ */
+ FULL_TEXT,
+
+ /**
+ * Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s.
+ * More details can be found here.
+ */
+ HYBRID,
+
+ /**
+ * Uses the hybrid search (vector + full text) to find the most similar {@code TextSegment}s,
+ * and uses the semantic re-ranker algorithm to rank the results.
+ * More details can be found here.
+ */
+ HYBRID_WITH_RERANKING
+}
diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java
deleted file mode 100644
index 1ac7f15a7..000000000
--- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchQueryType.java
+++ /dev/null
@@ -1,6 +0,0 @@
-package dev.langchain4j.store.embedding.azure.search;
-
-public enum AzureAiSearchQueryType {
-
- VECTOR, FULL_TEXT, HYBRID, HYBRID_WITH_RERANKING;
-}
diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java
index cf7207b8c..487811d7e 100644
--- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java
+++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java
@@ -2,8 +2,6 @@ package dev.langchain4j.rag.content.retriever.azure.search;
import com.azure.search.documents.models.SearchResult;
import com.azure.search.documents.models.SemanticSearchResult;
-import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever;
-import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java
index d4f383056..331962de8 100644
--- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java
+++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTestIT.java
@@ -7,7 +7,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.*;
-import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;