OpenSearch: reformatted code

This commit is contained in:
deep-learning-dynamo 2023-10-09 12:16:42 +02:00
parent 6f0c962108
commit cf276e844c
4 changed files with 85 additions and 123 deletions

View File

@ -1,15 +1,11 @@
package dev.langchain4j.store.embedding.opensearch;
import com.fasterxml.jackson.core.JsonProcessingException;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.hc.client5.http.auth.AuthScope;
import org.apache.hc.client5.http.auth.UsernamePasswordCredentials;
import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider;
@ -38,6 +34,9 @@ import org.opensearch.client.transport.endpoints.BooleanResponse;
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import java.io.IOException;
import java.net.URISyntaxException;
@ -65,11 +64,11 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* Creates an instance of OpenSearchEmbeddingStore to connect with
* OpenSearch clusters running locally and network reacheable.
* OpenSearch clusters running locally and network reachable.
*
* @param serverUrl OpenSearch Server URL.
* @param apiKey OpenSearch API key (optional)
* @param userName OpenSearch user name (optional)
* @param userName OpenSearch username (optional)
* @param password OpenSearch password (optional)
* @param indexName OpenSearch index name (optional). Default value: "default"
*/
@ -78,7 +77,6 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
String userName,
String password,
String indexName) {
HttpHost openSearchHost;
try {
openSearchHost = HttpHost.create(serverUrl);
@ -88,36 +86,31 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
}
OpenSearchTransport transport = ApacheHttpClient5TransportBuilder
.builder(openSearchHost)
.setMapper(new JacksonJsonpMapper())
.setHttpClientConfigCallback(httpClientBuilder -> {
.builder(openSearchHost)
.setMapper(new JacksonJsonpMapper())
.setHttpClientConfigCallback(httpClientBuilder -> {
if (!isNullOrBlank(apiKey)) {
httpClientBuilder.setDefaultHeaders(singletonList(
new BasicHeader("Authorization", "ApiKey " + apiKey)
));
}
if (!isNullOrBlank(apiKey)) {
httpClientBuilder.setDefaultHeaders(singletonList(
new BasicHeader("Authorization", "ApiKey " + apiKey)
));
}
if (!isNullOrBlank(userName) && !isNullOrBlank(password)) {
BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(new AuthScope(openSearchHost),
new UsernamePasswordCredentials(userName, password.toCharArray()));
httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
if (!isNullOrBlank(userName) && !isNullOrBlank(password)) {
BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(new AuthScope(openSearchHost),
new UsernamePasswordCredentials(userName, password.toCharArray()));
httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
httpClientBuilder.setConnectionManager(
PoolingAsyncClientConnectionManagerBuilder
.create()
.build());
httpClientBuilder.setConnectionManager(PoolingAsyncClientConnectionManagerBuilder.create().build());
return httpClientBuilder;
return httpClientBuilder;
})
.build();
})
.build();
this.client = new OpenSearchClient(transport);
this.indexName = ensureNotNull(indexName, "indexName");
}
/**
@ -139,13 +132,10 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
Region selectedRegion = Region.of(region);
SdkHttpClient httpClient = ApacheHttpClient.builder().build();
OpenSearchTransport transport = new AwsSdk2Transport(
httpClient, serverUrl, serviceName, selectedRegion, options
);
OpenSearchTransport transport = new AwsSdk2Transport(httpClient, serverUrl, serviceName, selectedRegion, options);
this.client = new OpenSearchClient(transport);
this.indexName = ensureNotNull(indexName, "indexName");
}
public static Builder builder() {
@ -205,12 +195,9 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
public OpenSearchEmbeddingStore build() {
if (!isNullOrBlank(serviceName) && !isNullOrBlank(region) && options != null) {
return new OpenSearchEmbeddingStore(
serverUrl, serviceName, region, options, indexName
);
return new OpenSearchEmbeddingStore(serverUrl, serviceName, region, options, indexName);
}
return new OpenSearchEmbeddingStore(
serverUrl, apiKey, userName, password, indexName);
return new OpenSearchEmbeddingStore(serverUrl, apiKey, userName, password, indexName);
}
}
@ -258,14 +245,14 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
*/
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
List<EmbeddingMatch<TextSegment>> matches = null;
List<EmbeddingMatch<TextSegment>> matches;
try {
ScriptScoreQuery scriptScoreQuery = buildDefaultScriptScoreQuery(referenceEmbedding.vector(), (float) minScore);
SearchResponse<Document> response = client.search(
SearchRequest.of(s -> s.index(indexName)
.query(n -> n.scriptScore(scriptScoreQuery))
.size(maxResults)),
Document.class
SearchRequest.of(s -> s.index(indexName)
.query(n -> n.scriptScore(scriptScoreQuery))
.size(maxResults)),
Document.class
);
matches = toEmbeddingMatch(response);
} catch (IOException ex) {
@ -278,21 +265,20 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
private ScriptScoreQuery buildDefaultScriptScoreQuery(float[] vector, float minScore) throws JsonProcessingException {
return ScriptScoreQuery.of(q -> q.minScore(minScore)
.query(Query.of(qu -> qu.matchAll(m -> m)))
.script(s -> s.inline(InlineScript.of(i -> i
.source("knn_score")
.lang("knn")
.params("field", JsonData.of("vector"))
.params("query_value", JsonData.of(vector))
.params("space_type", JsonData.of("cosinesimil")))))
.boost(0.5f));
// ===> From the OpenSearch documentation:
// "Cosine similarity returns a number between -1 and 1, and because OpenSearch
// relevance scores can't be below 0, the k-NN plugin adds 1 to get the final score."
// See https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script
// Thus, the query applies a boost of `0.5` to keep score in the range [0, 1]
.query(Query.of(qu -> qu.matchAll(m -> m)))
.script(s -> s.inline(InlineScript.of(i -> i
.source("knn_score")
.lang("knn")
.params("field", JsonData.of("vector"))
.params("query_value", JsonData.of(vector))
.params("space_type", JsonData.of("cosinesimil")))))
.boost(0.5f));
// ===> From the OpenSearch documentation:
// "Cosine similarity returns a number between -1 and 1, and because OpenSearch
// relevance scores can't be below 0, the k-NN plugin adds 1 to get the final score."
// See https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script
// Thus, the query applies a boost of `0.5` to keep score in the range [0, 1]
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
@ -316,16 +302,15 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
log.error("[I/O OpenSearch Exception]", ex);
throw new OpenSearchRequestFailedException(ex.getMessage());
}
}
private void createIndexIfNotExist(int dimension) throws IOException {
BooleanResponse response = client.indices().exists(c -> c.index(indexName));
if (!response.value()) {
client.indices()
.create(c -> c.index(indexName)
.settings(s -> s.knn(true))
.mappings(getDefaultMappings(dimension)));
.create(c -> c.index(indexName)
.settings(s -> s.knn(true))
.mappings(getDefaultMappings(dimension)));
}
}
@ -333,7 +318,7 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
Map<String, Property> properties = new HashMap<>(4);
properties.put("text", Property.of(p -> p.text(TextProperty.of(t -> t))));
properties.put("vector", Property.of(p -> p.knnVector(
k -> k.dimension(dimension)
k -> k.dimension(dimension)
)));
return TypeMapping.of(c -> c.properties(properties));
}
@ -349,14 +334,14 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
.vector(embeddings.get(i).vector())
.text(embedded == null ? null : embedded.get(i).text())
.metadata(embedded == null ? null : Optional.ofNullable(embedded.get(i).metadata())
.map(Metadata::asMap)
.orElse(null))
.map(Metadata::asMap)
.orElse(null))
.build();
bulkBuilder.operations(op -> op.index(
idx -> idx
.index(indexName)
.id(ids.get(finalI))
.document(document)
idx -> idx
.index(indexName)
.id(ids.get(finalI))
.document(document)
));
}
@ -368,27 +353,25 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
ErrorCause errorCause = item.error();
if (errorCause != null) {
throw new OpenSearchRequestFailedException(
"type: " + errorCause.type() + "," +
"reason: " + errorCause.reason());
"type: " + errorCause.type() + "," +
"reason: " + errorCause.reason());
}
}
}
}
}
private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(SearchResponse<Document> response) {
return response.hits().hits().stream()
.map(hit -> Optional.ofNullable(hit.source())
.map(document -> new EmbeddingMatch<>(
hit.score(),
hit.id(),
new Embedding(document.getVector()),
document.getText() == null
? null
: TextSegment.from(document.getText(), new Metadata(document.getMetadata()))
)).orElse(null))
.collect(toList());
.map(hit -> Optional.ofNullable(hit.source())
.map(document -> new EmbeddingMatch<>(
hit.score(),
hit.id(),
new Embedding(document.getVector()),
document.getText() == null
? null
: TextSegment.from(document.getText(), new Metadata(document.getMetadata()))
)).orElse(null))
.collect(toList());
}
}

View File

@ -13,5 +13,4 @@ class OpenSearchRequestFailedException extends RuntimeException {
public OpenSearchRequestFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -1,12 +1,5 @@
package dev.langchain4j.store.embedding.opensearch;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
@ -16,8 +9,14 @@ import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import java.util.List;
import java.util.UUID;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
@ -35,16 +34,16 @@ public class OpenSearchEmbeddingStoreAWSTest {
private final ProfileCredentialsProvider credentials = ProfileCredentialsProvider.create("default");
private final AwsSdk2TransportOptions transportOptions = AwsSdk2TransportOptions.builder()
.setCredentials(credentials)
.build();
.setCredentials(credentials)
.build();
private final EmbeddingStore<TextSegment> embeddingStore = OpenSearchEmbeddingStore.builder()
.serverUrl(domainEndpoint)
.serviceName("es")
.region("us-east-1")
.options(transportOptions)
.indexName(randomUUID())
.build();
.serverUrl(domainEndpoint)
.serviceName("es")
.region("us-east-1")
.options(transportOptions)
.indexName(randomUUID())
.build();
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@ -65,7 +64,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isNull();
}
@Test
@ -85,7 +83,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isNull();
}
@Test
@ -107,7 +104,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}
@Test
@ -129,7 +125,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}
@Test
@ -157,7 +152,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isNull();
}
@Test
@ -190,7 +184,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
}
@Test
@ -240,7 +233,6 @@ public class OpenSearchEmbeddingStoreAWSTest {
);
assertThat(relevant4).hasSize(1);
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
}
@Test
@ -263,11 +255,9 @@ public class OpenSearchEmbeddingStoreAWSTest {
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
withPercentage(1)
);
}
public static String randomUUID() {
return UUID.randomUUID().toString();
}
}

View File

@ -9,7 +9,6 @@ import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@ -35,12 +34,12 @@ class OpenSearchEmbeddingStoreLocalTest {
@Container
private static final OpensearchContainer opensearch =
new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0"));
new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0"));
private final EmbeddingStore<TextSegment> embeddingStore = OpenSearchEmbeddingStore.builder()
.serverUrl(opensearch.getHttpHostAddress())
.indexName(randomUUID())
.build();
.serverUrl(opensearch.getHttpHostAddress())
.indexName(randomUUID())
.build();
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@ -66,7 +65,6 @@ class OpenSearchEmbeddingStoreLocalTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isNull();
}
@Test
@ -86,7 +84,6 @@ class OpenSearchEmbeddingStoreLocalTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isNull();
}
@Test
@ -108,7 +105,6 @@ class OpenSearchEmbeddingStoreLocalTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}
@Test
@ -130,7 +126,6 @@ class OpenSearchEmbeddingStoreLocalTest {
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}
@Test
@ -158,7 +153,6 @@ class OpenSearchEmbeddingStoreLocalTest {
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isNull();
}
@Test
@ -191,7 +185,6 @@ class OpenSearchEmbeddingStoreLocalTest {
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
}
@Test
@ -241,7 +234,6 @@ class OpenSearchEmbeddingStoreLocalTest {
);
assertThat(relevant4).hasSize(1);
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
}
@Test
@ -264,7 +256,5 @@ class OpenSearchEmbeddingStoreLocalTest {
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
withPercentage(1)
);
}
}