Integration with Elastic (#95)
I've done integration with Elastic and do some local test to ensure it's right!(some logic is translated from LangChain Python to Java). Elasticsearch do not support `Gson`. So we must have `Jackson` dependency.
This commit is contained in:
parent
80c3880062
commit
3bffc971df
|
@ -33,6 +33,10 @@ public class Embedding {
|
|||
return list;
|
||||
}
|
||||
|
||||
public int dimensions() {
|
||||
return vector.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package dev.langchain4j.internal;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.UUID;
|
||||
|
||||
public class Utils {
|
||||
|
@ -8,6 +9,10 @@ public class Utils {
|
|||
return string == null || string.trim().isEmpty();
|
||||
}
|
||||
|
||||
public static boolean isCollectionEmpty(Collection<?> collection) {
|
||||
return collection == null || collection.isEmpty();
|
||||
}
|
||||
|
||||
public static String repeat(String string, int times) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < times; i++) {
|
||||
|
|
|
@ -30,6 +30,11 @@ public class ValidationUtils {
|
|||
return string;
|
||||
}
|
||||
|
||||
public static void ensureTrue(boolean expression, String msg) {
|
||||
if (!expression) {
|
||||
throw illegalArgument(msg);
|
||||
}
|
||||
}
|
||||
|
||||
public static int ensureGreaterThanZero(Integer i, String name) {
|
||||
if (i == null || i <= 0) {
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.22.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-elasticsearch</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>LangChain4j integration with Elastic</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>co.elastic.clients</groupId>
|
||||
<artifactId>elasticsearch-java</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,22 @@
|
|||
package dev.langchain4j.store.embedding.elasticsearch;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Elasticsearch document object, for the purpose of construct document object from embedding and text segment
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class Document {
|
||||
|
||||
private float[] vector;
|
||||
private String text;
|
||||
private Map<String, String> metadata;
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
package dev.langchain4j.store.embedding.elasticsearch;
|
||||
|
||||
import co.elastic.clients.elasticsearch.ElasticsearchClient;
|
||||
import co.elastic.clients.elasticsearch._types.InlineScript;
|
||||
import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty;
|
||||
import co.elastic.clients.elasticsearch._types.mapping.Property;
|
||||
import co.elastic.clients.elasticsearch._types.mapping.TextProperty;
|
||||
import co.elastic.clients.elasticsearch._types.mapping.TypeMapping;
|
||||
import co.elastic.clients.elasticsearch._types.query_dsl.Query;
|
||||
import co.elastic.clients.elasticsearch._types.query_dsl.ScriptScoreQuery;
|
||||
import co.elastic.clients.elasticsearch.core.BulkRequest;
|
||||
import co.elastic.clients.elasticsearch.core.BulkResponse;
|
||||
import co.elastic.clients.elasticsearch.core.SearchRequest;
|
||||
import co.elastic.clients.elasticsearch.core.SearchResponse;
|
||||
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
|
||||
import co.elastic.clients.json.JsonData;
|
||||
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
|
||||
import co.elastic.clients.transport.ElasticsearchTransport;
|
||||
import co.elastic.clients.transport.endpoints.BooleanResponse;
|
||||
import co.elastic.clients.transport.rest_client.RestClientTransport;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.ValidationUtils;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import lombok.Builder;
|
||||
import org.apache.http.Header;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.auth.AuthScope;
|
||||
import org.apache.http.auth.UsernamePasswordCredentials;
|
||||
import org.apache.http.client.CredentialsProvider;
|
||||
import org.apache.http.impl.client.BasicCredentialsProvider;
|
||||
import org.apache.http.message.BasicHeader;
|
||||
import org.elasticsearch.client.RestClient;
|
||||
import org.elasticsearch.client.RestClientBuilder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
|
||||
/**
|
||||
* Elastic Embedding Store Implementation
|
||||
*/
|
||||
public class ElasticsearchEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ElasticsearchEmbeddingStoreImpl.class);
|
||||
private final ElasticsearchClient client;
|
||||
private final String indexName;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Builder
|
||||
public ElasticsearchEmbeddingStoreImpl(String serverUrl,
|
||||
String username,
|
||||
String password,
|
||||
String apiKey,
|
||||
String indexName) {
|
||||
serverUrl = ValidationUtils.ensureNotNull(serverUrl, "serverUrl");
|
||||
indexName = ValidationUtils.ensureNotNull(indexName, "indexName");
|
||||
|
||||
RestClientBuilder restClientBuilder = RestClient
|
||||
.builder(HttpHost.create(serverUrl));
|
||||
if (!isNullOrBlank(username)) {
|
||||
CredentialsProvider provider = new BasicCredentialsProvider();
|
||||
provider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(username, password));
|
||||
restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder.setDefaultCredentialsProvider(provider));
|
||||
}
|
||||
if (!isNullOrBlank(apiKey)) {
|
||||
restClientBuilder.setDefaultHeaders(new Header[]{
|
||||
new BasicHeader("Authorization", "Apikey " + apiKey)
|
||||
});
|
||||
}
|
||||
ElasticsearchTransport transport = new RestClientTransport(restClientBuilder.build(), new JacksonJsonpMapper());
|
||||
|
||||
this.client = new ElasticsearchClient(transport);
|
||||
this.indexName = indexName;
|
||||
objectMapper = new ObjectMapper();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
String id = randomUUID();
|
||||
add(id, embedding);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
addInternal(id, embedding, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
String id = randomUUID();
|
||||
addInternal(id, embedding, textSegment);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
List<String> ids = embeddings.stream()
|
||||
.map(ignored -> randomUUID())
|
||||
.collect(Collectors.toList());
|
||||
addAllInternal(ids, embeddings, null);
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
List<String> ids = embeddings.stream()
|
||||
.map(ignored -> randomUUID())
|
||||
.collect(Collectors.toList());
|
||||
addAllInternal(ids, embeddings, embedded);
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||
try {
|
||||
// Use Script Score and cosineSimilarity to calculate
|
||||
// see https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions-cosine
|
||||
ScriptScoreQuery scriptScoreQuery = buildDefaultScriptScoreQuery(referenceEmbedding.vector(), (float) minScore);
|
||||
SearchResponse<Document> response = client.search(
|
||||
SearchRequest.of(s -> s.query(n -> n.scriptScore(scriptScoreQuery)).size(maxResults)), Document.class);
|
||||
|
||||
return toEmbeddingMatch(response);
|
||||
} catch (IOException e) {
|
||||
log.error("[ElasticSearch encounter I/O Exception]", e);
|
||||
throw new ElasticsearchRequestFailedException(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
|
||||
addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
|
||||
}
|
||||
|
||||
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
if (isCollectionEmpty(ids) || isCollectionEmpty(embeddings)) {
|
||||
log.info("[do not add empty embeddings to elasticsearch]");
|
||||
return;
|
||||
}
|
||||
ValidationUtils.ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
|
||||
ValidationUtils.ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");
|
||||
|
||||
try {
|
||||
createIndexIfNotExist(embeddings.get(0).dimensions());
|
||||
|
||||
bulk(ids, embeddings, embedded);
|
||||
} catch (IOException e) {
|
||||
log.error("[ElasticSearch encounter I/O Exception]", e);
|
||||
throw new ElasticsearchRequestFailedException(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void createIndexIfNotExist(int dim) throws IOException {
|
||||
BooleanResponse response = client.indices().exists(c -> c.index(indexName));
|
||||
if (!response.value()) {
|
||||
client.indices().create(c -> c.index(indexName)
|
||||
.mappings(getDefaultMappings(dim)));
|
||||
}
|
||||
}
|
||||
|
||||
private TypeMapping getDefaultMappings(int dim) {
|
||||
// do this like LangChain do
|
||||
Map<String, Property> properties = new HashMap<>(4);
|
||||
properties.put("text", Property.of(p -> p.text(TextProperty.of(t -> t))));
|
||||
properties.put("vector", Property.of(p -> p.denseVector(DenseVectorProperty.of(d -> d.dims(dim)))));
|
||||
return TypeMapping.of(c -> c.properties(properties));
|
||||
}
|
||||
|
||||
private void bulk(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) throws IOException {
|
||||
int size = ids.size();
|
||||
BulkRequest.Builder bulkBuilder = new BulkRequest.Builder();
|
||||
for (int i = 0; i < size; i++) {
|
||||
int finalI = i;
|
||||
Document document = Document.builder()
|
||||
.vector(embeddings.get(i).vector())
|
||||
.text(embedded == null ? null : embedded.get(i).text())
|
||||
.metadata(embedded == null ? null : Optional.ofNullable(embedded.get(i).metadata())
|
||||
.map(Metadata::asMap)
|
||||
.orElse(null))
|
||||
.build();
|
||||
bulkBuilder.operations(op -> op.index(idx -> idx
|
||||
.index(indexName)
|
||||
.id(ids.get(finalI))
|
||||
.document(document)));
|
||||
}
|
||||
|
||||
BulkResponse response = client.bulk(bulkBuilder.build());
|
||||
if (response.errors()) {
|
||||
for (BulkResponseItem item : response.items()) {
|
||||
if (item.error() != null) {
|
||||
throw new ElasticsearchRequestFailedException("type: " + item.error().type() + ", reason: " + item.error().reason());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private ScriptScoreQuery buildDefaultScriptScoreQuery(float[] vector, float minScore) throws JsonProcessingException {
|
||||
JsonData queryVector = toJsonData(vector);
|
||||
return ScriptScoreQuery.of(q -> q
|
||||
.minScore(minScore)
|
||||
.query(Query.of(qu -> qu.matchAll(m -> m)))
|
||||
.script(s -> s.inline(InlineScript.of(i -> i
|
||||
// The script adds 1.0 to the cosine similarity to prevent the score from being negative.
|
||||
// divided by 2 to keep score in the range [0, 1]
|
||||
.source("(cosineSimilarity(params.query_vector, 'vector') + 1.0) / 2")
|
||||
.params("query_vector", queryVector)))));
|
||||
}
|
||||
|
||||
private <T> JsonData toJsonData(T rawData) throws JsonProcessingException {
|
||||
return JsonData.fromJson(objectMapper.writeValueAsString(rawData));
|
||||
}
|
||||
|
||||
private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(SearchResponse<Document> response) {
|
||||
return response.hits().hits().stream()
|
||||
.map(hit -> Optional.ofNullable(hit.source())
|
||||
.map(document -> new EmbeddingMatch<>(hit.score(), hit.id(), new Embedding(document.getVector()),
|
||||
// TextSegment ensure that must have text and metadata
|
||||
document.getText() == null ? null : new TextSegment(document.getText(), new Metadata(document.getMetadata()))))
|
||||
.orElse(null)).collect(Collectors.toList());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package dev.langchain4j.store.embedding.elasticsearch;
|
||||
|
||||
class ElasticsearchRequestFailedException extends RuntimeException {
|
||||
|
||||
public ElasticsearchRequestFailedException() {
|
||||
super();
|
||||
}
|
||||
|
||||
public ElasticsearchRequestFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ElasticsearchRequestFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package dev.langchain4j.store.embedding.elasticsearch;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* disabled default, because this need local deployment of Elasticsearch
|
||||
*/
|
||||
@Disabled
|
||||
class ElasticsearchEmbeddingStoreImplTest {
|
||||
|
||||
private final EmbeddingStore<TextSegment> store = new ElasticsearchEmbeddingStoreImpl(
|
||||
"http://localhost:9200",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
"test-index");
|
||||
|
||||
@Test
|
||||
void testAdd() {
|
||||
// test add without id
|
||||
String id = store.add(Embedding.from(Arrays.asList(0.50f, 0.85f, 0.760f, 0.24f)),
|
||||
TextSegment.from("test string", Metadata.metadata("field", "value")));
|
||||
System.out.println("id=" + id);
|
||||
|
||||
// test add with id
|
||||
String selfId = Utils.randomUUID();
|
||||
store.add(selfId, Embedding.from(Arrays.asList(0.80f, 0.45f, 0.89f, 0.24f)));
|
||||
System.out.println("id=" + selfId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddAll() {
|
||||
// test add All Method without embedded
|
||||
List<String> ids = store.addAll(Arrays.asList(
|
||||
Embedding.from(Arrays.asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
||||
Embedding.from(Arrays.asList(0.54f, 0.34f, 0.67f, 0.24f)),
|
||||
Embedding.from(Arrays.asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
||||
));
|
||||
System.out.println("ids=" + ids);
|
||||
|
||||
// test add all method with embedded
|
||||
ids = store.addAll(Arrays.asList(
|
||||
Embedding.from(Arrays.asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
||||
Embedding.from(Arrays.asList(0.54f, 0.34f, 0.67f, 0.24f)),
|
||||
Embedding.from(Arrays.asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
||||
), Arrays.asList(
|
||||
TextSegment.from("testString1", Metadata.metadata("field1", "value1")),
|
||||
TextSegment.from("testString2", Metadata.metadata("field2", "value2")),
|
||||
TextSegment.from("testingString3", Metadata.metadata("field3", "value3"))
|
||||
));
|
||||
System.out.println("ids=" + ids);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddEmpty() {
|
||||
// see log
|
||||
store.addAll(Collections.emptyList());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddNotEqualSizeEmbeddingAndEmbedded() {
|
||||
Throwable ex = Assertions.assertThrows(IllegalArgumentException.class, () -> store.addAll(Arrays.asList(
|
||||
Embedding.from(Arrays.asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
||||
Embedding.from(Arrays.asList(0.54f, 0.34f, 0.67f, 0.24f, 0.55f)),
|
||||
Embedding.from(Arrays.asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
||||
), Arrays.asList(
|
||||
TextSegment.from("testString1", Metadata.metadata("field1", "value1")),
|
||||
TextSegment.from("testString2", Metadata.metadata("field2", "value2"))
|
||||
)));
|
||||
Assertions.assertEquals("embeddings size is not equal to embedded size", ex.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFindRelevant() {
|
||||
List<EmbeddingMatch<TextSegment>> res = store.findRelevant(Embedding.from(Arrays.asList(0.80f, 0.70f, 0.90f, 0.55f)), 5);
|
||||
res.forEach(System.out::println);
|
||||
}
|
||||
}
|
|
@ -17,6 +17,22 @@
|
|||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<openai4j.version>0.9.0</openai4j.version>
|
||||
<jtokkit.version>0.6.1</jtokkit.version>
|
||||
<lombok.version>1.18.28</lombok.version>
|
||||
<pdfbox.version>2.0.29</pdfbox.version>
|
||||
<jsoup.veresion>1.16.1</jsoup.veresion>
|
||||
<mustache.version>0.9.10</mustache.version>
|
||||
<slf4j-api.version>2.0.7</slf4j-api.version>
|
||||
<gson.version>2.10.1</gson.version>
|
||||
<junit.version>5.9.3</junit.version>
|
||||
<mockito.version>4.11.0</mockito.version>
|
||||
<assertj.version>3.24.2</assertj.version>
|
||||
<tinylog.version>2.6.2</tinylog.version>
|
||||
<spring-boot.version>2.7.14</spring-boot.version>
|
||||
<snakeyaml.version>2.0</snakeyaml.version>
|
||||
<elastic.version>8.9.0</elastic.version>
|
||||
<jackson.version>2.12.7.1</jackson.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
|
@ -25,97 +41,121 @@
|
|||
<dependency>
|
||||
<groupId>dev.ai4j</groupId>
|
||||
<artifactId>openai4j</artifactId>
|
||||
<version>0.9.0</version>
|
||||
<version>${openai4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.knuddels</groupId>
|
||||
<artifactId>jtokkit</artifactId>
|
||||
<version>0.6.1</version>
|
||||
<version>${jtokkit.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>1.18.28</version>
|
||||
<version>${lombok.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.pdfbox</groupId>
|
||||
<artifactId>pdfbox</artifactId>
|
||||
<version>2.0.29</version>
|
||||
<version>${pdfbox.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.jsoup</groupId>
|
||||
<artifactId>jsoup</artifactId>
|
||||
<version>1.16.1</version>
|
||||
<version>${jsoup.veresion}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.spullara.mustache.java</groupId>
|
||||
<artifactId>compiler</artifactId>
|
||||
<version>0.9.10</version>
|
||||
<version>${mustache.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>2.0.7</version>
|
||||
<version>${slf4j-api.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.10.1</version>
|
||||
<version>${gson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<version>5.9.3</version>
|
||||
<version>${junit.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<version>5.9.3</version>
|
||||
<version>${junit.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>4.11.0</version>
|
||||
<version>${mockito.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<version>4.11.0</version>
|
||||
<version>${mockito.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<version>3.24.2</version>
|
||||
<version>${assertj.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>tinylog-impl</artifactId>
|
||||
<version>2.6.2</version>
|
||||
<version>${tinylog.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<version>2.6.2</version>
|
||||
<version>${tinylog.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter</artifactId>
|
||||
<version>2.7.14</version>
|
||||
<version>${spring-boot.version}</version>
|
||||
<exclusions>
|
||||
<!-- due to vulnerabilities -->
|
||||
<exclusion>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>co.elastic.clients</groupId>
|
||||
<artifactId>elasticsearch-java</artifactId>
|
||||
<version>${elastic.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
package dev.langchain4j.store.embedding.elasticsearch;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents a <a href="https://www.elastic.co/">Elasticsearch</a> index as an embedding store.
|
||||
* Current implementation assumes the index uses the cosine distance metric.
|
||||
* To use ElasticsearchEmbeddingStore, please add the "langchain4j-elasticsearch" dependency to your project.
|
||||
*/
|
||||
public class ElasticsearchEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
private final EmbeddingStore<TextSegment> implementation;
|
||||
|
||||
/**
|
||||
* Creates an instance of ElasticsearchEmbeddingStore
|
||||
*
|
||||
* @param serverUrl Elasticsearch Server URL.
|
||||
* @param apiKey apiKey to connect to elasticsearch (optional if elasticsearch is local deployment).
|
||||
* @param indexName The name of the index (e.g., "test").
|
||||
*/
|
||||
public ElasticsearchEmbeddingStore(String serverUrl, String username, String password, String apiKey, String indexName) {
|
||||
try {
|
||||
implementation = loadDynamically(
|
||||
"dev.langchain4j.store.embedding.elasticsearch.ElasticsearchEmbeddingStoreImpl",
|
||||
serverUrl, username, password, apiKey, indexName
|
||||
);
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new RuntimeException(getMessage(), e);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static String getMessage() {
|
||||
return "To use ElasticsearchEmbeddingStore, please add the following dependency to your project:\n\n"
|
||||
+ "Maven:\n"
|
||||
+ "<dependency>\n" +
|
||||
" <groupId>dev.langchain4j</groupId>\n" +
|
||||
" <artifactId>langchain4j-elasticsearch</artifactId>\n" +
|
||||
" <version>0.20.0</version>\n" +
|
||||
"</dependency>\n\n"
|
||||
+ "Gradle:\n"
|
||||
+ "implementation 'dev.langchain4j:langchain4j-elasticsearch:0.20.0'\n";
|
||||
}
|
||||
|
||||
private static EmbeddingStore<TextSegment> loadDynamically(String implementationClassName, String serverUrl, String username, String password, String apiKey, String indexName) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
|
||||
Class<?> implementationClass = Class.forName(implementationClassName);
|
||||
Class<?>[] constructorParameterTypes = new Class<?>[]{String.class, String.class, String.class, String.class, String.class};
|
||||
Constructor<?> constructor = implementationClass.getConstructor(constructorParameterTypes);
|
||||
return (EmbeddingStore<TextSegment>) constructor.newInstance(serverUrl, username, password, apiKey, indexName);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
return implementation.add(embedding);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
implementation.add(id, embedding);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
return implementation.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
return implementation.addAll(embeddings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||
return implementation.addAll(embeddings, textSegments);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
||||
return implementation.findRelevant(referenceEmbedding, maxResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||
return implementation.findRelevant(referenceEmbedding, maxResults, minScore);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String serverUrl;
|
||||
private String username;
|
||||
private String password;
|
||||
private String apiKey;
|
||||
private String indexName;
|
||||
|
||||
/**
|
||||
* @param serverUrl Elasticsearch Server URL
|
||||
*/
|
||||
public Builder serverUrl(String serverUrl) {
|
||||
this.serverUrl = serverUrl;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param username Elasticsearch username
|
||||
* @return builder
|
||||
*/
|
||||
public Builder username(String username) {
|
||||
this.username = username;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param password Elasticsearch password
|
||||
* @return builder
|
||||
*/
|
||||
public Builder password(String password) {
|
||||
this.password = password;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param apikey Elasticsearch apikey
|
||||
* @return builder
|
||||
*/
|
||||
public Builder apikey(String apikey) {
|
||||
this.apiKey = apikey;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param indexName The name of the index (e.g., "test").
|
||||
*/
|
||||
public Builder indexName(String indexName) {
|
||||
this.indexName = indexName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ElasticsearchEmbeddingStore build() {
|
||||
return new ElasticsearchEmbeddingStore(serverUrl, username, password, apiKey, indexName);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue