diff --git a/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStore.java b/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStore.java index f87846cdf..f9f17ca12 100644 --- a/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStore.java +++ b/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStore.java @@ -1,15 +1,11 @@ package dev.langchain4j.store.embedding.opensearch; +import com.fasterxml.jackson.core.JsonProcessingException; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; - -import com.fasterxml.jackson.core.JsonProcessingException; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -38,6 +34,9 @@ import org.opensearch.client.transport.endpoints.BooleanResponse; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; import java.io.IOException; import java.net.URISyntaxException; @@ -65,11 +64,11 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { /** * Creates an instance of OpenSearchEmbeddingStore to connect with - * OpenSearch clusters running locally and network reacheable. + * OpenSearch clusters running locally and network reachable. * * @param serverUrl OpenSearch Server URL. * @param apiKey OpenSearch API key (optional) - * @param userName OpenSearch user name (optional) + * @param userName OpenSearch username (optional) * @param password OpenSearch password (optional) * @param indexName OpenSearch index name (optional). Default value: "default" */ @@ -78,7 +77,6 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { String userName, String password, String indexName) { - HttpHost openSearchHost; try { openSearchHost = HttpHost.create(serverUrl); @@ -88,36 +86,31 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { } OpenSearchTransport transport = ApacheHttpClient5TransportBuilder - .builder(openSearchHost) - .setMapper(new JacksonJsonpMapper()) - .setHttpClientConfigCallback(httpClientBuilder -> { + .builder(openSearchHost) + .setMapper(new JacksonJsonpMapper()) + .setHttpClientConfigCallback(httpClientBuilder -> { - if (!isNullOrBlank(apiKey)) { - httpClientBuilder.setDefaultHeaders(singletonList( - new BasicHeader("Authorization", "ApiKey " + apiKey) - )); - } + if (!isNullOrBlank(apiKey)) { + httpClientBuilder.setDefaultHeaders(singletonList( + new BasicHeader("Authorization", "ApiKey " + apiKey) + )); + } - if (!isNullOrBlank(userName) && !isNullOrBlank(password)) { - BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials(new AuthScope(openSearchHost), - new UsernamePasswordCredentials(userName, password.toCharArray())); - httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); - } + if (!isNullOrBlank(userName) && !isNullOrBlank(password)) { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(new AuthScope(openSearchHost), + new UsernamePasswordCredentials(userName, password.toCharArray())); + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + } - httpClientBuilder.setConnectionManager( - PoolingAsyncClientConnectionManagerBuilder - .create() - .build()); + httpClientBuilder.setConnectionManager(PoolingAsyncClientConnectionManagerBuilder.create().build()); - return httpClientBuilder; + return httpClientBuilder; + }) + .build(); - }) - .build(); - this.client = new OpenSearchClient(transport); this.indexName = ensureNotNull(indexName, "indexName"); - } /** @@ -139,13 +132,10 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { Region selectedRegion = Region.of(region); SdkHttpClient httpClient = ApacheHttpClient.builder().build(); - OpenSearchTransport transport = new AwsSdk2Transport( - httpClient, serverUrl, serviceName, selectedRegion, options - ); + OpenSearchTransport transport = new AwsSdk2Transport(httpClient, serverUrl, serviceName, selectedRegion, options); this.client = new OpenSearchClient(transport); this.indexName = ensureNotNull(indexName, "indexName"); - } public static Builder builder() { @@ -205,12 +195,9 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { public OpenSearchEmbeddingStore build() { if (!isNullOrBlank(serviceName) && !isNullOrBlank(region) && options != null) { - return new OpenSearchEmbeddingStore( - serverUrl, serviceName, region, options, indexName - ); + return new OpenSearchEmbeddingStore(serverUrl, serviceName, region, options, indexName); } - return new OpenSearchEmbeddingStore( - serverUrl, apiKey, userName, password, indexName); + return new OpenSearchEmbeddingStore(serverUrl, apiKey, userName, password, indexName); } } @@ -258,14 +245,14 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { */ @Override public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { - List> matches = null; + List> matches; try { ScriptScoreQuery scriptScoreQuery = buildDefaultScriptScoreQuery(referenceEmbedding.vector(), (float) minScore); SearchResponse response = client.search( - SearchRequest.of(s -> s.index(indexName) - .query(n -> n.scriptScore(scriptScoreQuery)) - .size(maxResults)), - Document.class + SearchRequest.of(s -> s.index(indexName) + .query(n -> n.scriptScore(scriptScoreQuery)) + .size(maxResults)), + Document.class ); matches = toEmbeddingMatch(response); } catch (IOException ex) { @@ -278,21 +265,20 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { private ScriptScoreQuery buildDefaultScriptScoreQuery(float[] vector, float minScore) throws JsonProcessingException { return ScriptScoreQuery.of(q -> q.minScore(minScore) - .query(Query.of(qu -> qu.matchAll(m -> m))) - .script(s -> s.inline(InlineScript.of(i -> i - .source("knn_score") - .lang("knn") - .params("field", JsonData.of("vector")) - .params("query_value", JsonData.of(vector)) - .params("space_type", JsonData.of("cosinesimil"))))) - .boost(0.5f)); - - // ===> From the OpenSearch documentation: - // "Cosine similarity returns a number between -1 and 1, and because OpenSearch - // relevance scores can't be below 0, the k-NN plugin adds 1 to get the final score." - // See https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script - // Thus, the query applies a boost of `0.5` to keep score in the range [0, 1] + .query(Query.of(qu -> qu.matchAll(m -> m))) + .script(s -> s.inline(InlineScript.of(i -> i + .source("knn_score") + .lang("knn") + .params("field", JsonData.of("vector")) + .params("query_value", JsonData.of(vector)) + .params("space_type", JsonData.of("cosinesimil"))))) + .boost(0.5f)); + // ===> From the OpenSearch documentation: + // "Cosine similarity returns a number between -1 and 1, and because OpenSearch + // relevance scores can't be below 0, the k-NN plugin adds 1 to get the final score." + // See https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script + // Thus, the query applies a boost of `0.5` to keep score in the range [0, 1] } private void addInternal(String id, Embedding embedding, TextSegment embedded) { @@ -316,16 +302,15 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { log.error("[I/O OpenSearch Exception]", ex); throw new OpenSearchRequestFailedException(ex.getMessage()); } - } private void createIndexIfNotExist(int dimension) throws IOException { BooleanResponse response = client.indices().exists(c -> c.index(indexName)); if (!response.value()) { client.indices() - .create(c -> c.index(indexName) - .settings(s -> s.knn(true)) - .mappings(getDefaultMappings(dimension))); + .create(c -> c.index(indexName) + .settings(s -> s.knn(true)) + .mappings(getDefaultMappings(dimension))); } } @@ -333,7 +318,7 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { Map properties = new HashMap<>(4); properties.put("text", Property.of(p -> p.text(TextProperty.of(t -> t)))); properties.put("vector", Property.of(p -> p.knnVector( - k -> k.dimension(dimension) + k -> k.dimension(dimension) ))); return TypeMapping.of(c -> c.properties(properties)); } @@ -349,14 +334,14 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { .vector(embeddings.get(i).vector()) .text(embedded == null ? null : embedded.get(i).text()) .metadata(embedded == null ? null : Optional.ofNullable(embedded.get(i).metadata()) - .map(Metadata::asMap) - .orElse(null)) + .map(Metadata::asMap) + .orElse(null)) .build(); bulkBuilder.operations(op -> op.index( - idx -> idx - .index(indexName) - .id(ids.get(finalI)) - .document(document) + idx -> idx + .index(indexName) + .id(ids.get(finalI)) + .document(document) )); } @@ -368,27 +353,25 @@ public class OpenSearchEmbeddingStore implements EmbeddingStore { ErrorCause errorCause = item.error(); if (errorCause != null) { throw new OpenSearchRequestFailedException( - "type: " + errorCause.type() + "," + - "reason: " + errorCause.reason()); + "type: " + errorCause.type() + "," + + "reason: " + errorCause.reason()); } } } } - } private List> toEmbeddingMatch(SearchResponse response) { return response.hits().hits().stream() - .map(hit -> Optional.ofNullable(hit.source()) - .map(document -> new EmbeddingMatch<>( - hit.score(), - hit.id(), - new Embedding(document.getVector()), - document.getText() == null - ? null - : TextSegment.from(document.getText(), new Metadata(document.getMetadata())) - )).orElse(null)) - .collect(toList()); + .map(hit -> Optional.ofNullable(hit.source()) + .map(document -> new EmbeddingMatch<>( + hit.score(), + hit.id(), + new Embedding(document.getVector()), + document.getText() == null + ? null + : TextSegment.from(document.getText(), new Metadata(document.getMetadata())) + )).orElse(null)) + .collect(toList()); } - } diff --git a/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchRequestFailedException.java b/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchRequestFailedException.java index a053d9342..97cec95d5 100644 --- a/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchRequestFailedException.java +++ b/langchain4j-opensearch/src/main/java/dev/langchain4j/store/embedding/opensearch/OpenSearchRequestFailedException.java @@ -13,5 +13,4 @@ class OpenSearchRequestFailedException extends RuntimeException { public OpenSearchRequestFailedException(String message, Throwable cause) { super(message, cause); } - } diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java index 1df7f0212..8e709b1a2 100644 --- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java +++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java @@ -1,12 +1,5 @@ package dev.langchain4j.store.embedding.opensearch; -import java.util.List; -import java.util.UUID; - -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; - import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -16,8 +9,14 @@ import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.RelevanceScore; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import java.util.List; +import java.util.UUID; + import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; @@ -35,16 +34,16 @@ public class OpenSearchEmbeddingStoreAWSTest { private final ProfileCredentialsProvider credentials = ProfileCredentialsProvider.create("default"); private final AwsSdk2TransportOptions transportOptions = AwsSdk2TransportOptions.builder() - .setCredentials(credentials) - .build(); + .setCredentials(credentials) + .build(); private final EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder() - .serverUrl(domainEndpoint) - .serviceName("es") - .region("us-east-1") - .options(transportOptions) - .indexName(randomUUID()) - .build(); + .serverUrl(domainEndpoint) + .serviceName("es") + .region("us-east-1") + .options(transportOptions) + .indexName(randomUUID()) + .build(); private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @@ -65,7 +64,6 @@ public class OpenSearchEmbeddingStoreAWSTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isNull(); - } @Test @@ -85,7 +83,6 @@ public class OpenSearchEmbeddingStoreAWSTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isNull(); - } @Test @@ -107,7 +104,6 @@ public class OpenSearchEmbeddingStoreAWSTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isEqualTo(segment); - } @Test @@ -129,7 +125,6 @@ public class OpenSearchEmbeddingStoreAWSTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isEqualTo(segment); - } @Test @@ -157,7 +152,6 @@ public class OpenSearchEmbeddingStoreAWSTest { assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); assertThat(secondMatch.embedded()).isNull(); - } @Test @@ -190,7 +184,6 @@ public class OpenSearchEmbeddingStoreAWSTest { assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } @Test @@ -240,7 +233,6 @@ public class OpenSearchEmbeddingStoreAWSTest { ); assertThat(relevant4).hasSize(1); assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } @Test @@ -263,11 +255,9 @@ public class OpenSearchEmbeddingStoreAWSTest { RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), withPercentage(1) ); - } public static String randomUUID() { return UUID.randomUUID().toString(); } - } diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java index 1b55882a8..4faec4900 100644 --- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java +++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java @@ -9,7 +9,6 @@ import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.RelevanceScore; - import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -35,12 +34,12 @@ class OpenSearchEmbeddingStoreLocalTest { @Container private static final OpensearchContainer opensearch = - new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0")); + new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0")); private final EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder() - .serverUrl(opensearch.getHttpHostAddress()) - .indexName(randomUUID()) - .build(); + .serverUrl(opensearch.getHttpHostAddress()) + .indexName(randomUUID()) + .build(); private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @@ -66,7 +65,6 @@ class OpenSearchEmbeddingStoreLocalTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isNull(); - } @Test @@ -86,7 +84,6 @@ class OpenSearchEmbeddingStoreLocalTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isNull(); - } @Test @@ -108,7 +105,6 @@ class OpenSearchEmbeddingStoreLocalTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isEqualTo(segment); - } @Test @@ -130,7 +126,6 @@ class OpenSearchEmbeddingStoreLocalTest { assertThat(match.embeddingId()).isEqualTo(id); assertThat(match.embedding()).isEqualTo(embedding); assertThat(match.embedded()).isEqualTo(segment); - } @Test @@ -158,7 +153,6 @@ class OpenSearchEmbeddingStoreLocalTest { assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); assertThat(secondMatch.embedded()).isNull(); - } @Test @@ -191,7 +185,6 @@ class OpenSearchEmbeddingStoreLocalTest { assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } @Test @@ -241,7 +234,6 @@ class OpenSearchEmbeddingStoreLocalTest { ); assertThat(relevant4).hasSize(1); assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } @Test @@ -264,7 +256,5 @@ class OpenSearchEmbeddingStoreLocalTest { RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), withPercentage(1) ); - } - }