Add advanced RAG with Azure AI Search (#587)

This PR should fix #576 and add advanced RAG with hybrid search and
semantic re-ranking with Azure AI Search.

In the current implementation, the scoring for full text search, hybrid
search and semantic search are done using comments directly from the
Azure AI Search team, as it seems the documentation is only correct for
vector search. Have a look at the `fromAzureScoreToRelevanceScore`
function for more information.
This commit is contained in:
Julien Dubois 2024-03-21 08:00:52 +01:00 committed by GitHub
parent 3783e95587
commit e8bfe166ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1052 additions and 301 deletions

View File

@ -15,7 +15,7 @@ jobs:
java_version: [8, 11, 17, 21]
include:
- java_version: '8'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch,!langchain4j-azure-ai-search'
- java_version: '11'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j'
- java_version: '17'

View File

@ -88,6 +88,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.11.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -0,0 +1,366 @@
package dev.langchain4j.rag.content.retriever.azure.search;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.models.SearchIndex;
import com.azure.search.documents.models.*;
import com.azure.search.documents.util.SearchPagedIterable;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException;
import dev.langchain4j.store.embedding.azure.search.Document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
* Represents Azure AI Search Service as a {@link ContentRetriever}.
*
* This class supports 4 {@link AzureAiSearchQueryType}s:
* - VECTOR: Uses the vector search algorithm to find the most similar {@link TextSegment}s.
* See https://learn.microsoft.com/en-us/azure/search/vector-search-overview for more information.
* - FULL_TEXT: Uses the full text search to find the most similar {@link TextSegment}s.
* See https://learn.microsoft.com/en-us/azure/search/search-lucene-query-architecture for more information.
* - HYBRID: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s.
* See https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview for more information.
* - HYBRID_WITH_RERANKING: Uses a hybrid search (vector + full text) to find the most similar {@link TextSegment}s, and uses the semantic re-ranker algorithm to rank the results.
* See https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking for more information.
*/
public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddingStore implements ContentRetriever {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetriever.class);
private EmbeddingModel embeddingModel;
private AzureAiSearchQueryType azureAiSearchQueryType;
private int maxResults;
private double minScore;
public AzureAiSearchContentRetriever(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index, EmbeddingModel embeddingModel, int maxResults, double minScore, AzureAiSearchQueryType azureAiSearchQueryType) {
ensureNotNull(endpoint, "endpoint");
ensureTrue(keyCredential != null || tokenCredential != null, "either keyCredential or tokenCredential must be set");
ensureTrue(dimensions > 0 || index != null, "either dimensions or index must be set");
if (!AzureAiSearchQueryType.FULL_TEXT.equals(azureAiSearchQueryType)) {
ensureNotNull(embeddingModel, "embeddingModel");
}
if (keyCredential == null) {
if (index == null) {
this.initialize(endpoint, null, tokenCredential, dimensions, null);
} else {
this.initialize(endpoint, null, tokenCredential, 0, index);
}
} else {
if (index == null) {
this.initialize(endpoint, keyCredential, null, dimensions, null);
} else {
this.initialize(endpoint, keyCredential, null, 0, index);
}
}
this.embeddingModel = embeddingModel;
this.azureAiSearchQueryType = azureAiSearchQueryType;
this.maxResults = maxResults;
this.minScore = minScore;
}
/**
* Add content to the full text search engine.
*/
public void add(String content) {
add(singletonList(content));
}
/**
* Add a list of content to the full text search engine.
*/
public void add(List<String> contents) {
if (isNullOrEmpty(contents)) {
log.info("Empty embeddings - no ops");
return;
}
List<Document> documents = new ArrayList<>();
for (String content : contents) {
Document document = new Document();
document.setId(randomUUID());
document.setContent(content);
documents.add(document);
}
List<IndexingResult> indexingResults = searchClient.uploadDocuments(documents).getResults();
for (IndexingResult indexingResult : indexingResults) {
if (!indexingResult.isSucceeded()) {
throw new AzureAiSearchRuntimeException("Failed to add content: " + indexingResult.getErrorMessage());
} else {
log.debug("Added content: {}", indexingResult.getKey());
}
}
}
@Override
public List<Content> retrieve(Query query) {
if (azureAiSearchQueryType == AzureAiSearchQueryType.VECTOR) {
Embedding referenceEmbedding = embeddingModel.embed(query.text()).content();
List<EmbeddingMatch<TextSegment>> searchResult = super.findRelevant(referenceEmbedding, maxResults, minScore);
return searchResult.stream()
.map(EmbeddingMatch::embedded)
.map(Content::from)
.collect(toList());
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
String content = query.text();
return findRelevantWithFullText(content, maxResults, minScore);
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.HYBRID) {
Embedding referenceEmbedding = embeddingModel.embed(query.text()).content();
String content = query.text();
return findRelevantWithHybrid(referenceEmbedding, content, maxResults, minScore);
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.HYBRID_WITH_RERANKING) {
Embedding referenceEmbedding = embeddingModel.embed(query.text()).content();
String content = query.text();
return findRelevantWithHybridAndReranking(referenceEmbedding, content, maxResults, minScore);
} else {
throw new AzureAiSearchRuntimeException("Unknown Azure AI Search Query Type: " + azureAiSearchQueryType);
}
}
List<Content> findRelevantWithFullText(String content, int maxResults, double minScore) {
SearchPagedIterable searchResults =
searchClient.search(content,
new SearchOptions()
.setTop(maxResults),
Context.NONE);
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.FULL_TEXT, minScore);
}
List<Content> findRelevantWithHybrid(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
List<Float> vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
.setFields(DEFAULT_FIELD_CONTENT_VECTOR)
.setKNearestNeighborsCount(maxResults);
SearchPagedIterable searchResults =
searchClient.search(content,
new SearchOptions()
.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorizedQuery))
.setTop(maxResults),
Context.NONE);
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.HYBRID, minScore);
}
List<Content> findRelevantWithHybridAndReranking(Embedding referenceEmbedding, String content, int maxResults, double minScore) {
List<Float> vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
.setFields(DEFAULT_FIELD_CONTENT_VECTOR)
.setKNearestNeighborsCount(maxResults);
SearchPagedIterable searchResults =
searchClient.search(content,
new SearchOptions()
.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorizedQuery))
.setSemanticSearchOptions(new SemanticSearchOptions().setSemanticConfigurationName(SEMANTIC_SEARCH_CONFIG_NAME))
.setQueryType(com.azure.search.documents.models.QueryType.SEMANTIC)
.setTop(maxResults),
Context.NONE);
return mapResultsToContentList(searchResults, AzureAiSearchQueryType.HYBRID_WITH_RERANKING, minScore);
}
private List<Content> mapResultsToContentList(SearchPagedIterable searchResults, AzureAiSearchQueryType azureAiSearchQueryType, double minScore) {
List<Content> result = new ArrayList<>();
for (SearchResult searchResult : searchResults) {
double score = fromAzureScoreToRelevanceScore(searchResult, azureAiSearchQueryType);
if (score < minScore) {
continue;
}
SearchDocument searchDocument = searchResult.getDocument(SearchDocument.class);
String embeddedContent = (String) searchDocument.get(DEFAULT_FIELD_CONTENT);
Content content = Content.from(embeddedContent);
result.add(content);
}
return result;
}
/**
* Calculates LangChain4j's RelevanceScore from Azure AI Search's score, for the 4 types of search.
*/
static double fromAzureScoreToRelevanceScore(SearchResult searchResult, AzureAiSearchQueryType azureAiSearchQueryType) {
if (azureAiSearchQueryType == AzureAiSearchQueryType.VECTOR) {
// Calculates LangChain4j's RelevanceScore from Azure AI Search's score.
// Score in Azure AI Search is transformed into a cosine similarity as described here:
// https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
// RelevanceScore in LangChain4j is a derivative of cosine similarity,
// but it compresses it into 0..1 range (instead of -1..1) for ease of use.
double score = searchResult.getScore();
return AbstractAzureAiSearchEmbeddingStore.fromAzureScoreToRelevanceScore(score);
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.FULL_TEXT) {
// Search score is into 0..1 range already
return searchResult.getScore();
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.HYBRID) {
// Search score is into 0..1 range already
return searchResult.getScore();
} else if (azureAiSearchQueryType == AzureAiSearchQueryType.HYBRID_WITH_RERANKING) {
// Re-ranker score is into 0..4 range, so we need to divide the re-reranker score by 4 to fit in the 0..1 range.
// The re-ranker score is a separate result from the original search score.
// See https://azuresdkdocs.blob.core.windows.net/$web/java/azure-search-documents/11.6.2/com/azure/search/documents/models/SearchResult.html#getSemanticSearch()
return searchResult.getSemanticSearch().getRerankerScore() / 4.0;
} else {
throw new AzureAiSearchRuntimeException("Unknown Azure AI Search Query Type: " + azureAiSearchQueryType);
}
}
public static AzureAiSearchContentRetrieverBuilder builder() {
return new AzureAiSearchContentRetrieverBuilder();
}
public static class AzureAiSearchContentRetrieverBuilder {
private String endpoint;
private AzureKeyCredential keyCredential;
private TokenCredential tokenCredential;
private int dimensions;
private SearchIndex index;
private EmbeddingModel embeddingModel;
private int maxResults = EmbeddingStoreContentRetriever.DEFAULT_MAX_RESULTS.apply(null);
private double minScore = EmbeddingStoreContentRetriever.DEFAULT_MIN_SCORE.apply(null);
private AzureAiSearchQueryType azureAiSearchQueryType;
/**
* Sets the Azure AI Search endpoint. This is a mandatory parameter.
*
* @param endpoint The Azure AI Search endpoint in the format: https://{resource}.search.windows.net
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}
/**
* Sets the Azure AI Search API key.
*
* @param apiKey The Azure AI Search API key.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder apiKey(String apiKey) {
this.keyCredential = new AzureKeyCredential(apiKey);
return this;
}
/**
* Used to authenticate to Azure OpenAI with Azure Active Directory credentials.
* @param tokenCredential the credentials to authenticate with Azure Active Directory
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder tokenCredential(TokenCredential tokenCredential) {
this.tokenCredential = tokenCredential;
return this;
}
/**
* If using the ready-made index, sets the number of dimensions of the embeddings.
* This parameter is exclusive of the index parameter.
*
* @param dimensions The number of dimensions of the embeddings.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder dimensions(int dimensions) {
this.dimensions = dimensions;
return this;
}
/**
* If using a custom index, sets the index to be used.
* This parameter is exclusive of the dimensions parameter.
*
* @param index The index to be used.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder index(SearchIndex index) {
this.index = index;
return this;
}
/**
* Sets the Embedding Model.
*
* @param embeddingModel The Embedding Model.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder embeddingModel(EmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
return this;
}
/**
* Sets the maximum number of {@link Content}s to retrieve.
*
* @param maxResults The maximum number of {@link Content}s to retrieve.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder maxResults(int maxResults) {
this.maxResults = maxResults;
return this;
}
/**
* Sets the minimum relevance score for the returned {@link Content}s.
* {@link Content}s scoring below {@code #minScore} are excluded from the results.
*
* @param minScore The minimum relevance score for the returned {@link Content}s.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder minScore(double minScore) {
this.minScore = minScore;
return this;
}
/**
* Sets the Azure AI Search Query Type.
*
* @param azureAiSearchQueryType The Azure AI Search Query Type.
* @return builder
*/
public AzureAiSearchContentRetrieverBuilder queryType(AzureAiSearchQueryType azureAiSearchQueryType) {
this.azureAiSearchQueryType = azureAiSearchQueryType;
return this;
}
public AzureAiSearchContentRetriever build() {
return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, dimensions, index, embeddingModel, maxResults, minScore, azureAiSearchQueryType);
}
}
}

View File

@ -0,0 +1,330 @@
package dev.langchain4j.store.embedding.azure.search;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchClient;
import com.azure.search.documents.SearchClientBuilder;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.*;
import com.azure.search.documents.models.*;
import com.azure.search.documents.util.SearchPagedIterable;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
public abstract class AbstractAzureAiSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final Logger log = LoggerFactory.getLogger(AbstractAzureAiSearchEmbeddingStore.class);
static final String INDEX_NAME = "vectorsearch";
static final String DEFAULT_FIELD_ID = "id";
protected static final String DEFAULT_FIELD_CONTENT = "content";
protected final String DEFAULT_FIELD_CONTENT_VECTOR = "content_vector";
protected static final String DEFAULT_FIELD_METADATA = "metadata";
protected static final String DEFAULT_FIELD_METADATA_SOURCE = "source";
protected static final String DEFAULT_FIELD_METADATA_ATTRS = "attributes";
protected static final String SEMANTIC_SEARCH_CONFIG_NAME = "semantic-search-config";
protected static final String VECTOR_ALGORITHM_NAME = "vector-search-algorithm";
protected static final String VECTOR_SEARCH_PROFILE_NAME = "vector-search-profile";
private SearchIndexClient searchIndexClient;
protected SearchClient searchClient;
protected void initialize(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index) {
if (keyCredential != null) {
searchIndexClient = new SearchIndexClientBuilder()
.endpoint(endpoint)
.credential(keyCredential)
.buildClient();
searchClient = new SearchClientBuilder()
.endpoint(endpoint)
.credential(keyCredential)
.indexName(INDEX_NAME)
.buildClient();
} else {
searchIndexClient = new SearchIndexClientBuilder()
.endpoint(endpoint)
.credential(tokenCredential)
.buildClient();
searchClient = new SearchClientBuilder()
.endpoint(endpoint)
.credential(tokenCredential)
.indexName(INDEX_NAME)
.buildClient();
}
if (index == null) {
createOrUpdateIndex(dimensions);
} else {
createOrUpdateIndex(index);
}
}
/**
* Creates or updates the index using a ready-made index.
*
* @param dimensions The number of dimensions of the embeddings.
*/
public void createOrUpdateIndex(int dimensions) {
ensureTrue(dimensions > 0, "Dimensions must be greater than 0");
List<SearchField> fields = new ArrayList<>();
fields.add(new SearchField(DEFAULT_FIELD_ID, SearchFieldDataType.STRING)
.setKey(true)
.setFilterable(true));
fields.add(new SearchField(DEFAULT_FIELD_CONTENT, SearchFieldDataType.STRING)
.setSearchable(true)
.setFilterable(true));
fields.add(new SearchField(DEFAULT_FIELD_CONTENT_VECTOR, SearchFieldDataType.collection(SearchFieldDataType.SINGLE))
.setSearchable(true)
.setVectorSearchDimensions(dimensions)
.setVectorSearchProfileName(VECTOR_SEARCH_PROFILE_NAME));
fields.add((new SearchField(DEFAULT_FIELD_METADATA, SearchFieldDataType.COMPLEX)).setFields(
Arrays.asList(
new SearchField(DEFAULT_FIELD_METADATA_SOURCE, SearchFieldDataType.STRING)
.setFilterable(true),
(new SearchField(DEFAULT_FIELD_METADATA_ATTRS, SearchFieldDataType.collection(SearchFieldDataType.COMPLEX))).setFields(
Arrays.asList(
new SearchField("key", SearchFieldDataType.STRING)
.setFilterable(true),
new SearchField("value", SearchFieldDataType.STRING)
.setFilterable(true)
)
)
)
));
VectorSearch vectorSearch = new VectorSearch()
.setAlgorithms(Collections.singletonList(
new HnswAlgorithmConfiguration(VECTOR_ALGORITHM_NAME)
.setParameters(
new HnswParameters()
.setMetric(VectorSearchAlgorithmMetric.COSINE)
.setM(4)
.setEfSearch(500)
.setEfConstruction(400))))
.setProfiles(Collections.singletonList(
new VectorSearchProfile(VECTOR_SEARCH_PROFILE_NAME, VECTOR_ALGORITHM_NAME)));
SemanticSearch semanticSearch = new SemanticSearch().setDefaultConfigurationName(SEMANTIC_SEARCH_CONFIG_NAME)
.setConfigurations(singletonList(
new SemanticConfiguration(SEMANTIC_SEARCH_CONFIG_NAME,
new SemanticPrioritizedFields()
.setContentFields(new SemanticField(DEFAULT_FIELD_CONTENT))
.setKeywordsFields(new SemanticField(DEFAULT_FIELD_CONTENT)))));
SearchIndex index = new SearchIndex(INDEX_NAME)
.setFields(fields)
.setVectorSearch(vectorSearch)
.setSemanticSearch(semanticSearch);
searchIndexClient.createOrUpdateIndex(index);
}
/**
* Creates or updates the index, with full control on its configuration.
*
* @param index The index to be created or updated.
*/
void createOrUpdateIndex(SearchIndex index) {
searchIndexClient.createOrUpdateIndex(index);
}
public void deleteIndex() {
searchIndexClient.deleteIndex(INDEX_NAME);
}
/**
* Add an embedding to the store.
*/
@Override
public String add(Embedding embedding) {
String id = randomUUID();
addInternal(id, embedding, null);
return id;
}
/**
* Add an embedding to the store.
*/
@Override
public void add(String id, Embedding embedding) {
addInternal(id, embedding, null);
}
/**
* Add an embedding and the related content to the store.
*/
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = randomUUID();
addInternal(id, embedding, textSegment);
return id;
}
/**
* Add a list of embeddings to the store.
*/
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
addAllInternal(ids, embeddings, null);
return ids;
}
/**
* Add a list of embeddings, and the list of related content, to the store.
*/
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
List<String> ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
addAllInternal(ids, embeddings, embedded);
return ids;
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
List<Float> vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
.setFields(DEFAULT_FIELD_CONTENT_VECTOR)
.setKNearestNeighborsCount(maxResults);
SearchPagedIterable searchResults =
searchClient.search(null,
new SearchOptions()
.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorizedQuery)),
Context.NONE);
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
for (SearchResult searchResult : searchResults) {
Double score = fromAzureScoreToRelevanceScore(searchResult.getScore());
if (score < minScore) {
continue;
}
SearchDocument searchDocument = searchResult.getDocument(SearchDocument.class);
String embeddingId = (String) searchDocument.get(DEFAULT_FIELD_ID);
List<Double> embeddingList = (List<Double>) searchDocument.get(DEFAULT_FIELD_CONTENT_VECTOR);
float[] embeddingArray = doublesListToFloatArray(embeddingList);
Embedding embedding = Embedding.from(embeddingArray);
String embeddedContent = (String) searchDocument.get(DEFAULT_FIELD_CONTENT);
EmbeddingMatch<TextSegment> embeddingMatch;
if (isNotNullOrBlank(embeddedContent)) {
LinkedHashMap metadata = (LinkedHashMap) searchDocument.get(DEFAULT_FIELD_METADATA);
List attributes = (List) metadata.get(DEFAULT_FIELD_METADATA_ATTRS);
Map<String, String> attributesMap = new HashMap<>();
for (Object attribute : attributes) {
LinkedHashMap innerAttribute = (LinkedHashMap) attribute;
String key = (String) innerAttribute.get("key");
String value = (String) innerAttribute.get("value");
attributesMap.put(key, value);
}
Metadata langChainMetadata = Metadata.from(attributesMap);
TextSegment embedded = TextSegment.textSegment(embeddedContent, langChainMetadata);
embeddingMatch = new EmbeddingMatch<>(score, embeddingId, embedding, embedded);
} else {
embeddingMatch = new EmbeddingMatch<>(score, embeddingId, embedding, null);
}
result.add(embeddingMatch);
}
return result;
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAllInternal(
singletonList(id),
singletonList(embedding),
embedded == null ? null : singletonList(embedded));
}
private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
log.info("Empty embeddings - no ops");
return;
}
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(),
"embeddings size is not equal to embedded size");
List<Document> documents = new ArrayList<>();
for (int i = 0; i < ids.size(); ++i) {
Document document = new Document();
document.setId(ids.get(i));
document.setContentVector(embeddings.get(i).vectorAsList());
if (embedded != null) {
document.setContent(embedded.get(i).text());
Document.Metadata metadata = new Document.Metadata();
List<Document.Metadata.Attribute> attributes = new ArrayList<>();
for (Map.Entry<String, String> entry : embedded.get(i).metadata().asMap().entrySet()) {
Document.Metadata.Attribute attribute = new Document.Metadata.Attribute();
attribute.setKey(entry.getKey());
attribute.setValue(entry.getValue());
attributes.add(attribute);
}
metadata.setAttributes(attributes);
document.setMetadata(metadata);
}
documents.add(document);
}
List<IndexingResult> indexingResults = searchClient.uploadDocuments(documents).getResults();
for (IndexingResult indexingResult : indexingResults) {
if (!indexingResult.isSucceeded()) {
throw new AzureAiSearchRuntimeException("Failed to add embedding: " + indexingResult.getErrorMessage());
} else {
log.debug("Added embedding: {}", indexingResult.getKey());
}
}
}
float[] doublesListToFloatArray(List<Double> doubles) {
float[] array = new float[doubles.size()];
for (int i = 0; i < doubles.size(); ++i) {
array[i] = doubles.get(i).floatValue();
}
return array;
}
/**
* Calculates LangChain4j's RelevanceScore from Azure AI Search's score.
* <p>
* Score in Azure AI Search is transformed into a cosine similarity as described here:
* https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
* <p>
* RelevanceScore in LangChain4j is a derivative of cosine similarity,
* but it compresses it into 0..1 range (instead of -1..1) for ease of use.
*/
protected static double fromAzureScoreToRelevanceScore(double score) {
double cosineDistance = (1 - score) / score;
double cosineSimilarity = -cosineDistance + 1;
return RelevanceScore.fromCosineSimilarity(cosineSimilarity);
}
}

View File

@ -2,56 +2,17 @@ package dev.langchain4j.store.embedding.azure.search;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchClient;
import com.azure.search.documents.SearchClientBuilder;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.*;
import com.azure.search.documents.models.*;
import com.azure.search.documents.util.SearchPagedIterable;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import com.azure.search.documents.indexes.models.SearchIndex;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.*;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
/**
* Azure AI Search EmbeddingStore Implementation
*/
public class AzureAiSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStore.class);
private static final String INDEX_NAME = "vectorsearch";
private static final String DEFAULT_FIELD_ID = "id";
private static final String DEFAULT_FIELD_CONTENT = "content";
private static final String DEFAULT_FIELD_CONTENT_VECTOR = "content_vector";
private static final String DEFAULT_FIELD_METADATA = "metadata";
private static final String DEFAULT_FIELD_METADATA_SOURCE = "source";
private static final String DEFAULT_FIELD_METADATA_ATTRS = "attributes";
private SearchIndexClient searchIndexClient;
private SearchClient searchClient;
public class AzureAiSearchEmbeddingStore extends AbstractAzureAiSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, int dimensions) {
this.initialize(endpoint, keyCredential, null, dimensions, null);
@ -69,261 +30,6 @@ public class AzureAiSearchEmbeddingStore implements EmbeddingStore<TextSegment>
this.initialize(endpoint, null, tokenCredential, 0, index);
}
private void initialize(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index) {
if (keyCredential != null) {
searchIndexClient = new SearchIndexClientBuilder()
.endpoint(endpoint)
.credential(keyCredential)
.buildClient();
searchClient = new SearchClientBuilder()
.endpoint(endpoint)
.credential(keyCredential)
.indexName(INDEX_NAME)
.buildClient();
} else {
searchIndexClient = new SearchIndexClientBuilder()
.endpoint(endpoint)
.credential(tokenCredential)
.buildClient();
searchClient = new SearchClientBuilder()
.endpoint(endpoint)
.credential(tokenCredential)
.indexName(INDEX_NAME)
.buildClient();
}
if (index == null) {
createOrUpdateIndex(dimensions);
} else {
createOrUpdateIndex(index);
}
}
/**
* Creates or updates the index using a ready-made index.
* @param dimensions The number of dimensions of the embeddings.
*/
void createOrUpdateIndex(int dimensions) {
List<SearchField> fields = new ArrayList<>();
fields.add(new SearchField(DEFAULT_FIELD_ID, SearchFieldDataType.STRING)
.setKey(true)
.setFilterable(true));
fields.add(new SearchField(DEFAULT_FIELD_CONTENT, SearchFieldDataType.STRING)
.setSearchable(true)
.setFilterable(true));
fields.add(new SearchField(DEFAULT_FIELD_CONTENT_VECTOR, SearchFieldDataType.collection(SearchFieldDataType.SINGLE))
.setSearchable(true)
.setVectorSearchDimensions(dimensions)
.setVectorSearchProfileName("vector-search-profile"));
fields.add((new SearchField(DEFAULT_FIELD_METADATA, SearchFieldDataType.COMPLEX)).setFields(
Arrays.asList(
new SearchField(DEFAULT_FIELD_METADATA_SOURCE, SearchFieldDataType.STRING)
.setFilterable(true),
(new SearchField(DEFAULT_FIELD_METADATA_ATTRS, SearchFieldDataType.collection(SearchFieldDataType.COMPLEX))).setFields(
Arrays.asList(
new SearchField("key", SearchFieldDataType.STRING)
.setFilterable(true),
new SearchField("value", SearchFieldDataType.STRING)
.setFilterable(true)
)
)
)
));
VectorSearch vectorSearch = new VectorSearch()
.setAlgorithms(Collections.singletonList(
new HnswAlgorithmConfiguration("vector-search-algorithm")
.setParameters(
new HnswParameters()
.setMetric(VectorSearchAlgorithmMetric.COSINE)
.setM(4)
.setEfSearch(500)
.setEfConstruction(400))))
.setProfiles(Collections.singletonList(
new VectorSearchProfile("vector-search-profile", "vector-search-algorithm")));
SemanticSearch semanticSearch = new SemanticSearch().setDefaultConfigurationName("semantic-search-config")
.setConfigurations(Arrays.asList(
new SemanticConfiguration("semantic-search-config",
new SemanticPrioritizedFields()
.setContentFields(new SemanticField(DEFAULT_FIELD_CONTENT))
.setKeywordsFields(new SemanticField(DEFAULT_FIELD_CONTENT)))));
SearchIndex index = new SearchIndex(INDEX_NAME)
.setFields(fields)
.setVectorSearch(vectorSearch)
.setSemanticSearch(semanticSearch);
searchIndexClient.createOrUpdateIndex(index);
}
/**
* Creates or updates the index, with full control on its configuration.
* @param index The index to be created or updated.
*/
void createOrUpdateIndex(SearchIndex index) {
searchIndexClient.createOrUpdateIndex(index);
}
public void deleteIndex() {
searchIndexClient.deleteIndex(INDEX_NAME);
}
@Override
public String add(Embedding embedding) {
String id = randomUUID();
addInternal(id, embedding, null);
return id;
}
@Override
public void add(String id, Embedding embedding) {
addInternal(id, embedding, null);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = randomUUID();
addInternal(id, embedding, textSegment);
return id;
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
addAllInternal(ids, embeddings, null);
return ids;
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
List<String> ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
addAllInternal(ids, embeddings, embedded);
return ids;
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
List<Float> vector = referenceEmbedding.vectorAsList();
VectorizedQuery vectorizedQuery = new VectorizedQuery(vector)
.setFields(DEFAULT_FIELD_CONTENT_VECTOR)
.setKNearestNeighborsCount(maxResults);
SearchPagedIterable searchResults =
searchClient.search(null,
new SearchOptions()
.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorizedQuery)),
Context.NONE);
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
for (SearchResult searchResult : searchResults) {
Double score = fromAzureScoreToRelevanceScore(searchResult.getScore());
if (score < minScore) {
continue;
}
SearchDocument searchDocument = searchResult.getDocument(SearchDocument.class);
String embeddingId = (String) searchDocument.get(DEFAULT_FIELD_ID);
List<Double> embeddingList = (List<Double>) searchDocument.get(DEFAULT_FIELD_CONTENT_VECTOR);
float[] embeddingArray = doublesListToFloatArray(embeddingList);
Embedding embedding = Embedding.from(embeddingArray);
String embeddedContent = (String) searchDocument.get(DEFAULT_FIELD_CONTENT);
EmbeddingMatch<TextSegment> embeddingMatch;
if (isNotNullOrBlank(embeddedContent)) {
LinkedHashMap metadata = (LinkedHashMap) searchDocument.get(DEFAULT_FIELD_METADATA);
List attributes = (List) metadata.get(DEFAULT_FIELD_METADATA_ATTRS);
Map<String, String> attributesMap = new HashMap<>();
for (Object attribute : attributes) {
LinkedHashMap innerAttribute = (LinkedHashMap) attribute;
String key = (String) innerAttribute.get("key");
String value = (String) innerAttribute.get("value");
attributesMap.put(key, value);
}
Metadata langChainMetadata = Metadata.from(attributesMap);
TextSegment embedded = TextSegment.textSegment(embeddedContent, langChainMetadata);
embeddingMatch = new EmbeddingMatch<>(score, embeddingId, embedding, embedded);
} else {
embeddingMatch = new EmbeddingMatch<>(score, embeddingId, embedding, null);
}
result.add(embeddingMatch);
}
return result;
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAllInternal(
singletonList(id),
singletonList(embedding),
embedded == null ? null : singletonList(embedded));
}
private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
log.info("Empty embeddings - no ops");
return;
}
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(),
"embeddings size is not equal to embedded size");
List<Document> searchDocuments = new ArrayList<>();
for (int i = 0; i < ids.size(); ++i) {
Document document = new Document();
document.setId(ids.get(i));
document.setContentVector(embeddings.get(i).vectorAsList());
if (embedded != null) {
document.setContent(embedded.get(i).text());
Document.Metadata metadata = new Document.Metadata();
List<Document.Metadata.Attribute> attributes = new ArrayList<>();
for (Map.Entry<String, String> entry : embedded.get(i).metadata().asMap().entrySet()) {
Document.Metadata.Attribute attribute = new Document.Metadata.Attribute();
attribute.setKey(entry.getKey());
attribute.setValue(entry.getValue());
attributes.add(attribute);
}
metadata.setAttributes(attributes);
document.setMetadata(metadata);
}
searchDocuments.add(document);
}
List<IndexingResult> indexingResults = searchClient.uploadDocuments(searchDocuments).getResults();
for (IndexingResult indexingResult : indexingResults) {
if (!indexingResult.isSucceeded()) {
throw new AzureAiSearchRuntimeException("Failed to add embedding: " + indexingResult.getErrorMessage());
} else {
log.debug("Added embedding: {}", indexingResult.getKey());
}
}
}
private float[] doublesListToFloatArray(List<Double> doubles) {
float[] array = new float[doubles.size()];
for (int i = 0; i < doubles.size(); ++i) {
array[i] = doubles.get(i).floatValue();
}
return array;
}
/**
* Calculates LangChain4j's RelevanceScore from Azure AI Search's score.
*
* Score in Azure AI Search is transformed into a cosine similarity as described here:
* https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#scores-in-a-vector-search-results
*
* RelevanceScore in LangChain4j is a derivative of cosine similarity,
* but it compresses it into 0..1 range (instead of -1..1) for ease of use.
*/
private double fromAzureScoreToRelevanceScore(double score) {
double cosineDistance = (1 - score) / score;
double cosineSimilarity = -cosineDistance + 1;
return RelevanceScore.fromCosineSimilarity(cosineSimilarity);
}
public static Builder builder() {
return new Builder();
}

View File

@ -0,0 +1,6 @@
package dev.langchain4j.store.embedding.azure.search;
public enum AzureAiSearchQueryType {
VECTOR, FULL_TEXT, HYBRID, HYBRID_WITH_RERANKING;
}

View File

@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Collection;
class Document {
public class Document {
private String id;

View File

@ -0,0 +1,57 @@
package dev.langchain4j.rag.content.retriever.azure.search;
import com.azure.search.documents.models.SearchResult;
import com.azure.search.documents.models.SemanticSearchResult;
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AzureAiSearchContentRetrieverTest {
@Test
public void testFromAzureScoreToRelevanceScore_VECTOR() {
SearchResult mockResult = mock(SearchResult.class);
when(mockResult.getScore()).thenReturn(0.6);
double result = AzureAiSearchContentRetriever.fromAzureScoreToRelevanceScore(mockResult, AzureAiSearchQueryType.VECTOR);
assertEquals(0.6666666666666666, result);
}
@Test
public void testFromAzureScoreToRelevanceScore_FULL_TEXT() {
SearchResult mockResult = mock(SearchResult.class);
when(mockResult.getScore()).thenReturn(0.4);
double result = AzureAiSearchContentRetriever.fromAzureScoreToRelevanceScore(mockResult, AzureAiSearchQueryType.FULL_TEXT);
assertEquals(0.4, result);
}
@Test
public void testFromAzureScoreToRelevanceScore_HYBRID() {
SearchResult mockResult = mock(SearchResult.class);
when(mockResult.getScore()).thenReturn(0.7);
double result = AzureAiSearchContentRetriever.fromAzureScoreToRelevanceScore(mockResult, AzureAiSearchQueryType.HYBRID);
assertEquals(0.7, result);
}
@Test
public void testFromAzureScoreToRelevanceScore_HYBRID_WITH_RERANKING() {
SearchResult mockResult = mock(SearchResult.class);
SemanticSearchResult mockSemanticSearchResult = mock(SemanticSearchResult.class);
when(mockResult.getSemanticSearch()).thenReturn(mockSemanticSearchResult);
when(mockSemanticSearchResult.getRerankerScore()).thenReturn(1.5);
double result = AzureAiSearchContentRetriever.fromAzureScoreToRelevanceScore(mockResult, AzureAiSearchQueryType.HYBRID_WITH_RERANKING);
assertEquals(0.375, result);
}
}

View File

@ -0,0 +1,280 @@
package dev.langchain4j.rag.content.retriever.azure.search;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.*;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchQueryType;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
@EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+")
public class AzureAiSearchContentRetrieverTestIT extends EmbeddingStoreIT {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverTestIT.class);
private final EmbeddingModel embeddingModel;
private final AzureAiSearchContentRetriever contentRetrieverWithVector;
private final AzureAiSearchContentRetriever contentRetrieverWithFullText;
private final AzureAiSearchContentRetriever contentRetrieverWithHybrid;
private final AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking;
private final int dimensions;
public AzureAiSearchContentRetrieverTestIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
dimensions = embeddingModel.embed("test").content().vector().length;
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
contentRetrieverWithFullText = createContentRetriever(AzureAiSearchQueryType.FULL_TEXT);
contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID);
contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING);
}
private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) {
return AzureAiSearchContentRetriever.builder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
.dimensions(dimensions)
.embeddingModel(embeddingModel)
.queryType(azureAiSearchQueryType)
.maxResults(3)
.minScore(0.0)
.build();
}
@Test
void testAddEmbeddingsAndRetrieveRelevant() {
String content1 = "banana";
String content2 = "computer";
String content3 = "apple";
String content4 = "pizza";
String content5 = "strawberry";
String content6 = "chess";
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
contentRetrieverWithVector.add(embedding, textSegment);
}
awaitUntilPersisted();
String content = "fruit";
Query query = Query.from(content);
List<Content> relevant = contentRetrieverWithVector.retrieve(query);
assertThat(relevant).hasSize(3);
assertThat(relevant.get(0).textSegment()).isNotNull();
assertThat(relevant.get(0).textSegment().text()).isIn(content1, content3, content5);
log.info("#1 relevant item: {}", relevant.get(0).textSegment().text());
assertThat(relevant.get(1).textSegment()).isNotNull();
assertThat(relevant.get(1).textSegment().text()).isIn(content1, content3, content5);
log.info("#2 relevant item: {}", relevant.get(1).textSegment().text());
assertThat(relevant.get(2).textSegment()).isNotNull();
assertThat(relevant.get(2).textSegment().text()).isIn(content1, content3, content5);
log.info("#3 relevant item: {}", relevant.get(2).textSegment().text());
}
@Test
void testAllTypesOfSearch() {
String content1 = "This book is about politics";
String content2 = "Cats sleeps a lot.";
String content3 = "Sandwiches taste good.";
String content4 = "The house is open";
List<String> contents = asList(content1, content2, content3, content4);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
contentRetrieverWithVector.add(embedding, textSegment);
}
awaitUntilPersisted();
String content = "house";
Query query = Query.from(content);
log.info("Testing Vector Search");
List<Content> relevant = contentRetrieverWithVector.retrieve(query);
assertThat(relevant).hasSizeGreaterThan(0);
assertThat(relevant.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant.get(0).textSegment().text());
log.info("Testing Full Text Search");
// This uses the same storage as the vector search, so we don't need to add the content again
List<Content> relevant2 = contentRetrieverWithFullText.retrieve(query);
assertThat(relevant2).hasSizeGreaterThan(0);
assertThat(relevant2.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant2.get(0).textSegment().text());
log.info("Testing Hybrid Search");
List<Content> relevant3 = contentRetrieverWithHybrid.retrieve(query);
assertThat(relevant3).hasSizeGreaterThan(0);
assertThat(relevant3.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant3.get(0).textSegment().text());
log.info("Testing Hybrid Search with Reranking");
List<Content> relevant4 = contentRetrieverWithHybridAndReranking.retrieve(query);
assertThat(relevant4).hasSizeGreaterThan(0);
assertThat(relevant4.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant4.get(0).textSegment().text());
log.info("Test complete");
}
@Test
void testFullTextSearch() {
String content1 = "Émile-Auguste Chartier (3 March 1868 2 June 1951), commonly known as Alain, was a French philosopher, journalist, essayist, pacifist, and teacher of philosophy. He adopted his pseudonym as the most banal he could find. There is no evidence he ever thought in so doing of the 15th century Norman poet Alain Chartier.";
String content2 = "Emmanuel Levinas (12 January 1906 25 December 1995) was a French philosopher of Lithuanian Jewish ancestry who is known for his work within Jewish philosophy, existentialism, and phenomenology, focusing on the relationship of ethics to metaphysics and ontology.";
String content3 = "Maurice Jean Jacques Merleau-Ponty (14 March 1908 3 May 1961) was a French phenomenological philosopher, strongly influenced by Edmund Husserl and Martin Heidegger. The constitution of meaning in human experience was his main interest and he wrote on perception, art, politics, religion, biology, psychology, psychoanalysis, language, nature, and history. He was the lead editor of Les Temps modernes, the leftist magazine he established with Jean-Paul Sartre and Simone de Beauvoir in 1945.";
List<String> contents = asList(content1, content2, content3);
for (String content : contents) {
contentRetrieverWithFullText.add(content);
}
awaitUntilPersisted();
Query query = Query.from("Alain");
List<Content> relevant = contentRetrieverWithHybrid.retrieve(query);
assertThat(relevant).hasSizeGreaterThan(0);
log.info("#1 relevant item: {}", relevant.get(0).textSegment().text());
assertThat(relevant.get(0).textSegment().text()).contains("Émile-Auguste Chartier");
Query query2 = Query.from("Heidegger");
List<Content> relevant2 = contentRetrieverWithHybrid.retrieve(query2);
assertThat(relevant2).hasSizeGreaterThan(0);
log.info("#1 relevant item: {}", relevant2.get(0).textSegment().text());
assertThat(relevant2.get(0).textSegment().text()).contains("Maurice Jean Jacques Merleau-Ponty");
}
@Test
void testAddEmbeddingsAndRetrieveRelevantWithHybrid() {
String content1 = "Albert Camus (7 November 1913 4 January 1960) was a French philosopher, author, dramatist, journalist, world federalist, and political activist. He was the recipient of the 1957 Nobel Prize in Literature at the age of 44, the second-youngest recipient in history. His works include The Stranger, The Plague, The Myth of Sisyphus, The Fall, and The Rebel.\n" +
"\n" +
"Camus was born in Algeria during the French colonization, to pied-noir parents. He spent his childhood in a poor neighbourhood and later studied philosophy at the University of Algiers. He was in Paris when the Germans invaded France during World War II in 1940. Camus tried to flee but finally joined the French Resistance where he served as editor-in-chief at Combat, an outlawed newspaper. After the war, he was a celebrity figure and gave many lectures around the world. He married twice but had many extramarital affairs. Camus was politically active; he was part of the left that opposed Joseph Stalin and the Soviet Union because of their totalitarianism. Camus was a moralist and leaned towards anarcho-syndicalism. He was part of many organisations seeking European integration. During the Algerian War (19541962), he kept a neutral stance, advocating for a multicultural and pluralistic Algeria, a position that was rejected by most parties.\n" +
"\n" +
"Philosophically, Camus' views contributed to the rise of the philosophy known as absurdism. Some consider Camus' work to show him to be an existentialist, even though he himself firmly rejected the term throughout his lifetime.";
String content2 = "Gilles Louis René Deleuze (18 January 1925 4 November 1995) was a French philosopher who, from the early 1950s until his death in 1995, wrote on philosophy, literature, film, and fine art. His most popular works were the two volumes of Capitalism and Schizophrenia: Anti-Oedipus (1972) and A Thousand Plateaus (1980), both co-written with psychoanalyst Félix Guattari. His metaphysical treatise Difference and Repetition (1968) is considered by many scholars to be his magnum opus.\n" +
"\n" +
"An important part of Deleuze's oeuvre is devoted to the reading of other philosophers: the Stoics, Leibniz, Hume, Kant, Nietzsche, and Bergson, with particular influence derived from Spinoza. A. W. Moore, citing Bernard Williams's criteria for a great thinker, ranks Deleuze among the \"greatest philosophers\". Although he once characterized himself as a \"pure metaphysician\", his work has influenced a variety of disciplines across the humanities, including philosophy, art, and literary theory, as well as movements such as post-structuralism and postmodernism.";
String content3 = "Paul-Michel Foucault (15 October 1926 25 June 1984) was a French philosopher, historian of ideas, writer, political activist, and literary critic. Foucault's theories primarily address the relationships between power and knowledge, and how they are used as a form of social control through societal institutions. Though often cited as a structuralist and postmodernist, Foucault rejected these labels. His thought has influenced academics, especially those working in communication studies, anthropology, psychology, sociology, criminology, cultural studies, literary theory, feminism, Marxism and critical theory.\n" +
"\n" +
"Born in Poitiers, France, into an upper-middle-class family, Foucault was educated at the Lycée Henri-IV, at the École Normale Supérieure, where he developed an interest in philosophy and came under the influence of his tutors Jean Hyppolite and Louis Althusser, and at the University of Paris (Sorbonne), where he earned degrees in philosophy and psychology. After several years as a cultural diplomat abroad, he returned to France and published his first major book, The History of Madness (1961). After obtaining work between 1960 and 1966 at the University of Clermont-Ferrand, he produced The Birth of the Clinic (1963) and The Order of Things (1966), publications that displayed his increasing involvement with structuralism, from which he later distanced himself. These first three histories exemplified a historiographical technique Foucault was developing called \"archaeology\".\n" +
"\n" +
"From 1966 to 1968, Foucault lectured at the University of Tunis before returning to France, where he became head of the philosophy department at the new experimental university of Paris VIII. Foucault subsequently published The Archaeology of Knowledge (1969). In 1970, Foucault was admitted to the Collège de France, a membership he retained until his death. He also became active in several left-wing groups involved in campaigns against racism and human rights abuses and for penal reform. Foucault later published Discipline and Punish (1975) and The History of Sexuality (1976), in which he developed archaeological and genealogical methods that emphasized the role that power plays in society.\n" +
"\n" +
"Foucault died in Paris from complications of HIV/AIDS; he became the first public figure in France to die from complications of the disease. His partner Daniel Defert founded the AIDES charity in his memory.";
List<String> contents = asList(content1, content2, content3);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
contentRetrieverWithHybrid.add(embedding, textSegment);
}
awaitUntilPersisted();
Query query = Query.from("Algeria");
List<Content> relevant = contentRetrieverWithHybrid.retrieve(query);
assertThat(relevant).hasSizeGreaterThan(0);
log.info("#1 relevant item: {}", relevant.get(0).textSegment().text());
assertThat(relevant.get(0).textSegment().text()).contains("Albert Camus");
Query query2 = Query.from("École Normale Supérieure");
List<Content> relevant2 = contentRetrieverWithHybrid.retrieve(query2);
assertThat(relevant2).hasSizeGreaterThan(0);
log.info("#1 relevant item: {}", relevant2.get(0).textSegment().text());
assertThat(relevant2.get(0).textSegment().text()).contains("Paul-Michel Foucault");
}
@Test
void testAddEmbeddingsAndRetrieveRelevantWithHybridAndReranking() {
String content1 = "Albert Camus (7 November 1913 4 January 1960) was a French philosopher, author, dramatist, journalist, world federalist, and political activist. He was the recipient of the 1957 Nobel Prize in Literature at the age of 44, the second-youngest recipient in history. His works include The Stranger, The Plague, The Myth of Sisyphus, The Fall, and The Rebel.\n" +
"\n" +
"Camus was born in Algeria during the French colonization, to pied-noir parents. He spent his childhood in a poor neighbourhood and later studied philosophy at the University of Algiers. He was in Paris when the Germans invaded France during World War II in 1940. Camus tried to flee but finally joined the French Resistance where he served as editor-in-chief at Combat, an outlawed newspaper. After the war, he was a celebrity figure and gave many lectures around the world. He married twice but had many extramarital affairs. Camus was politically active; he was part of the left that opposed Joseph Stalin and the Soviet Union because of their totalitarianism. Camus was a moralist and leaned towards anarcho-syndicalism. He was part of many organisations seeking European integration. During the Algerian War (19541962), he kept a neutral stance, advocating for a multicultural and pluralistic Algeria, a position that was rejected by most parties.\n" +
"\n" +
"Philosophically, Camus' views contributed to the rise of the philosophy known as absurdism. Some consider Camus' work to show him to be an existentialist, even though he himself firmly rejected the term throughout his lifetime.";
String content2 = "Gilles Louis René Deleuze (18 January 1925 4 November 1995) was a French philosopher who, from the early 1950s until his death in 1995, wrote on philosophy, literature, film, and fine art. His most popular works were the two volumes of Capitalism and Schizophrenia: Anti-Oedipus (1972) and A Thousand Plateaus (1980), both co-written with psychoanalyst Félix Guattari. His metaphysical treatise Difference and Repetition (1968) is considered by many scholars to be his magnum opus.\n" +
"\n" +
"An important part of Deleuze's oeuvre is devoted to the reading of other philosophers: the Stoics, Leibniz, Hume, Kant, Nietzsche, and Bergson, with particular influence derived from Spinoza. A. W. Moore, citing Bernard Williams's criteria for a great thinker, ranks Deleuze among the \"greatest philosophers\". Although he once characterized himself as a \"pure metaphysician\", his work has influenced a variety of disciplines across the humanities, including philosophy, art, and literary theory, as well as movements such as post-structuralism and postmodernism.";
String content3 = "Paul-Michel Foucault (15 October 1926 25 June 1984) was a French philosopher, historian of ideas, writer, political activist, and literary critic. Foucault's theories primarily address the relationships between power and knowledge, and how they are used as a form of social control through societal institutions. Though often cited as a structuralist and postmodernist, Foucault rejected these labels. His thought has influenced academics, especially those working in communication studies, anthropology, psychology, sociology, criminology, cultural studies, literary theory, feminism, Marxism and critical theory.\n" +
"\n" +
"Born in Poitiers, France, into an upper-middle-class family, Foucault was educated at the Lycée Henri-IV, at the École Normale Supérieure, where he developed an interest in philosophy and came under the influence of his tutors Jean Hyppolite and Louis Althusser, and at the University of Paris (Sorbonne), where he earned degrees in philosophy and psychology. After several years as a cultural diplomat abroad, he returned to France and published his first major book, The History of Madness (1961). After obtaining work between 1960 and 1966 at the University of Clermont-Ferrand, he produced The Birth of the Clinic (1963) and The Order of Things (1966), publications that displayed his increasing involvement with structuralism, from which he later distanced himself. These first three histories exemplified a historiographical technique Foucault was developing called \"archaeology\".\n" +
"\n" +
"From 1966 to 1968, Foucault lectured at the University of Tunis before returning to France, where he became head of the philosophy department at the new experimental university of Paris VIII. Foucault subsequently published The Archaeology of Knowledge (1969). In 1970, Foucault was admitted to the Collège de France, a membership he retained until his death. He also became active in several left-wing groups involved in campaigns against racism and human rights abuses and for penal reform. Foucault later published Discipline and Punish (1975) and The History of Sexuality (1976), in which he developed archaeological and genealogical methods that emphasized the role that power plays in society.\n" +
"\n" +
"Foucault died in Paris from complications of HIV/AIDS; he became the first public figure in France to die from complications of the disease. His partner Daniel Defert founded the AIDES charity in his memory.";
List<String> contents = asList(content1, content2, content3);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
contentRetrieverWithHybridAndReranking.add(embedding, textSegment);
}
awaitUntilPersisted();
Query query = Query.from("A philosopher who was in the French Resistance");
List<Content> relevant = contentRetrieverWithHybridAndReranking.retrieve(query);
assertThat(relevant).hasSizeGreaterThan(0);
log.info("#1 relevant item: {}", relevant.get(0).textSegment().text());
assertThat(relevant.get(0).textSegment().text()).contains("Albert Camus");
Query query2 = Query.from("A philosopher who studied at the École Normale Supérieure");
List<Content> relevant2 = contentRetrieverWithHybridAndReranking.retrieve(query2);
assertThat(relevant2).hasSizeGreaterThan(0);
log.info("#1 relevant item: {}", relevant2.get(0).textSegment().text());
assertThat(relevant2.get(0).textSegment().text()).contains("Paul-Michel Foucault");
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return contentRetrieverWithVector;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Override
protected void clearStore() {
AzureAiSearchContentRetriever azureAiSearchContentRetriever = (AzureAiSearchContentRetriever) contentRetrieverWithVector;
try {
azureAiSearchContentRetriever.deleteIndex();
azureAiSearchContentRetriever.createOrUpdateIndex(dimensions);
} catch (RuntimeException e) {
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
}
}
@Override
protected void awaitUntilPersisted() {
try {
Thread.sleep(1_000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -14,7 +14,6 @@ import java.util.List;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
@EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+")
public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreIT {