diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 000000000..963354f23 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,3 @@ +{ + "printWidth": 120 +} diff --git a/langchain4j-weaviate/pom.xml b/langchain4j-weaviate/pom.xml new file mode 100644 index 000000000..e13b28c12 --- /dev/null +++ b/langchain4j-weaviate/pom.xml @@ -0,0 +1,48 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.18.0 + ../langchain4j-parent/pom.xml + + + + langchain4j-weaviate + jar + + LangChain4j integration with Weaviate + Uses io.weaviate.client library which has a BSD 3-Clause license: + https://github.com/weaviate/java-client/blob/main/LICENSE + + + + + + dev.langchain4j + langchain4j-core + 0.18.0 + + + + io.weaviate + client + 4.2.0 + + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + \ No newline at end of file diff --git a/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/WeaviateEmbeddingStoreImpl.java b/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/WeaviateEmbeddingStoreImpl.java new file mode 100644 index 000000000..477b04bbc --- /dev/null +++ b/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/WeaviateEmbeddingStoreImpl.java @@ -0,0 +1,215 @@ +package dev.langchain4j.store.embedding; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.joining; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateAuthClient; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.base.WeaviateErrorMessage; +import io.weaviate.client.v1.auth.exception.AuthException; +import io.weaviate.client.v1.data.model.WeaviateObject; +import io.weaviate.client.v1.data.replication.model.ConsistencyLevel; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument; +import io.weaviate.client.v1.graphql.query.fields.Field; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.*; +import java.util.stream.Collectors; +import lombok.Builder; + +public class WeaviateEmbeddingStoreImpl implements EmbeddingStore { + + private static final String DEFAULT_CLASS = "Default"; + private static final Double DEFAULT_MIN_CERTAINTY = 0.0; + private static final String METADATA_TEXT_SEGMENT = "text"; + private static final String ADDITIONALS = "_additional"; + + private final WeaviateClient client; + private final String objectClass; + private boolean avoidDups = true; + private String consistencyLevel = ConsistencyLevel.QUORUM; + + @Builder + public WeaviateEmbeddingStoreImpl( + String apiKey, + String scheme, + String host, + String objectClass, + boolean avoidDups, + String consistencyLevel + ) { + try { + client = WeaviateAuthClient.apiKey(new Config(scheme, host), apiKey); + } catch (AuthException e) { + throw new IllegalArgumentException(e); + } + this.objectClass = objectClass != null ? objectClass : DEFAULT_CLASS; + this.avoidDups = avoidDups; + this.consistencyLevel = consistencyLevel; + } + + @Override + public String add(Embedding embedding) { + String id = generateRandomId(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + addAll(singletonList(id), singletonList(embedding), null); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + return addAll(singletonList(embedding), singletonList(textSegment)).stream().findFirst().orElse(null); + } + + @Override + public List addAll(List embeddings) { + return addAll(embeddings, null); + } + + @Override + public List addAll(List embeddings, List embedded) { + return addAll(null, embeddings, embedded); + } + + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults) { + return findRelevant(referenceEmbedding, maxResults, DEFAULT_MIN_CERTAINTY); + } + + @Override + public List> findRelevant( + Embedding referenceEmbedding, + int maxResults, + double minCertainty + ) { + Result result = client + .graphQL() + .get() + .withClassName(objectClass) + .withFields( + Field.builder().name(METADATA_TEXT_SEGMENT).build(), + Field + .builder() + .name(ADDITIONALS) + .fields( + Field.builder().name("id").build(), + Field.builder().name("certainty").build(), + Field.builder().name("vector").build() + ) + .build() + ) + .withNearVector( + NearVectorArgument + .builder() + .vector(referenceEmbedding.vectorAsList().toArray(new Float[0])) + .certainty((float) minCertainty) + .build() + ) + .withLimit(maxResults) + .run(); + + if (result.hasErrors()) { + throw new IllegalArgumentException( + result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(joining("\n")) + ); + } + + Optional> resGetPart = + ((Map) result.getResult().getData()).entrySet().stream().findFirst(); + if (!resGetPart.isPresent()) { + return emptyList(); + } + + Optional resItemsPart = resGetPart.get().getValue().entrySet().stream().findFirst(); + if (!resItemsPart.isPresent()) { + return emptyList(); + } + + List> resItems = ((Map.Entry>>) resItemsPart.get()).getValue(); + + return resItems.stream().map(WeaviateEmbeddingStoreImpl::toEmbeddingMatch).collect(Collectors.toList()); + } + + private List addAll(List ids, List embeddings, List embedded) { + if (embedded != null && embeddings.size() != embedded.size()) { + throw new IllegalArgumentException("The list of embeddings and embedded must have the same size"); + } + + List resIds = new ArrayList<>(); + List objects = new ArrayList<>(); + for (int i = 0; i < embeddings.size(); i++) { + String id = ids != null + ? ids.get(i) + : avoidDups && embedded != null ? generateUUID(embedded.get(i).text()) : generateRandomId(); + resIds.add(id); + objects.add(buildObject(id, embeddings.get(i), embedded != null ? embedded.get(i).text() : null)); + } + + client + .batch() + .objectsBatcher() + .withObjects(objects.toArray(new WeaviateObject[0])) + .withConsistencyLevel(consistencyLevel) + .run(); + + return resIds; + } + + private WeaviateObject buildObject(String id, Embedding embedding, String text) { + WeaviateObject.WeaviateObjectBuilder builder = WeaviateObject + .builder() + .className(objectClass) + .id(id) + .vector(embedding.vectorAsList().toArray(new Float[0])); + + if (text != null) { + Map props = new HashMap<>(); + props.put(METADATA_TEXT_SEGMENT, text); + + builder.properties(props); + } + + return builder.build(); + } + + private static EmbeddingMatch toEmbeddingMatch(Map item) { + Map additional = (Map) item.get(ADDITIONALS); + + return new EmbeddingMatch<>( + (String) additional.get("id"), + Embedding.from( + ((List) additional.get("vector")).stream().map(Double::floatValue).collect(Collectors.toList()) + ), + TextSegment.from((String) item.get(METADATA_TEXT_SEGMENT)), + (Double) additional.get("certainty") + ); + } + + // TODO this shall be migrated to some common place + private static String generateUUID(String input) { + try { + byte[] hashBytes = MessageDigest.getInstance("SHA-256").digest(input.getBytes(UTF_8)); + StringBuilder sb = new StringBuilder(); + for (byte b : hashBytes) sb.append(String.format("%02x", b)); + return UUID.nameUUIDFromBytes(sb.toString().getBytes(UTF_8)).toString(); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException(e); + } + } + + // TODO this shall be migrated to some common place + private static String generateRandomId() { + return UUID.randomUUID().toString(); + } +} diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/WeaviateEmbeddingStore.java b/langchain4j/src/main/java/dev/langchain4j/store/embedding/WeaviateEmbeddingStore.java new file mode 100644 index 000000000..99c1a6ad1 --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/store/embedding/WeaviateEmbeddingStore.java @@ -0,0 +1,144 @@ +package dev.langchain4j.store.embedding; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.List; +import lombok.Builder; + +public class WeaviateEmbeddingStore implements EmbeddingStore { + + private final EmbeddingStore implementation; + + /** + * Creates a new WeaviateEmbeddingStore instance. + * + * @param apiKey your Weaviate API key + * @param scheme the scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster. + * @param host the host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL. + * Find in under Details of your Weaviate cluster. + * @param objectClass the object class you want to store, e.g. "MyGreatClass" + * @param avoidDups if true (default), then WeaviateEmbeddingStore will generate a hashed ID based on + * provided text segment, which avoids duplicated entries in DB. + * If false, then random ID will be generated. + * @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details here. + */ + @Builder + public WeaviateEmbeddingStore( + String apiKey, + String scheme, + String host, + String objectClass, + boolean avoidDups, + String consistencyLevel + ) { + try { + implementation = + loadDynamically( + "dev.langchain4j.store.embedding.WeaviateEmbeddingStoreImpl", + apiKey, + scheme, + host, + objectClass, + avoidDups, + consistencyLevel + ); + } catch (ClassNotFoundException e) { + throw new RuntimeException(getMessage(), e); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static String getMessage() { + return ( + "To use WeaviateEmbeddingStore, please add the following dependency to your project:\n\n" + + "Maven:\n" + + "\n" + + " dev.langchain4j\n" + + " langchain4j-weaviate\n" + + " 0.18.0\n" + + "\n\n" + + "Gradle:\n" + + "implementation 'dev.langchain4j:langchain4j-weaviate:0.18.0'\n" + ); + } + + private static EmbeddingStore loadDynamically( + String implementationClassName, + String apiKey, + String scheme, + String host, + String objectClass, + boolean avoidDups, + String consistencyLevel + ) + throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException { + Class implementationClass = Class.forName(implementationClassName); + Class[] constructorParameterTypes = new Class[] { + String.class, + String.class, + String.class, + String.class, + boolean.class, + String.class, + }; + Constructor constructor = implementationClass.getConstructor(constructorParameterTypes); + return (EmbeddingStore) constructor.newInstance( + apiKey, + scheme, + host, + objectClass, + avoidDups, + consistencyLevel + ); + } + + @Override + public String add(Embedding embedding) { + return implementation.add(embedding); + } + + /** + * Adds a new embedding with provided ID to the store. + * + * @param id the ID of the embedding to add in UUID format, since it's Weaviate requirement. + * See Weaviate docs and + * UUID on Wikipedia + * @param embedding the embedding to add + */ + @Override + public void add(String id, Embedding embedding) { + implementation.add(id, embedding); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + return implementation.add(embedding, textSegment); + } + + @Override + public List addAll(List embeddings) { + return implementation.addAll(embeddings); + } + + @Override + public List addAll(List embeddings, List textSegments) { + return implementation.addAll(embeddings, textSegments); + } + + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults) { + return implementation.findRelevant(referenceEmbedding, maxResults); + } + + @Override + public List> findRelevant( + Embedding referenceEmbedding, + int maxResults, + double minSimilarity + ) { + return implementation.findRelevant(referenceEmbedding, maxResults, minSimilarity); + } +} diff --git a/pom.xml b/pom.xml index 5139f7e31..e182eca1a 100644 --- a/pom.xml +++ b/pom.xml @@ -16,6 +16,7 @@ langchain4j langchain4j-pinecone + langchain4j-weaviate langchain4j-embeddings langchain4j-embeddings-all-minilm-l6-v2