Refactor the code to avoid duplication between integrations (#845)

Refactoring to allow reusing the code between integrations
This commit is contained in:
Katia Aresti 2024-04-23 17:03:00 +02:00 committed by GitHub
parent a4c256d8d0
commit 6a87b9b608
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 322 additions and 73 deletions

View File

@ -9,11 +9,11 @@ import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager; import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.client.hotrod.configuration.ConfigurationBuilder; import org.infinispan.client.hotrod.configuration.ConfigurationBuilder;
import org.infinispan.commons.api.query.Query; import org.infinispan.commons.api.query.Query;
import org.infinispan.commons.configuration.StringConfiguration;
import org.infinispan.commons.marshall.ProtoStreamMarshaller; import org.infinispan.commons.marshall.ProtoStreamMarshaller;
import org.infinispan.protostream.FileDescriptorSource; import org.infinispan.protostream.FileDescriptorSource;
import org.infinispan.protostream.SerializationContext; import org.infinispan.protostream.SerializationContext;
import org.infinispan.protostream.schema.Schema; import org.infinispan.protostream.schema.Schema;
import org.infinispan.protostream.schema.Type;
import org.infinispan.query.remote.client.ProtobufMetadataManagerConstants; import org.infinispan.query.remote.client.ProtobufMetadataManagerConstants;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -31,6 +31,8 @@ import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue; import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_CACHE_CONFIG;
import static dev.langchain4j.store.embedding.infinispan.InfinispanStoreConfiguration.DEFAULT_DISTANCE;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toList;
@ -42,86 +44,90 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final Logger log = LoggerFactory.getLogger(InfinispanEmbeddingStore.class); private static final Logger log = LoggerFactory.getLogger(InfinispanEmbeddingStore.class);
private final RemoteCache<String, LangChainInfinispanItem> remoteCache; private final RemoteCache<String, LangChainInfinispanItem> remoteCache;
private final InfinispanStoreConfiguration storeConfiguration;
private final LangChainItemMarshaller itemMarshaller; /**
private final LangChainMetadataMarshaller metadataMarshaller; * Creates an Infinispan embedding store from a RemoteCacheManager
* Assumes marshalling configuration is already provided by the RemoteCacheManager instance.
*
* @param remoteCacheManager, the already configured remote cache manager
* @param storeConfiguration, the store configuration
*/
public InfinispanEmbeddingStore(RemoteCacheManager remoteCacheManager,
InfinispanStoreConfiguration storeConfiguration) {
private static final String DEFAULT_CACHE_CONFIG = ensureNotNull(remoteCacheManager, "remoteCacheManager");
"<distributed-cache name=\"CACHE_NAME\">\n" ensureNotNull(storeConfiguration, "storeConfiguration");
+ "<indexing storage=\"local-heap\">\n" ensureNotNull(storeConfiguration.dimension(), "dimension");
+ "<indexed-entities>\n" ensureNotBlank(storeConfiguration.cacheName(), "cacheName");
+ "<indexed-entity>LANGCHAINITEM</indexed-entity>\n"
+ "<indexed-entity>LANGCHAIN_METADATA</indexed-entity>\n"
+ "</indexed-entities>\n"
+ "</indexing>\n"
+ "</distributed-cache>";
public static final String ITEM_PACKAGE = "dev.langchain4j"; this.storeConfiguration = storeConfiguration;
public static final String LANGCHAIN_ITEM = "LangChainItem";
public static final String METADATA_ITEM = "LangChainMetadata"; if (storeConfiguration.createCache()) {
this.remoteCache = remoteCacheManager.administration()
.getOrCreateCache(storeConfiguration.cacheName(), new StringConfiguration(computeCacheConfiguration(storeConfiguration)));
} else {
this.remoteCache = remoteCacheManager.getCache(storeConfiguration.cacheName());
}
}
/** /**
* Creates an instance of InfinispanEmbeddingStore * Creates an instance of InfinispanEmbeddingStore
*
* @param builder Infinispan Configuration Builder
* @param name The name of the store
* @param dimension The dimension of the store
*/ */
public InfinispanEmbeddingStore(ConfigurationBuilder builder, public InfinispanEmbeddingStore(ConfigurationBuilder builder,
String name, InfinispanStoreConfiguration storeConfiguration) {
Integer dimension) {
ensureNotNull(builder, "builder"); ensureNotNull(builder, "builder");
ensureNotBlank(name, "name"); ensureNotNull(storeConfiguration, "storeConfiguration");
ensureNotNull(dimension, "dimension"); ensureNotBlank(storeConfiguration.cacheName(), "cacheName");
String langchainType = LANGCHAIN_ITEM + dimension; ensureNotNull(storeConfiguration.dimension(), "dimension");
String metadataType = METADATA_ITEM + dimension; this.storeConfiguration = storeConfiguration;
itemMarshaller = new LangChainItemMarshaller(computeTypeWithPackage(langchainType)); Schema schema = LangchainSchemaCreator.buildSchema(storeConfiguration);
metadataMarshaller = new LangChainMetadataMarshaller(computeTypeWithPackage(metadataType));
builder.remoteCache(name) if (storeConfiguration.createCache()) {
.configuration(DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", name) String remoteCacheConfig = computeCacheConfiguration(storeConfiguration);
.replace("LANGCHAINITEM", itemMarshaller.getTypeName()) builder.remoteCache(storeConfiguration.cacheName()).configuration(remoteCacheConfig);
.replace("LANGCHAIN_METADATA", metadataMarshaller.getTypeName())); }
// Registers the schema on the client // Registers the schema on the client
ProtoStreamMarshaller marshaller = new ProtoStreamMarshaller(); ProtoStreamMarshaller marshaller = new ProtoStreamMarshaller();
SerializationContext serializationContext = marshaller.getSerializationContext(); SerializationContext serializationContext = marshaller.getSerializationContext();
String fileName = ITEM_PACKAGE + "." + "dimension." + dimension + ".proto";
Schema schema = new Schema.Builder("magazine.proto")
.packageName(ITEM_PACKAGE)
.addMessage(metadataType)
.addComment("@Indexed")
.addField(Type.Scalar.STRING, "name", 1)
.addComment("@Text")
.addField(Type.Scalar.STRING, "value", 2)
.addComment("@Text")
.addMessage(langchainType)
.addComment("@Indexed")
.addField(Type.Scalar.STRING, "id", 1)
.addComment("@Text")
.addField(Type.Scalar.STRING, "text", 2)
.addComment("@Keyword")
.addRepeatedField(Type.Scalar.FLOAT, "embedding", 3)
.addComment("@Vector(dimension=" + dimension + ", similarity=COSINE)")
.addRepeatedField(Type.create(metadataType), "metadata", 4)
.build();
String schemaContent = schema.toString(); String schemaContent = schema.toString();
FileDescriptorSource fileDescriptorSource = FileDescriptorSource.fromString(fileName, schemaContent); FileDescriptorSource fileDescriptorSource = FileDescriptorSource.fromString(storeConfiguration.fileName(), schemaContent);
serializationContext.registerProtoFiles(fileDescriptorSource); serializationContext.registerProtoFiles(fileDescriptorSource);
serializationContext.registerMarshaller(metadataMarshaller); serializationContext.registerMarshaller(new LangChainItemMarshaller(storeConfiguration.langchainItemFullType()));
serializationContext.registerMarshaller(itemMarshaller); serializationContext.registerMarshaller(new LangChainMetadataMarshaller(storeConfiguration.metadataFullType()));
builder.marshaller(marshaller); builder.marshaller(marshaller);
// Uploads the schema to the server
RemoteCacheManager rmc = new RemoteCacheManager(builder.build());
RemoteCache<String, String> metadataCache = rmc
.getCache(ProtobufMetadataManagerConstants.PROTOBUF_METADATA_CACHE_NAME);
metadataCache.put(fileName, schemaContent);
this.remoteCache = rmc.getCache(name); // creates the client
RemoteCacheManager rmc = new RemoteCacheManager(builder.build());
// Uploads the schema to the server
if (storeConfiguration.registerSchema()) {
RemoteCache<String, String> metadataCache = rmc
.getCache(ProtobufMetadataManagerConstants.PROTOBUF_METADATA_CACHE_NAME);
metadataCache.put(storeConfiguration.fileName(), schemaContent);
}
this.remoteCache = rmc.getCache(storeConfiguration.cacheName());
} }
private static String computeTypeWithPackage(String langchainType) { /**
return ITEM_PACKAGE + "." + langchainType; * Gets the underlying Infinispan remote cache
*
* @return RemoteCache
*/
public RemoteCache<String, LangChainInfinispanItem> getRemoteCache() {
return remoteCache;
}
private String computeCacheConfiguration(InfinispanStoreConfiguration storeConfiguration) {
String remoteCacheConfig = storeConfiguration.cacheConfig();
if (remoteCacheConfig == null) {
remoteCacheConfig = DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", storeConfiguration.cacheName())
.replace("LANGCHAINITEM", storeConfiguration.langchainItemFullType())
.replace("LANGCHAIN_METADATA", storeConfiguration.metadataFullType());
}
return remoteCacheConfig;
} }
@Override @Override
@ -163,7 +169,7 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
@Override @Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
Query<Object[]> query = remoteCache.query("select i, score(i) from " + itemMarshaller.getTypeName() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~3"); Query<Object[]> query = remoteCache.query("select i, score(i) from " + storeConfiguration.langchainItemFullType() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~" + storeConfiguration.distance());
List<Object[]> hits = query.maxResults(maxResults).list(); List<Object[]> hits = query.maxResults(maxResults).list();
return hits.stream().map(obj -> { return hits.stream().map(obj -> {
@ -228,15 +234,32 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
public static class Builder { public static class Builder {
private ConfigurationBuilder builder; private ConfigurationBuilder configurationBuilder;
private String name; private String cacheName;
private Integer dimension; private Integer dimension;
private Integer distance;
private String similarity;
private String cacheConfig;
private String packageName;
private String fileName;
private String langchainItemName;
private String metadataItemName;
private boolean registerSchema = true;
private boolean createCache = true;
/** /**
* Infinispan cache name to be used, will be created on first access * Infinispan cache name to be used, will be created on first access
*/ */
public Builder cacheName(String name) { public Builder cacheName(String name) {
this.name = name; this.cacheName = name;
return this;
}
/**
* Infinispan cache config to be used, will be created on first access
*/
public Builder cacheConfig(String cacheConfig) {
this.cacheConfig = cacheConfig;
return this; return this;
} }
@ -248,14 +271,75 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
return this; return this;
} }
/**
* Infinispan distance for knn query
*/
public Builder distance(Integer distance) {
this.distance = distance;
return this;
}
/**
* Infinispan similarity for the embedding definition
*/
public Builder similarity(String similarity) {
this.similarity = similarity;
return this;
}
/**
* Infinispan schema package name
*/
public Builder packageName(String packageName) {
this.packageName = packageName;
return this;
}
/**
* Infinispan schema file name
*/
public Builder fileName(String fileName) {
this.fileName = fileName;
return this;
}
/**
* Infinispan schema langchainItemName
*/
public Builder langchainItemName(String langchainItemName) {
this.langchainItemName = langchainItemName;
return this;
}
/**
* Infinispan schema metadataItemName
*/
public Builder metadataItemName(String metadataItemName) {
this.metadataItemName = metadataItemName;
return this;
}
/**
* Register Langchain schema in the server
*/
public Builder registerSchema(boolean registerSchema) {
this.registerSchema = registerSchema;
return this;
}
/**
* Create cache in the server
*/
public Builder createCache(boolean createCache) {
this.createCache = createCache;
return this;
}
/** /**
* Infinispan Configuration Builder * Infinispan Configuration Builder
*
* @param builder, Infinispan client configuration builder
* @return this Builder
*/ */
public Builder infinispanConfigBuilder(ConfigurationBuilder builder) { public Builder infinispanConfigBuilder(ConfigurationBuilder builder) {
this.builder = builder; this.configurationBuilder = builder;
return this; return this;
} }
@ -265,7 +349,9 @@ public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
* @return InfinispanEmbeddingStore * @return InfinispanEmbeddingStore
*/ */
public InfinispanEmbeddingStore build() { public InfinispanEmbeddingStore build() {
return new InfinispanEmbeddingStore(builder, name, dimension); return new InfinispanEmbeddingStore(configurationBuilder,
new InfinispanStoreConfiguration(cacheName, dimension, distance, similarity, cacheConfig, packageName, fileName,
langchainItemName, metadataItemName, createCache, registerSchema));
} }
} }
} }

View File

@ -0,0 +1,104 @@
package dev.langchain4j.store.embedding.infinispan;
/**
* Holds configuration for the store
*/
public record InfinispanStoreConfiguration(String cacheName,
Integer dimension,
Integer distance,
String similarity,
String cacheConfig,
String packageItem,
String fileName,
String langchainItemName,
String metadataItemName,
boolean createCache,
boolean registerSchema){
/**
* Default Cache Config
*/
public static final String DEFAULT_CACHE_CONFIG =
"<distributed-cache name=\"CACHE_NAME\">\n"
+ "<indexing storage=\"local-heap\">\n"
+ "<indexed-entities>\n"
+ "<indexed-entity>LANGCHAINITEM</indexed-entity>\n"
+ "<indexed-entity>LANGCHAIN_METADATA</indexed-entity>\n"
+ "</indexed-entities>\n"
+ "</indexing>\n"
+ "</distributed-cache>";
/**
* Default package of the schema
*/
public static final String DEFAULT_ITEM_PACKAGE = "dev.langchain4j";
/**
* Default name of the protobuf langchain item. Size will be added
*/
public static final String DEFAULT_LANGCHAIN_ITEM = "LangChainItem";
/**
* Default name of the protobuf metadata item. Size will be added
*/
public static final String DEFAULT_METADATA_ITEM = "LangChainMetadata";
/**
* The default distance to for the search
*/
public static final int DEFAULT_DISTANCE = 3;
/**
* Default vector similarity
*/
public static final String DEFAULT_SIMILARITY = "COSINE";
/**
* Creates the configuration and sets default values
*
* @param cacheName, mandatory
* @param dimension, mandatory
* @param distance, defaults to 3
* @param similarity, defaults COUSINE
* @param cacheConfig, the full cache configuration
* @param packageItem, optional the package item
* @param fileName, optional file name
* @param langchainItemName, optional item name
* @param metadataItemName, optional metadata item name
* @param createCache, defaults to true. Disables creating the cache on startup
* @param registerSchema, defaults to true. Disables registering the schema in the server
*/
public InfinispanStoreConfiguration(String cacheName, Integer dimension, Integer distance, String similarity, String cacheConfig,
String packageItem, String fileName, String langchainItemName,
String metadataItemName, boolean createCache, boolean registerSchema) {
this.cacheName = cacheName;
this.dimension = dimension;
this.cacheConfig = cacheConfig;
this.distance = distance != null ? distance : DEFAULT_DISTANCE;
this.similarity = similarity != null ? similarity : DEFAULT_SIMILARITY;
this.packageItem = packageItem != null ? packageItem : DEFAULT_ITEM_PACKAGE;
this.fileName = fileName != null ? fileName: computeFileName(packageItem, dimension);
this.langchainItemName = langchainItemName != null? langchainItemName : DEFAULT_LANGCHAIN_ITEM + dimension;
this.metadataItemName = metadataItemName != null? metadataItemName : DEFAULT_METADATA_ITEM + dimension;
this.createCache = createCache;
this.registerSchema = registerSchema;
}
/**
* Get the full name of the langchainItem protobuf type
* @return langchainItemFullType
*/
public String langchainItemFullType() {
return packageItem + "." + langchainItemName;
}
/**
* Get the full name of the metadata protobuf type
* @return metadataFullType
*/
public String metadataFullType() {
return packageItem + "." + metadataItemName;
}
private static String computeFileName(String itemPackage, int dimension) {
return itemPackage + "." + "dimension." + dimension + ".proto";
}
}

View File

@ -2,5 +2,13 @@ package dev.langchain4j.store.embedding.infinispan;
import java.util.Set; import java.util.Set;
/**
* Langchain item that is serialized for the langchain integration use case
*
* @param id, the id of the item
* @param embedding, the vector
* @param text, associated text
* @param metadata, additional set of metadata
*/
public record LangChainInfinispanItem(String id, float[] embedding, String text, Set<LangChainMetadata> metadata) {} public record LangChainInfinispanItem(String id, float[] embedding, String text, Set<LangChainMetadata> metadata) {}

View File

@ -9,10 +9,14 @@ import java.util.Set;
/** /**
* Marshaller to read and write embeddings to Infinispan * Marshaller to read and write embeddings to Infinispan
*/ */
class LangChainItemMarshaller implements MessageMarshaller<LangChainInfinispanItem> { public class LangChainItemMarshaller implements MessageMarshaller<LangChainInfinispanItem> {
private final String typeName; private final String typeName;
/**
* Constructor for the LangChainItemMarshaller Marshaller
* @param typeName, the full type of the protobuf entity
*/
public LangChainItemMarshaller(String typeName) { public LangChainItemMarshaller(String typeName) {
this.typeName = typeName; this.typeName = typeName;
} }

View File

@ -1,4 +1,9 @@
package dev.langchain4j.store.embedding.infinispan; package dev.langchain4j.store.embedding.infinispan;
/**
* Langchain Metadata item that is serialized for the langchain integration use case
* @param name, the name of the metadata
* @param value, the value of the metadata
*/
public record LangChainMetadata(String name, String value) {} public record LangChainMetadata(String name, String value) {}

View File

@ -7,10 +7,14 @@ import java.io.IOException;
/** /**
* Marshaller to read and write metadata to Infinispan * Marshaller to read and write metadata to Infinispan
*/ */
class LangChainMetadataMarshaller implements MessageMarshaller<LangChainMetadata> { public class LangChainMetadataMarshaller implements MessageMarshaller<LangChainMetadata> {
private final String typeName; private final String typeName;
/**
* Constructor for the LangChainMetadata Marshaller
* @param typeName, the full type of the protobuf entity
*/
public LangChainMetadataMarshaller(String typeName) { public LangChainMetadataMarshaller(String typeName) {
this.typeName = typeName; this.typeName = typeName;
} }

View File

@ -0,0 +1,38 @@
package dev.langchain4j.store.embedding.infinispan;
import org.infinispan.protostream.schema.Schema;
import org.infinispan.protostream.schema.Type;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
/**
* LangchainSchemaCreator for Infinispan
*/
public final class LangchainSchemaCreator {
/**
* Build the Infinispan Schema to marshall embeddings
*
* @param storeConfiguration, the configuration of the store
* @return produced Schema
*/
public static Schema buildSchema(InfinispanStoreConfiguration storeConfiguration) {
return new Schema.Builder(storeConfiguration.fileName())
.packageName(storeConfiguration.packageItem())
.addMessage(storeConfiguration.metadataItemName())
.addComment("@Indexed")
.addField(Type.Scalar.STRING, "name", 1)
.addComment("@Text")
.addField(Type.Scalar.STRING, "value", 2)
.addComment("@Text")
.addMessage(storeConfiguration.langchainItemName())
.addComment("@Indexed")
.addField(Type.Scalar.STRING, "id", 1)
.addComment("@Text")
.addField(Type.Scalar.STRING, "text", 2)
.addComment("@Keyword")
.addRepeatedField(Type.Scalar.FLOAT, "embedding", 3)
.addComment("@Vector(dimension=" + storeConfiguration.dimension() + ", similarity=" + storeConfiguration.similarity() + ")")
.addRepeatedField(Type.create(storeConfiguration.metadataItemName()), "metadata", 4)
.build();
}
}