Weaviate support (#57)

Authored-by: Titov, Alexey <alexey.titov@adesso.de>
This commit is contained in:
Heezer 2023-08-06 17:03:39 +02:00 committed by GitHub
parent 7b4706d279
commit d45ddbfc7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 411 additions and 0 deletions

3
.prettierrc Normal file
View File

@ -0,0 +1,3 @@
{
"printWidth": 120
}

View File

@ -0,0 +1,48 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.18.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-weaviate</artifactId>
<packaging>jar</packaging>
<name>LangChain4j integration with Weaviate</name>
<description>Uses io.weaviate.client library which has a BSD 3-Clause license:
https://github.com/weaviate/java-client/blob/main/LICENSE
</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.18.0</version>
</dependency>
<dependency>
<groupId>io.weaviate</groupId>
<artifactId>client</artifactId>
<version>4.2.0</version>
</dependency>
</dependencies>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,215 @@
package dev.langchain4j.store.embedding;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.joining;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.v1.auth.exception.AuthException;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.fields.Field;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.stream.Collectors;
import lombok.Builder;
public class WeaviateEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
private static final String DEFAULT_CLASS = "Default";
private static final Double DEFAULT_MIN_CERTAINTY = 0.0;
private static final String METADATA_TEXT_SEGMENT = "text";
private static final String ADDITIONALS = "_additional";
private final WeaviateClient client;
private final String objectClass;
private boolean avoidDups = true;
private String consistencyLevel = ConsistencyLevel.QUORUM;
@Builder
public WeaviateEmbeddingStoreImpl(
String apiKey,
String scheme,
String host,
String objectClass,
boolean avoidDups,
String consistencyLevel
) {
try {
client = WeaviateAuthClient.apiKey(new Config(scheme, host), apiKey);
} catch (AuthException e) {
throw new IllegalArgumentException(e);
}
this.objectClass = objectClass != null ? objectClass : DEFAULT_CLASS;
this.avoidDups = avoidDups;
this.consistencyLevel = consistencyLevel;
}
@Override
public String add(Embedding embedding) {
String id = generateRandomId();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
addAll(singletonList(id), singletonList(embedding), null);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
return addAll(singletonList(embedding), singletonList(textSegment)).stream().findFirst().orElse(null);
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
return addAll(embeddings, null);
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
return addAll(null, embeddings, embedded);
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
return findRelevant(referenceEmbedding, maxResults, DEFAULT_MIN_CERTAINTY);
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(
Embedding referenceEmbedding,
int maxResults,
double minCertainty
) {
Result<GraphQLResponse> result = client
.graphQL()
.get()
.withClassName(objectClass)
.withFields(
Field.builder().name(METADATA_TEXT_SEGMENT).build(),
Field
.builder()
.name(ADDITIONALS)
.fields(
Field.builder().name("id").build(),
Field.builder().name("certainty").build(),
Field.builder().name("vector").build()
)
.build()
)
.withNearVector(
NearVectorArgument
.builder()
.vector(referenceEmbedding.vectorAsList().toArray(new Float[0]))
.certainty((float) minCertainty)
.build()
)
.withLimit(maxResults)
.run();
if (result.hasErrors()) {
throw new IllegalArgumentException(
result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(joining("\n"))
);
}
Optional<Map.Entry<String, Map>> resGetPart =
((Map<String, Map>) result.getResult().getData()).entrySet().stream().findFirst();
if (!resGetPart.isPresent()) {
return emptyList();
}
Optional resItemsPart = resGetPart.get().getValue().entrySet().stream().findFirst();
if (!resItemsPart.isPresent()) {
return emptyList();
}
List<Map<String, ?>> resItems = ((Map.Entry<String, List<Map<String, ?>>>) resItemsPart.get()).getValue();
return resItems.stream().map(WeaviateEmbeddingStoreImpl::toEmbeddingMatch).collect(Collectors.toList());
}
private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (embedded != null && embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<String> resIds = new ArrayList<>();
List<WeaviateObject> objects = new ArrayList<>();
for (int i = 0; i < embeddings.size(); i++) {
String id = ids != null
? ids.get(i)
: avoidDups && embedded != null ? generateUUID(embedded.get(i).text()) : generateRandomId();
resIds.add(id);
objects.add(buildObject(id, embeddings.get(i), embedded != null ? embedded.get(i).text() : null));
}
client
.batch()
.objectsBatcher()
.withObjects(objects.toArray(new WeaviateObject[0]))
.withConsistencyLevel(consistencyLevel)
.run();
return resIds;
}
private WeaviateObject buildObject(String id, Embedding embedding, String text) {
WeaviateObject.WeaviateObjectBuilder builder = WeaviateObject
.builder()
.className(objectClass)
.id(id)
.vector(embedding.vectorAsList().toArray(new Float[0]));
if (text != null) {
Map<String, Object> props = new HashMap<>();
props.put(METADATA_TEXT_SEGMENT, text);
builder.properties(props);
}
return builder.build();
}
private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS);
return new EmbeddingMatch<>(
(String) additional.get("id"),
Embedding.from(
((List<Double>) additional.get("vector")).stream().map(Double::floatValue).collect(Collectors.toList())
),
TextSegment.from((String) item.get(METADATA_TEXT_SEGMENT)),
(Double) additional.get("certainty")
);
}
// TODO this shall be migrated to some common place
private static String generateUUID(String input) {
try {
byte[] hashBytes = MessageDigest.getInstance("SHA-256").digest(input.getBytes(UTF_8));
StringBuilder sb = new StringBuilder();
for (byte b : hashBytes) sb.append(String.format("%02x", b));
return UUID.nameUUIDFromBytes(sb.toString().getBytes(UTF_8)).toString();
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
// TODO this shall be migrated to some common place
private static String generateRandomId() {
return UUID.randomUUID().toString();
}
}

View File

@ -0,0 +1,144 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import lombok.Builder;
public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
private final EmbeddingStore<TextSegment> implementation;
/**
* Creates a new WeaviateEmbeddingStore instance.
*
* @param apiKey your Weaviate API key
* @param scheme the scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster.
* @param host the host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL.
* Find in under Details of your Weaviate cluster.
* @param objectClass the object class you want to store, e.g. "MyGreatClass"
* @param avoidDups if true (default), then <code>WeaviateEmbeddingStore</code> will generate a hashed ID based on
* provided text segment, which avoids duplicated entries in DB.
* If false, then random ID will be generated.
* @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details <a href="https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-write-consistency">here</a>.
*/
@Builder
public WeaviateEmbeddingStore(
String apiKey,
String scheme,
String host,
String objectClass,
boolean avoidDups,
String consistencyLevel
) {
try {
implementation =
loadDynamically(
"dev.langchain4j.store.embedding.WeaviateEmbeddingStoreImpl",
apiKey,
scheme,
host,
objectClass,
avoidDups,
consistencyLevel
);
} catch (ClassNotFoundException e) {
throw new RuntimeException(getMessage(), e);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static String getMessage() {
return (
"To use WeaviateEmbeddingStore, please add the following dependency to your project:\n\n" +
"Maven:\n" +
"<dependency>\n" +
" <groupId>dev.langchain4j</groupId>\n" +
" <artifactId>langchain4j-weaviate</artifactId>\n" +
" <version>0.18.0</version>\n" +
"</dependency>\n\n" +
"Gradle:\n" +
"implementation 'dev.langchain4j:langchain4j-weaviate:0.18.0'\n"
);
}
private static EmbeddingStore<TextSegment> loadDynamically(
String implementationClassName,
String apiKey,
String scheme,
String host,
String objectClass,
boolean avoidDups,
String consistencyLevel
)
throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
Class<?> implementationClass = Class.forName(implementationClassName);
Class<?>[] constructorParameterTypes = new Class<?>[] {
String.class,
String.class,
String.class,
String.class,
boolean.class,
String.class,
};
Constructor<?> constructor = implementationClass.getConstructor(constructorParameterTypes);
return (EmbeddingStore<TextSegment>) constructor.newInstance(
apiKey,
scheme,
host,
objectClass,
avoidDups,
consistencyLevel
);
}
@Override
public String add(Embedding embedding) {
return implementation.add(embedding);
}
/**
* Adds a new embedding with provided ID to the store.
*
* @param id the ID of the embedding to add in UUID format, since it's Weaviate requirement.
* See <a href="https://weaviate.io/developers/weaviate/manage-data/create#id">Weaviate docs</a> and
* <a href="https://en.wikipedia.org/wiki/Universally_unique_identifier">UUID on Wikipedia</a>
* @param embedding the embedding to add
*/
@Override
public void add(String id, Embedding embedding) {
implementation.add(id, embedding);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
return implementation.add(embedding, textSegment);
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
return implementation.addAll(embeddings);
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegments) {
return implementation.addAll(embeddings, textSegments);
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
return implementation.findRelevant(referenceEmbedding, maxResults);
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(
Embedding referenceEmbedding,
int maxResults,
double minSimilarity
) {
return implementation.findRelevant(referenceEmbedding, maxResults, minSimilarity);
}
}

View File

@ -16,6 +16,7 @@
<module>langchain4j</module>
<module>langchain4j-pinecone</module>
<module>langchain4j-weaviate</module>
<module>langchain4j-embeddings</module>
<module>langchain4j-embeddings-all-minilm-l6-v2</module>