Weaviate support (#57)
Authored-by: Titov, Alexey <alexey.titov@adesso.de>
This commit is contained in:
parent
7b4706d279
commit
d45ddbfc7c
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"printWidth": 120
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.18.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
<artifactId>langchain4j-weaviate</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>LangChain4j integration with Weaviate</name>
|
||||
<description>Uses io.weaviate.client library which has a BSD 3-Clause license:
|
||||
https://github.com/weaviate/java-client/blob/main/LICENSE
|
||||
</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.18.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.weaviate</groupId>
|
||||
<artifactId>client</artifactId>
|
||||
<version>4.2.0</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,215 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import io.weaviate.client.Config;
|
||||
import io.weaviate.client.WeaviateAuthClient;
|
||||
import io.weaviate.client.WeaviateClient;
|
||||
import io.weaviate.client.base.Result;
|
||||
import io.weaviate.client.base.WeaviateErrorMessage;
|
||||
import io.weaviate.client.v1.auth.exception.AuthException;
|
||||
import io.weaviate.client.v1.data.model.WeaviateObject;
|
||||
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
|
||||
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
|
||||
import io.weaviate.client.v1.graphql.query.fields.Field;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Builder;
|
||||
|
||||
public class WeaviateEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
|
||||
|
||||
private static final String DEFAULT_CLASS = "Default";
|
||||
private static final Double DEFAULT_MIN_CERTAINTY = 0.0;
|
||||
private static final String METADATA_TEXT_SEGMENT = "text";
|
||||
private static final String ADDITIONALS = "_additional";
|
||||
|
||||
private final WeaviateClient client;
|
||||
private final String objectClass;
|
||||
private boolean avoidDups = true;
|
||||
private String consistencyLevel = ConsistencyLevel.QUORUM;
|
||||
|
||||
@Builder
|
||||
public WeaviateEmbeddingStoreImpl(
|
||||
String apiKey,
|
||||
String scheme,
|
||||
String host,
|
||||
String objectClass,
|
||||
boolean avoidDups,
|
||||
String consistencyLevel
|
||||
) {
|
||||
try {
|
||||
client = WeaviateAuthClient.apiKey(new Config(scheme, host), apiKey);
|
||||
} catch (AuthException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
this.objectClass = objectClass != null ? objectClass : DEFAULT_CLASS;
|
||||
this.avoidDups = avoidDups;
|
||||
this.consistencyLevel = consistencyLevel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
String id = generateRandomId();
|
||||
add(id, embedding);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
addAll(singletonList(id), singletonList(embedding), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
return addAll(singletonList(embedding), singletonList(textSegment)).stream().findFirst().orElse(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
return addAll(embeddings, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
return addAll(null, embeddings, embedded);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
||||
return findRelevant(referenceEmbedding, maxResults, DEFAULT_MIN_CERTAINTY);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(
|
||||
Embedding referenceEmbedding,
|
||||
int maxResults,
|
||||
double minCertainty
|
||||
) {
|
||||
Result<GraphQLResponse> result = client
|
||||
.graphQL()
|
||||
.get()
|
||||
.withClassName(objectClass)
|
||||
.withFields(
|
||||
Field.builder().name(METADATA_TEXT_SEGMENT).build(),
|
||||
Field
|
||||
.builder()
|
||||
.name(ADDITIONALS)
|
||||
.fields(
|
||||
Field.builder().name("id").build(),
|
||||
Field.builder().name("certainty").build(),
|
||||
Field.builder().name("vector").build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
.withNearVector(
|
||||
NearVectorArgument
|
||||
.builder()
|
||||
.vector(referenceEmbedding.vectorAsList().toArray(new Float[0]))
|
||||
.certainty((float) minCertainty)
|
||||
.build()
|
||||
)
|
||||
.withLimit(maxResults)
|
||||
.run();
|
||||
|
||||
if (result.hasErrors()) {
|
||||
throw new IllegalArgumentException(
|
||||
result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(joining("\n"))
|
||||
);
|
||||
}
|
||||
|
||||
Optional<Map.Entry<String, Map>> resGetPart =
|
||||
((Map<String, Map>) result.getResult().getData()).entrySet().stream().findFirst();
|
||||
if (!resGetPart.isPresent()) {
|
||||
return emptyList();
|
||||
}
|
||||
|
||||
Optional resItemsPart = resGetPart.get().getValue().entrySet().stream().findFirst();
|
||||
if (!resItemsPart.isPresent()) {
|
||||
return emptyList();
|
||||
}
|
||||
|
||||
List<Map<String, ?>> resItems = ((Map.Entry<String, List<Map<String, ?>>>) resItemsPart.get()).getValue();
|
||||
|
||||
return resItems.stream().map(WeaviateEmbeddingStoreImpl::toEmbeddingMatch).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
if (embedded != null && embeddings.size() != embedded.size()) {
|
||||
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
|
||||
}
|
||||
|
||||
List<String> resIds = new ArrayList<>();
|
||||
List<WeaviateObject> objects = new ArrayList<>();
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
String id = ids != null
|
||||
? ids.get(i)
|
||||
: avoidDups && embedded != null ? generateUUID(embedded.get(i).text()) : generateRandomId();
|
||||
resIds.add(id);
|
||||
objects.add(buildObject(id, embeddings.get(i), embedded != null ? embedded.get(i).text() : null));
|
||||
}
|
||||
|
||||
client
|
||||
.batch()
|
||||
.objectsBatcher()
|
||||
.withObjects(objects.toArray(new WeaviateObject[0]))
|
||||
.withConsistencyLevel(consistencyLevel)
|
||||
.run();
|
||||
|
||||
return resIds;
|
||||
}
|
||||
|
||||
private WeaviateObject buildObject(String id, Embedding embedding, String text) {
|
||||
WeaviateObject.WeaviateObjectBuilder builder = WeaviateObject
|
||||
.builder()
|
||||
.className(objectClass)
|
||||
.id(id)
|
||||
.vector(embedding.vectorAsList().toArray(new Float[0]));
|
||||
|
||||
if (text != null) {
|
||||
Map<String, Object> props = new HashMap<>();
|
||||
props.put(METADATA_TEXT_SEGMENT, text);
|
||||
|
||||
builder.properties(props);
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
|
||||
Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS);
|
||||
|
||||
return new EmbeddingMatch<>(
|
||||
(String) additional.get("id"),
|
||||
Embedding.from(
|
||||
((List<Double>) additional.get("vector")).stream().map(Double::floatValue).collect(Collectors.toList())
|
||||
),
|
||||
TextSegment.from((String) item.get(METADATA_TEXT_SEGMENT)),
|
||||
(Double) additional.get("certainty")
|
||||
);
|
||||
}
|
||||
|
||||
// TODO this shall be migrated to some common place
|
||||
private static String generateUUID(String input) {
|
||||
try {
|
||||
byte[] hashBytes = MessageDigest.getInstance("SHA-256").digest(input.getBytes(UTF_8));
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (byte b : hashBytes) sb.append(String.format("%02x", b));
|
||||
return UUID.nameUUIDFromBytes(sb.toString().getBytes(UTF_8)).toString();
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO this shall be migrated to some common place
|
||||
private static String generateRandomId() {
|
||||
return UUID.randomUUID().toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.List;
|
||||
import lombok.Builder;
|
||||
|
||||
public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
private final EmbeddingStore<TextSegment> implementation;
|
||||
|
||||
/**
|
||||
* Creates a new WeaviateEmbeddingStore instance.
|
||||
*
|
||||
* @param apiKey your Weaviate API key
|
||||
* @param scheme the scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster.
|
||||
* @param host the host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL.
|
||||
* Find in under Details of your Weaviate cluster.
|
||||
* @param objectClass the object class you want to store, e.g. "MyGreatClass"
|
||||
* @param avoidDups if true (default), then <code>WeaviateEmbeddingStore</code> will generate a hashed ID based on
|
||||
* provided text segment, which avoids duplicated entries in DB.
|
||||
* If false, then random ID will be generated.
|
||||
* @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details <a href="https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-write-consistency">here</a>.
|
||||
*/
|
||||
@Builder
|
||||
public WeaviateEmbeddingStore(
|
||||
String apiKey,
|
||||
String scheme,
|
||||
String host,
|
||||
String objectClass,
|
||||
boolean avoidDups,
|
||||
String consistencyLevel
|
||||
) {
|
||||
try {
|
||||
implementation =
|
||||
loadDynamically(
|
||||
"dev.langchain4j.store.embedding.WeaviateEmbeddingStoreImpl",
|
||||
apiKey,
|
||||
scheme,
|
||||
host,
|
||||
objectClass,
|
||||
avoidDups,
|
||||
consistencyLevel
|
||||
);
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new RuntimeException(getMessage(), e);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static String getMessage() {
|
||||
return (
|
||||
"To use WeaviateEmbeddingStore, please add the following dependency to your project:\n\n" +
|
||||
"Maven:\n" +
|
||||
"<dependency>\n" +
|
||||
" <groupId>dev.langchain4j</groupId>\n" +
|
||||
" <artifactId>langchain4j-weaviate</artifactId>\n" +
|
||||
" <version>0.18.0</version>\n" +
|
||||
"</dependency>\n\n" +
|
||||
"Gradle:\n" +
|
||||
"implementation 'dev.langchain4j:langchain4j-weaviate:0.18.0'\n"
|
||||
);
|
||||
}
|
||||
|
||||
private static EmbeddingStore<TextSegment> loadDynamically(
|
||||
String implementationClassName,
|
||||
String apiKey,
|
||||
String scheme,
|
||||
String host,
|
||||
String objectClass,
|
||||
boolean avoidDups,
|
||||
String consistencyLevel
|
||||
)
|
||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
|
||||
Class<?> implementationClass = Class.forName(implementationClassName);
|
||||
Class<?>[] constructorParameterTypes = new Class<?>[] {
|
||||
String.class,
|
||||
String.class,
|
||||
String.class,
|
||||
String.class,
|
||||
boolean.class,
|
||||
String.class,
|
||||
};
|
||||
Constructor<?> constructor = implementationClass.getConstructor(constructorParameterTypes);
|
||||
return (EmbeddingStore<TextSegment>) constructor.newInstance(
|
||||
apiKey,
|
||||
scheme,
|
||||
host,
|
||||
objectClass,
|
||||
avoidDups,
|
||||
consistencyLevel
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
return implementation.add(embedding);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a new embedding with provided ID to the store.
|
||||
*
|
||||
* @param id the ID of the embedding to add in UUID format, since it's Weaviate requirement.
|
||||
* See <a href="https://weaviate.io/developers/weaviate/manage-data/create#id">Weaviate docs</a> and
|
||||
* <a href="https://en.wikipedia.org/wiki/Universally_unique_identifier">UUID on Wikipedia</a>
|
||||
* @param embedding the embedding to add
|
||||
*/
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
implementation.add(id, embedding);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
return implementation.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
return implementation.addAll(embeddings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||
return implementation.addAll(embeddings, textSegments);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
||||
return implementation.findRelevant(referenceEmbedding, maxResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(
|
||||
Embedding referenceEmbedding,
|
||||
int maxResults,
|
||||
double minSimilarity
|
||||
) {
|
||||
return implementation.findRelevant(referenceEmbedding, maxResults, minSimilarity);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue