Weaviate support (#57)
Authored-by: Titov, Alexey <alexey.titov@adesso.de>
This commit is contained in:
parent
7b4706d279
commit
d45ddbfc7c
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"printWidth": 120
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-parent</artifactId>
|
||||||
|
<version>0.18.0</version>
|
||||||
|
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
|
||||||
|
<artifactId>langchain4j-weaviate</artifactId>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<name>LangChain4j integration with Weaviate</name>
|
||||||
|
<description>Uses io.weaviate.client library which has a BSD 3-Clause license:
|
||||||
|
https://github.com/weaviate/java-client/blob/main/LICENSE
|
||||||
|
</description>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-core</artifactId>
|
||||||
|
<version>0.18.0</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.weaviate</groupId>
|
||||||
|
<artifactId>client</artifactId>
|
||||||
|
<version>4.2.0</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<licenses>
|
||||||
|
<license>
|
||||||
|
<name>Apache-2.0</name>
|
||||||
|
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||||
|
<distribution>repo</distribution>
|
||||||
|
<comments>A business-friendly OSS license</comments>
|
||||||
|
</license>
|
||||||
|
</licenses>
|
||||||
|
|
||||||
|
</project>
|
|
@ -0,0 +1,215 @@
|
||||||
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
import static java.util.Collections.singletonList;
|
||||||
|
import static java.util.stream.Collectors.joining;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import io.weaviate.client.Config;
|
||||||
|
import io.weaviate.client.WeaviateAuthClient;
|
||||||
|
import io.weaviate.client.WeaviateClient;
|
||||||
|
import io.weaviate.client.base.Result;
|
||||||
|
import io.weaviate.client.base.WeaviateErrorMessage;
|
||||||
|
import io.weaviate.client.v1.auth.exception.AuthException;
|
||||||
|
import io.weaviate.client.v1.data.model.WeaviateObject;
|
||||||
|
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
|
||||||
|
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||||
|
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
|
||||||
|
import io.weaviate.client.v1.graphql.query.fields.Field;
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
public class WeaviateEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
|
||||||
|
|
||||||
|
private static final String DEFAULT_CLASS = "Default";
|
||||||
|
private static final Double DEFAULT_MIN_CERTAINTY = 0.0;
|
||||||
|
private static final String METADATA_TEXT_SEGMENT = "text";
|
||||||
|
private static final String ADDITIONALS = "_additional";
|
||||||
|
|
||||||
|
private final WeaviateClient client;
|
||||||
|
private final String objectClass;
|
||||||
|
private boolean avoidDups = true;
|
||||||
|
private String consistencyLevel = ConsistencyLevel.QUORUM;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public WeaviateEmbeddingStoreImpl(
|
||||||
|
String apiKey,
|
||||||
|
String scheme,
|
||||||
|
String host,
|
||||||
|
String objectClass,
|
||||||
|
boolean avoidDups,
|
||||||
|
String consistencyLevel
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
client = WeaviateAuthClient.apiKey(new Config(scheme, host), apiKey);
|
||||||
|
} catch (AuthException e) {
|
||||||
|
throw new IllegalArgumentException(e);
|
||||||
|
}
|
||||||
|
this.objectClass = objectClass != null ? objectClass : DEFAULT_CLASS;
|
||||||
|
this.avoidDups = avoidDups;
|
||||||
|
this.consistencyLevel = consistencyLevel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String add(Embedding embedding) {
|
||||||
|
String id = generateRandomId();
|
||||||
|
add(id, embedding);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(String id, Embedding embedding) {
|
||||||
|
addAll(singletonList(id), singletonList(embedding), null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String add(Embedding embedding, TextSegment textSegment) {
|
||||||
|
return addAll(singletonList(embedding), singletonList(textSegment)).stream().findFirst().orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> addAll(List<Embedding> embeddings) {
|
||||||
|
return addAll(embeddings, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||||
|
return addAll(null, embeddings, embedded);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
||||||
|
return findRelevant(referenceEmbedding, maxResults, DEFAULT_MIN_CERTAINTY);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<EmbeddingMatch<TextSegment>> findRelevant(
|
||||||
|
Embedding referenceEmbedding,
|
||||||
|
int maxResults,
|
||||||
|
double minCertainty
|
||||||
|
) {
|
||||||
|
Result<GraphQLResponse> result = client
|
||||||
|
.graphQL()
|
||||||
|
.get()
|
||||||
|
.withClassName(objectClass)
|
||||||
|
.withFields(
|
||||||
|
Field.builder().name(METADATA_TEXT_SEGMENT).build(),
|
||||||
|
Field
|
||||||
|
.builder()
|
||||||
|
.name(ADDITIONALS)
|
||||||
|
.fields(
|
||||||
|
Field.builder().name("id").build(),
|
||||||
|
Field.builder().name("certainty").build(),
|
||||||
|
Field.builder().name("vector").build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.withNearVector(
|
||||||
|
NearVectorArgument
|
||||||
|
.builder()
|
||||||
|
.vector(referenceEmbedding.vectorAsList().toArray(new Float[0]))
|
||||||
|
.certainty((float) minCertainty)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.withLimit(maxResults)
|
||||||
|
.run();
|
||||||
|
|
||||||
|
if (result.hasErrors()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(joining("\n"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<Map.Entry<String, Map>> resGetPart =
|
||||||
|
((Map<String, Map>) result.getResult().getData()).entrySet().stream().findFirst();
|
||||||
|
if (!resGetPart.isPresent()) {
|
||||||
|
return emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional resItemsPart = resGetPart.get().getValue().entrySet().stream().findFirst();
|
||||||
|
if (!resItemsPart.isPresent()) {
|
||||||
|
return emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Map<String, ?>> resItems = ((Map.Entry<String, List<Map<String, ?>>>) resItemsPart.get()).getValue();
|
||||||
|
|
||||||
|
return resItems.stream().map(WeaviateEmbeddingStoreImpl::toEmbeddingMatch).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||||
|
if (embedded != null && embeddings.size() != embedded.size()) {
|
||||||
|
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> resIds = new ArrayList<>();
|
||||||
|
List<WeaviateObject> objects = new ArrayList<>();
|
||||||
|
for (int i = 0; i < embeddings.size(); i++) {
|
||||||
|
String id = ids != null
|
||||||
|
? ids.get(i)
|
||||||
|
: avoidDups && embedded != null ? generateUUID(embedded.get(i).text()) : generateRandomId();
|
||||||
|
resIds.add(id);
|
||||||
|
objects.add(buildObject(id, embeddings.get(i), embedded != null ? embedded.get(i).text() : null));
|
||||||
|
}
|
||||||
|
|
||||||
|
client
|
||||||
|
.batch()
|
||||||
|
.objectsBatcher()
|
||||||
|
.withObjects(objects.toArray(new WeaviateObject[0]))
|
||||||
|
.withConsistencyLevel(consistencyLevel)
|
||||||
|
.run();
|
||||||
|
|
||||||
|
return resIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
private WeaviateObject buildObject(String id, Embedding embedding, String text) {
|
||||||
|
WeaviateObject.WeaviateObjectBuilder builder = WeaviateObject
|
||||||
|
.builder()
|
||||||
|
.className(objectClass)
|
||||||
|
.id(id)
|
||||||
|
.vector(embedding.vectorAsList().toArray(new Float[0]));
|
||||||
|
|
||||||
|
if (text != null) {
|
||||||
|
Map<String, Object> props = new HashMap<>();
|
||||||
|
props.put(METADATA_TEXT_SEGMENT, text);
|
||||||
|
|
||||||
|
builder.properties(props);
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
|
||||||
|
Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS);
|
||||||
|
|
||||||
|
return new EmbeddingMatch<>(
|
||||||
|
(String) additional.get("id"),
|
||||||
|
Embedding.from(
|
||||||
|
((List<Double>) additional.get("vector")).stream().map(Double::floatValue).collect(Collectors.toList())
|
||||||
|
),
|
||||||
|
TextSegment.from((String) item.get(METADATA_TEXT_SEGMENT)),
|
||||||
|
(Double) additional.get("certainty")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO this shall be migrated to some common place
|
||||||
|
private static String generateUUID(String input) {
|
||||||
|
try {
|
||||||
|
byte[] hashBytes = MessageDigest.getInstance("SHA-256").digest(input.getBytes(UTF_8));
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (byte b : hashBytes) sb.append(String.format("%02x", b));
|
||||||
|
return UUID.nameUUIDFromBytes(sb.toString().getBytes(UTF_8)).toString();
|
||||||
|
} catch (NoSuchAlgorithmException e) {
|
||||||
|
throw new IllegalArgumentException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO this shall be migrated to some common place
|
||||||
|
private static String generateRandomId() {
|
||||||
|
return UUID.randomUUID().toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,144 @@
|
||||||
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import java.lang.reflect.Constructor;
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||||
|
|
||||||
|
private final EmbeddingStore<TextSegment> implementation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new WeaviateEmbeddingStore instance.
|
||||||
|
*
|
||||||
|
* @param apiKey your Weaviate API key
|
||||||
|
* @param scheme the scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster.
|
||||||
|
* @param host the host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL.
|
||||||
|
* Find in under Details of your Weaviate cluster.
|
||||||
|
* @param objectClass the object class you want to store, e.g. "MyGreatClass"
|
||||||
|
* @param avoidDups if true (default), then <code>WeaviateEmbeddingStore</code> will generate a hashed ID based on
|
||||||
|
* provided text segment, which avoids duplicated entries in DB.
|
||||||
|
* If false, then random ID will be generated.
|
||||||
|
* @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details <a href="https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-write-consistency">here</a>.
|
||||||
|
*/
|
||||||
|
@Builder
|
||||||
|
public WeaviateEmbeddingStore(
|
||||||
|
String apiKey,
|
||||||
|
String scheme,
|
||||||
|
String host,
|
||||||
|
String objectClass,
|
||||||
|
boolean avoidDups,
|
||||||
|
String consistencyLevel
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
implementation =
|
||||||
|
loadDynamically(
|
||||||
|
"dev.langchain4j.store.embedding.WeaviateEmbeddingStoreImpl",
|
||||||
|
apiKey,
|
||||||
|
scheme,
|
||||||
|
host,
|
||||||
|
objectClass,
|
||||||
|
avoidDups,
|
||||||
|
consistencyLevel
|
||||||
|
);
|
||||||
|
} catch (ClassNotFoundException e) {
|
||||||
|
throw new RuntimeException(getMessage(), e);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getMessage() {
|
||||||
|
return (
|
||||||
|
"To use WeaviateEmbeddingStore, please add the following dependency to your project:\n\n" +
|
||||||
|
"Maven:\n" +
|
||||||
|
"<dependency>\n" +
|
||||||
|
" <groupId>dev.langchain4j</groupId>\n" +
|
||||||
|
" <artifactId>langchain4j-weaviate</artifactId>\n" +
|
||||||
|
" <version>0.18.0</version>\n" +
|
||||||
|
"</dependency>\n\n" +
|
||||||
|
"Gradle:\n" +
|
||||||
|
"implementation 'dev.langchain4j:langchain4j-weaviate:0.18.0'\n"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static EmbeddingStore<TextSegment> loadDynamically(
|
||||||
|
String implementationClassName,
|
||||||
|
String apiKey,
|
||||||
|
String scheme,
|
||||||
|
String host,
|
||||||
|
String objectClass,
|
||||||
|
boolean avoidDups,
|
||||||
|
String consistencyLevel
|
||||||
|
)
|
||||||
|
throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
|
||||||
|
Class<?> implementationClass = Class.forName(implementationClassName);
|
||||||
|
Class<?>[] constructorParameterTypes = new Class<?>[] {
|
||||||
|
String.class,
|
||||||
|
String.class,
|
||||||
|
String.class,
|
||||||
|
String.class,
|
||||||
|
boolean.class,
|
||||||
|
String.class,
|
||||||
|
};
|
||||||
|
Constructor<?> constructor = implementationClass.getConstructor(constructorParameterTypes);
|
||||||
|
return (EmbeddingStore<TextSegment>) constructor.newInstance(
|
||||||
|
apiKey,
|
||||||
|
scheme,
|
||||||
|
host,
|
||||||
|
objectClass,
|
||||||
|
avoidDups,
|
||||||
|
consistencyLevel
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String add(Embedding embedding) {
|
||||||
|
return implementation.add(embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a new embedding with provided ID to the store.
|
||||||
|
*
|
||||||
|
* @param id the ID of the embedding to add in UUID format, since it's Weaviate requirement.
|
||||||
|
* See <a href="https://weaviate.io/developers/weaviate/manage-data/create#id">Weaviate docs</a> and
|
||||||
|
* <a href="https://en.wikipedia.org/wiki/Universally_unique_identifier">UUID on Wikipedia</a>
|
||||||
|
* @param embedding the embedding to add
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void add(String id, Embedding embedding) {
|
||||||
|
implementation.add(id, embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String add(Embedding embedding, TextSegment textSegment) {
|
||||||
|
return implementation.add(embedding, textSegment);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> addAll(List<Embedding> embeddings) {
|
||||||
|
return implementation.addAll(embeddings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||||
|
return implementation.addAll(embeddings, textSegments);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
||||||
|
return implementation.findRelevant(referenceEmbedding, maxResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<EmbeddingMatch<TextSegment>> findRelevant(
|
||||||
|
Embedding referenceEmbedding,
|
||||||
|
int maxResults,
|
||||||
|
double minSimilarity
|
||||||
|
) {
|
||||||
|
return implementation.findRelevant(referenceEmbedding, maxResults, minSimilarity);
|
||||||
|
}
|
||||||
|
}
|
1
pom.xml
1
pom.xml
|
@ -16,6 +16,7 @@
|
||||||
<module>langchain4j</module>
|
<module>langchain4j</module>
|
||||||
|
|
||||||
<module>langchain4j-pinecone</module>
|
<module>langchain4j-pinecone</module>
|
||||||
|
<module>langchain4j-weaviate</module>
|
||||||
|
|
||||||
<module>langchain4j-embeddings</module>
|
<module>langchain4j-embeddings</module>
|
||||||
<module>langchain4j-embeddings-all-minilm-l6-v2</module>
|
<module>langchain4j-embeddings-all-minilm-l6-v2</module>
|
||||||
|
|
Loading…
Reference in New Issue