From 6a87b9b6089ab6f217bf03f1a39f7184afeecf85 Mon Sep 17 00:00:00 2001 From: Katia Aresti Date: Tue, 23 Apr 2024 17:03:00 +0200 Subject: [PATCH] Refactor the code to avoid duplication between integrations (#845) Refactoring to allow reusing the code between integrations --- .../infinispan/InfinispanEmbeddingStore.java | 228 ++++++++++++------ .../InfinispanStoreConfiguration.java | 104 ++++++++ .../infinispan/LangChainInfinispanItem.java | 8 + .../infinispan/LangChainItemMarshaller.java | 6 +- .../infinispan/LangChainMetadata.java | 5 + .../LangChainMetadataMarshaller.java | 6 +- .../infinispan/LangchainSchemaCreator.java | 38 +++ 7 files changed, 322 insertions(+), 73 deletions(-) create mode 100644 langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanStoreConfiguration.java create mode 100644 langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangchainSchemaCreator.java diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java index 51b82ae89..f4dcdfc23 100644 --- a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.java @@ -9,11 +9,11 @@ import org.infinispan.client.hotrod.RemoteCache; import org.infinispan.client.hotrod.RemoteCacheManager; import org.infinispan.client.hotrod.configuration.ConfigurationBuilder; import org.infinispan.commons.api.query.Query; +import org.infinispan.commons.configuration.StringConfiguration; import org.infinispan.commons.marshall.ProtoStreamMarshaller; import org.infinispan.protostream.FileDescriptorSource; import org.infinispan.protostream.SerializationContext; import org.infinispan.protostream.schema.Schema; -import org.infinispan.protostream.schema.Type; import org.infinispan.query.remote.client.ProtobufMetadataManagerConstants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,6 +31,8 @@ import static dev.langchain4j.internal.Utils.randomUUID; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.internal.ValidationUtils.ensureTrue; +import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_CACHE_CONFIG; +import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_DISTANCE; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; @@ -42,86 +44,90 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { private static final Logger log = LoggerFactory.getLogger(InfinispanEmbeddingStore.class); private final RemoteCache remoteCache; + private final InfinispanStoreConfiguration storeConfiguration; - private final LangChainItemMarshaller itemMarshaller; - private final LangChainMetadataMarshaller metadataMarshaller; + /** + * Creates an Infinispan embedding store from a RemoteCacheManager + * Assumes marshalling configuration is already provided by the RemoteCacheManager instance. + * + * @param remoteCacheManager, the already configured remote cache manager + * @param storeConfiguration, the store configuration + */ + public InfinispanEmbeddingStore(RemoteCacheManager remoteCacheManager, + InfinispanStoreConfiguration storeConfiguration) { - private static final String DEFAULT_CACHE_CONFIG = - "\n" - + "\n" - + "\n" - + "LANGCHAINITEM\n" - + "LANGCHAIN_METADATA\n" - + "\n" - + "\n" - + ""; + ensureNotNull(remoteCacheManager, "remoteCacheManager"); + ensureNotNull(storeConfiguration, "storeConfiguration"); + ensureNotNull(storeConfiguration.dimension(), "dimension"); + ensureNotBlank(storeConfiguration.cacheName(), "cacheName"); - public static final String ITEM_PACKAGE = "dev.langchain4j"; - public static final String LANGCHAIN_ITEM = "LangChainItem"; - public static final String METADATA_ITEM = "LangChainMetadata"; + this.storeConfiguration = storeConfiguration; + + if (storeConfiguration.createCache()) { + this.remoteCache = remoteCacheManager.administration() + .getOrCreateCache(storeConfiguration.cacheName(), new StringConfiguration(computeCacheConfiguration(storeConfiguration))); + } else { + this.remoteCache = remoteCacheManager.getCache(storeConfiguration.cacheName()); + } + } /** * Creates an instance of InfinispanEmbeddingStore - * - * @param builder Infinispan Configuration Builder - * @param name The name of the store - * @param dimension The dimension of the store */ public InfinispanEmbeddingStore(ConfigurationBuilder builder, - String name, - Integer dimension) { + InfinispanStoreConfiguration storeConfiguration) { ensureNotNull(builder, "builder"); - ensureNotBlank(name, "name"); - ensureNotNull(dimension, "dimension"); - String langchainType = LANGCHAIN_ITEM + dimension; - String metadataType = METADATA_ITEM + dimension; - itemMarshaller = new LangChainItemMarshaller(computeTypeWithPackage(langchainType)); - metadataMarshaller = new LangChainMetadataMarshaller(computeTypeWithPackage(metadataType)); - builder.remoteCache(name) - .configuration(DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", name) - .replace("LANGCHAINITEM", itemMarshaller.getTypeName()) - .replace("LANGCHAIN_METADATA", metadataMarshaller.getTypeName())); + ensureNotNull(storeConfiguration, "storeConfiguration"); + ensureNotBlank(storeConfiguration.cacheName(), "cacheName"); + ensureNotNull(storeConfiguration.dimension(), "dimension"); + this.storeConfiguration = storeConfiguration; + Schema schema = LangchainSchemaCreator.buildSchema(storeConfiguration); + + if (storeConfiguration.createCache()) { + String remoteCacheConfig = computeCacheConfiguration(storeConfiguration); + builder.remoteCache(storeConfiguration.cacheName()).configuration(remoteCacheConfig); + } // Registers the schema on the client ProtoStreamMarshaller marshaller = new ProtoStreamMarshaller(); SerializationContext serializationContext = marshaller.getSerializationContext(); - String fileName = ITEM_PACKAGE + "." + "dimension." + dimension + ".proto"; - Schema schema = new Schema.Builder("magazine.proto") - .packageName(ITEM_PACKAGE) - .addMessage(metadataType) - .addComment("@Indexed") - .addField(Type.Scalar.STRING, "name", 1) - .addComment("@Text") - .addField(Type.Scalar.STRING, "value", 2) - .addComment("@Text") - .addMessage(langchainType) - .addComment("@Indexed") - .addField(Type.Scalar.STRING, "id", 1) - .addComment("@Text") - .addField(Type.Scalar.STRING, "text", 2) - .addComment("@Keyword") - .addRepeatedField(Type.Scalar.FLOAT, "embedding", 3) - .addComment("@Vector(dimension=" + dimension + ", similarity=COSINE)") - .addRepeatedField(Type.create(metadataType), "metadata", 4) - .build(); - String schemaContent = schema.toString(); - FileDescriptorSource fileDescriptorSource = FileDescriptorSource.fromString(fileName, schemaContent); + FileDescriptorSource fileDescriptorSource = FileDescriptorSource.fromString(storeConfiguration.fileName(), schemaContent); serializationContext.registerProtoFiles(fileDescriptorSource); - serializationContext.registerMarshaller(metadataMarshaller); - serializationContext.registerMarshaller(itemMarshaller); + serializationContext.registerMarshaller(new LangChainItemMarshaller(storeConfiguration.langchainItemFullType())); + serializationContext.registerMarshaller(new LangChainMetadataMarshaller(storeConfiguration.metadataFullType())); builder.marshaller(marshaller); - // Uploads the schema to the server - RemoteCacheManager rmc = new RemoteCacheManager(builder.build()); - RemoteCache metadataCache = rmc - .getCache(ProtobufMetadataManagerConstants.PROTOBUF_METADATA_CACHE_NAME); - metadataCache.put(fileName, schemaContent); - this.remoteCache = rmc.getCache(name); + // creates the client + RemoteCacheManager rmc = new RemoteCacheManager(builder.build()); + + // Uploads the schema to the server + if (storeConfiguration.registerSchema()) { + RemoteCache metadataCache = rmc + .getCache(ProtobufMetadataManagerConstants.PROTOBUF_METADATA_CACHE_NAME); + metadataCache.put(storeConfiguration.fileName(), schemaContent); + } + + this.remoteCache = rmc.getCache(storeConfiguration.cacheName()); } - private static String computeTypeWithPackage(String langchainType) { - return ITEM_PACKAGE + "." + langchainType; + /** + * Gets the underlying Infinispan remote cache + * + * @return RemoteCache + */ + public RemoteCache getRemoteCache() { + return remoteCache; + } + + private String computeCacheConfiguration(InfinispanStoreConfiguration storeConfiguration) { + String remoteCacheConfig = storeConfiguration.cacheConfig(); + if (remoteCacheConfig == null) { + remoteCacheConfig = DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", storeConfiguration.cacheName()) + .replace("LANGCHAINITEM", storeConfiguration.langchainItemFullType()) + .replace("LANGCHAIN_METADATA", storeConfiguration.metadataFullType()); + } + return remoteCacheConfig; } @Override @@ -163,7 +169,7 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { @Override public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { - Query query = remoteCache.query("select i, score(i) from " + itemMarshaller.getTypeName() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~3"); + Query query = remoteCache.query("select i, score(i) from " + storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~" + storeConfiguration.distance()); List hits = query.maxResults(maxResults).list(); return hits.stream().map(obj -> { @@ -228,15 +234,32 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { } public static class Builder { - private ConfigurationBuilder builder; - private String name; + private ConfigurationBuilder configurationBuilder; + private String cacheName; private Integer dimension; + private Integer distance; + private String similarity; + private String cacheConfig; + private String packageName; + private String fileName; + private String langchainItemName; + private String metadataItemName; + private boolean registerSchema = true; + private boolean createCache = true; /** * Infinispan cache name to be used, will be created on first access */ public Builder cacheName(String name) { - this.name = name; + this.cacheName = name; + return this; + } + + /** + * Infinispan cache config to be used, will be created on first access + */ + public Builder cacheConfig(String cacheConfig) { + this.cacheConfig = cacheConfig; return this; } @@ -248,14 +271,75 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { return this; } + /** + * Infinispan distance for knn query + */ + public Builder distance(Integer distance) { + this.distance = distance; + return this; + } + + /** + * Infinispan similarity for the embedding definition + */ + public Builder similarity(String similarity) { + this.similarity = similarity; + return this; + } + + /** + * Infinispan schema package name + */ + public Builder packageName(String packageName) { + this.packageName = packageName; + return this; + } + + /** + * Infinispan schema file name + */ + public Builder fileName(String fileName) { + this.fileName = fileName; + return this; + } + + /** + * Infinispan schema langchainItemName + */ + public Builder langchainItemName(String langchainItemName) { + this.langchainItemName = langchainItemName; + return this; + } + + /** + * Infinispan schema metadataItemName + */ + public Builder metadataItemName(String metadataItemName) { + this.metadataItemName = metadataItemName; + return this; + } + + /** + * Register Langchain schema in the server + */ + public Builder registerSchema(boolean registerSchema) { + this.registerSchema = registerSchema; + return this; + } + + /** + * Create cache in the server + */ + public Builder createCache(boolean createCache) { + this.createCache = createCache; + return this; + } + /** * Infinispan Configuration Builder - * - * @param builder, Infinispan client configuration builder - * @return this Builder */ public Builder infinispanConfigBuilder(ConfigurationBuilder builder) { - this.builder = builder; + this.configurationBuilder = builder; return this; } @@ -265,7 +349,9 @@ public class InfinispanEmbeddingStore implements EmbeddingStore { * @return InfinispanEmbeddingStore */ public InfinispanEmbeddingStore build() { - return new InfinispanEmbeddingStore(builder, name, dimension); + return new InfinispanEmbeddingStore(configurationBuilder, + new InfinispanStoreConfiguration(cacheName, dimension, distance, similarity, cacheConfig, packageName, fileName, + langchainItemName, metadataItemName, createCache, registerSchema)); } } } diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanStoreConfiguration.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanStoreConfiguration.java new file mode 100644 index 000000000..4c3d65f8f --- /dev/null +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/InfinispanStoreConfiguration.java @@ -0,0 +1,104 @@ +package dev.langchain4j.store.embedding.infinispan; + +/** + * Holds configuration for the store + */ +public record InfinispanStoreConfiguration(String cacheName, + Integer dimension, + Integer distance, + String similarity, + String cacheConfig, + String packageItem, + String fileName, + String langchainItemName, + String metadataItemName, + boolean createCache, + boolean registerSchema){ + + /** + * Default Cache Config + */ + public static final String DEFAULT_CACHE_CONFIG = + "\n" + + "\n" + + "\n" + + "LANGCHAINITEM\n" + + "LANGCHAIN_METADATA\n" + + "\n" + + "\n" + + ""; + + /** + * Default package of the schema + */ + public static final String DEFAULT_ITEM_PACKAGE = "dev.langchain4j"; + + /** + * Default name of the protobuf langchain item. Size will be added + */ + public static final String DEFAULT_LANGCHAIN_ITEM = "LangChainItem"; + /** + * Default name of the protobuf metadata item. Size will be added + */ + public static final String DEFAULT_METADATA_ITEM = "LangChainMetadata"; + /** + * The default distance to for the search + */ + public static final int DEFAULT_DISTANCE = 3; + /** + * Default vector similarity + */ + public static final String DEFAULT_SIMILARITY = "COSINE"; + + /** + * Creates the configuration and sets default values + * + * @param cacheName, mandatory + * @param dimension, mandatory + * @param distance, defaults to 3 + * @param similarity, defaults COUSINE + * @param cacheConfig, the full cache configuration + * @param packageItem, optional the package item + * @param fileName, optional file name + * @param langchainItemName, optional item name + * @param metadataItemName, optional metadata item name + * @param createCache, defaults to true. Disables creating the cache on startup + * @param registerSchema, defaults to true. Disables registering the schema in the server + */ + public InfinispanStoreConfiguration(String cacheName, Integer dimension, Integer distance, String similarity, String cacheConfig, + String packageItem, String fileName, String langchainItemName, + String metadataItemName, boolean createCache, boolean registerSchema) { + this.cacheName = cacheName; + this.dimension = dimension; + this.cacheConfig = cacheConfig; + this.distance = distance != null ? distance : DEFAULT_DISTANCE; + this.similarity = similarity != null ? similarity : DEFAULT_SIMILARITY; + this.packageItem = packageItem != null ? packageItem : DEFAULT_ITEM_PACKAGE; + this.fileName = fileName != null ? fileName: computeFileName(packageItem, dimension); + this.langchainItemName = langchainItemName != null? langchainItemName : DEFAULT_LANGCHAIN_ITEM + dimension; + this.metadataItemName = metadataItemName != null? metadataItemName : DEFAULT_METADATA_ITEM + dimension; + this.createCache = createCache; + this.registerSchema = registerSchema; + } + + /** + * Get the full name of the langchainItem protobuf type + * @return langchainItemFullType + */ + public String langchainItemFullType() { + return packageItem + "." + langchainItemName; + } + + /** + * Get the full name of the metadata protobuf type + * @return metadataFullType + */ + public String metadataFullType() { + return packageItem + "." + metadataItemName; + } + + private static String computeFileName(String itemPackage, int dimension) { + return itemPackage + "." + "dimension." + dimension + ".proto"; + } + +} diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainInfinispanItem.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainInfinispanItem.java index affcdf33e..f10a9171a 100644 --- a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainInfinispanItem.java +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainInfinispanItem.java @@ -2,5 +2,13 @@ package dev.langchain4j.store.embedding.infinispan; import java.util.Set; +/** + * Langchain item that is serialized for the langchain integration use case + * + * @param id, the id of the item + * @param embedding, the vector + * @param text, associated text + * @param metadata, additional set of metadata + */ public record LangChainInfinispanItem(String id, float[] embedding, String text, Set metadata) {} diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainItemMarshaller.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainItemMarshaller.java index 26d8f1677..9ab4fc040 100644 --- a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainItemMarshaller.java +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainItemMarshaller.java @@ -9,10 +9,14 @@ import java.util.Set; /** * Marshaller to read and write embeddings to Infinispan */ -class LangChainItemMarshaller implements MessageMarshaller { +public class LangChainItemMarshaller implements MessageMarshaller { private final String typeName; + /** + * Constructor for the LangChainItemMarshaller Marshaller + * @param typeName, the full type of the protobuf entity + */ public LangChainItemMarshaller(String typeName) { this.typeName = typeName; } diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadata.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadata.java index de6f26de3..e8818e132 100644 --- a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadata.java +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadata.java @@ -1,4 +1,9 @@ package dev.langchain4j.store.embedding.infinispan; +/** + * Langchain Metadata item that is serialized for the langchain integration use case + * @param name, the name of the metadata + * @param value, the value of the metadata + */ public record LangChainMetadata(String name, String value) {} diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadataMarshaller.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadataMarshaller.java index 522204335..ac0d19a8e 100644 --- a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadataMarshaller.java +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangChainMetadataMarshaller.java @@ -7,10 +7,14 @@ import java.io.IOException; /** * Marshaller to read and write metadata to Infinispan */ -class LangChainMetadataMarshaller implements MessageMarshaller { +public class LangChainMetadataMarshaller implements MessageMarshaller { private final String typeName; + /** + * Constructor for the LangChainMetadata Marshaller + * @param typeName, the full type of the protobuf entity + */ public LangChainMetadataMarshaller(String typeName) { this.typeName = typeName; } diff --git a/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangchainSchemaCreator.java b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangchainSchemaCreator.java new file mode 100644 index 000000000..d7c66a0e7 --- /dev/null +++ b/langchain4j-infinispan/src/main/java/dev/langchain4j/store/embedding/infinispan/LangchainSchemaCreator.java @@ -0,0 +1,38 @@ +package dev.langchain4j.store.embedding.infinispan; + +import org.infinispan.protostream.schema.Schema; +import org.infinispan.protostream.schema.Type; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * LangchainSchemaCreator for Infinispan + */ +public final class LangchainSchemaCreator { + /** + * Build the Infinispan Schema to marshall embeddings + * + * @param storeConfiguration, the configuration of the store + * @return produced Schema + */ + public static Schema buildSchema(InfinispanStoreConfiguration storeConfiguration) { + return new Schema.Builder(storeConfiguration.fileName()) + .packageName(storeConfiguration.packageItem()) + .addMessage(storeConfiguration.metadataItemName()) + .addComment("@Indexed") + .addField(Type.Scalar.STRING, "name", 1) + .addComment("@Text") + .addField(Type.Scalar.STRING, "value", 2) + .addComment("@Text") + .addMessage(storeConfiguration.langchainItemName()) + .addComment("@Indexed") + .addField(Type.Scalar.STRING, "id", 1) + .addComment("@Text") + .addField(Type.Scalar.STRING, "text", 2) + .addComment("@Keyword") + .addRepeatedField(Type.Scalar.FLOAT, "embedding", 3) + .addComment("@Vector(dimension=" + storeConfiguration.dimension() + ", similarity=" + storeConfiguration.similarity() + ")") + .addRepeatedField(Type.create(storeConfiguration.metadataItemName()), "metadata", 4) + .build(); + } +}