Rework support of AstraDB and Cassandra (#548)
In the Datastax Astra DB saas solution, a new way to integrate with vector databases has been introduced: using an HTTP APi instead of the Cassandra Cluster. It is called the DataAPI and use the MongoDB principles with collections. The pull request includes the following: ### Update on previous implementations - Previous implementations of embedding stores have been grouped in a single `CassandraEmbeddingStore`. It can be instantiated for Astra or OSS Cassandra based on 2 different constructor builders but everything else is the same. - Previous implementations of chat memory stores have been grouped in a single `CassandraChatMemoryStore`. It can be instantiated for Astra or OSS Cassandra based on 2 different constructor builders but everything else is the same. - Integration test for OSS Cassandra now using test containers (as Cassandra 5-alpha2 image is out) - Usage ```java // Using with Astra (Cassandra AAS in the cloud) CassandraEmbeddingStore.builderAstra() .token(token) .databaseId(dbId) .databaseRegion(TEST_REGION) .keyspace(KEYSPACE) .table(TEST_INDEX) .dimension(11) .metric(CassandraSimilarityMetric.COSINE) .build(); // Using OSS Cassandra CassandraEmbeddingStore.builder() .contactPoints(Arrays.asList(contactPoint.getHostName())) .port(contactPoint.getPort()) .localDataCenter(DATACENTER) .keyspace(KEYSPACE) .table(TEST_INDEX) .dimension(11) .metric(CassandraSimilarityMetric.COSINE) .build(); ``` -Adding jdk11 in the pom ``` <maven.compiler.source>11</maven.compiler.source> <maven.compiler.target>11</maven.compiler.target> ``` - introducing `insertMany()`, distributed to all bulk loading - Extending the variables `EmbeddingStoreIT` - Using `MessageWindowChatMemory` for the tests.
This commit is contained in:
parent
b375c7b42f
commit
cd006b166c
|
@ -15,7 +15,7 @@ jobs:
|
|||
java_version: [8, 11, 17, 21]
|
||||
include:
|
||||
- java_version: '8'
|
||||
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch'
|
||||
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch'
|
||||
- java_version: '11'
|
||||
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j'
|
||||
- java_version: '17'
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
</parent>
|
||||
|
||||
<properties>
|
||||
<astra-sdk.version>0.6.11</astra-sdk.version>
|
||||
<astra-db-client.version>1.2.4</astra-db-client.version>
|
||||
<jackson.version>2.16.1</jackson.version>
|
||||
<logback.version>1.4.14</logback.version>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
|
@ -25,6 +29,18 @@
|
|||
<artifactId>langchain4j-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.datastax.astra</groupId>
|
||||
<artifactId>astra-db-client</artifactId>
|
||||
<version>${astra-db-client.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
|
@ -36,53 +52,59 @@
|
|||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.datastax.astra</groupId>
|
||||
<artifactId>astra-sdk-vector</artifactId>
|
||||
<version>${astra-sdk.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<!-- removing cve -->
|
||||
<dependency>
|
||||
<groupId>org.json</groupId>
|
||||
<artifactId>json</artifactId>
|
||||
<version>20231013</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-beanutils</groupId>
|
||||
<artifactId>commons-beanutils</artifactId>
|
||||
<version>1.9.4</version>
|
||||
</dependency>
|
||||
<!-- TESTS -->
|
||||
|
||||
<!-- Visibility for EmbeddingStoreIT -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<classifier>tests</classifier>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<!-- Same embeddings model to keep the 1% -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<version>${assertj.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>cassandra</artifactId>
|
||||
<version>${testcontainers.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<version>${testcontainers.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
package dev.langchain4j.store.embedding.astradb;
|
||||
|
||||
import com.dtsx.astra.sdk.AstraDBCollection;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import io.stargate.sdk.data.domain.JsonDocument;
|
||||
import io.stargate.sdk.data.domain.JsonDocumentMutationResult;
|
||||
import io.stargate.sdk.data.domain.JsonDocumentResult;
|
||||
import io.stargate.sdk.data.domain.odm.Document;
|
||||
import io.stargate.sdk.data.domain.query.Filter;
|
||||
import io.stargate.sdk.data.domain.query.SelectQuery;
|
||||
import io.stargate.sdk.data.domain.query.SelectQueryBuilder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.experimental.Accessors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Implementation of {@link EmbeddingStore} using AstraDB.
|
||||
*
|
||||
* @see EmbeddingStore
|
||||
*/
|
||||
@Slf4j
|
||||
@Getter @Setter
|
||||
@Accessors(fluent = true)
|
||||
public class AstraDbEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* Saving the text chunk as an attribut.
|
||||
*/
|
||||
public static final String KEY_ATTRIBUTES_BLOB = "body_blob";
|
||||
|
||||
/**
|
||||
* Metadata used for similarity.
|
||||
*/
|
||||
public static final String KEY_SIMILARITY = "$similarity";
|
||||
|
||||
/**
|
||||
* Client to work with an Astra Collection
|
||||
*/
|
||||
private final AstraDBCollection astraDBCollection;
|
||||
|
||||
/**
|
||||
* Bulk loading are processed in chunks, size of 1 chunk in between 1 and 20
|
||||
*/
|
||||
private final int itemsPerChunk;
|
||||
|
||||
/**
|
||||
* Bulk loading is distributed,the is the number threads
|
||||
*/
|
||||
private final int concurrentThreads;
|
||||
|
||||
/**
|
||||
* Initialization of the store with an EXISTING collection.
|
||||
*
|
||||
* @param client
|
||||
* astra db collection client
|
||||
*/
|
||||
public AstraDbEmbeddingStore(@NonNull AstraDBCollection client) {
|
||||
this(client, 20, 8);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialization of the store with an EXISTING collection.
|
||||
*
|
||||
* @param client
|
||||
* astra db collection client
|
||||
* @param itemsPerChunk
|
||||
* size of 1 chunk in between 1 and 20
|
||||
*/
|
||||
public AstraDbEmbeddingStore(@NonNull AstraDBCollection client, int itemsPerChunk, int concurrentThreads) {
|
||||
if (itemsPerChunk>20 || itemsPerChunk<1) {
|
||||
throw new IllegalArgumentException("'itemsPerChunk' should be in between 1 and 20");
|
||||
}
|
||||
if (concurrentThreads<1) {
|
||||
throw new IllegalArgumentException("'concurrentThreads' should be at least 1");
|
||||
}
|
||||
this.astraDBCollection = client;
|
||||
this.itemsPerChunk = itemsPerChunk;
|
||||
this.concurrentThreads = concurrentThreads;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all records from the table.
|
||||
*/
|
||||
public void clear() {
|
||||
astraDBCollection.deleteAll();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
return add(embedding, null);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
return astraDBCollection
|
||||
.insertOne(mapRecord(embedding, textSegment))
|
||||
.getDocument().getId();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
astraDBCollection.upsertOne(new JsonDocument().id(id).vector(embedding.vector()));
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
if (embeddings == null) return null;
|
||||
|
||||
// Map as a JsonDocument list.
|
||||
List<JsonDocument> recordList = embeddings
|
||||
.stream()
|
||||
.map(e -> mapRecord(e, null))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// No upsert needed as ids will be generated.
|
||||
return astraDBCollection
|
||||
.insertManyChunkedJsonDocuments(recordList, itemsPerChunk, concurrentThreads)
|
||||
.stream()
|
||||
.map(JsonDocumentMutationResult::getDocument)
|
||||
.map(Document::getId)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Add multiple embeddings as a single action.
|
||||
*
|
||||
* @param embeddingList
|
||||
* list of embeddings
|
||||
* @param textSegmentList
|
||||
* list of text segment
|
||||
*
|
||||
* @return list of new row if (same order as the input)
|
||||
*/
|
||||
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
|
||||
if (embeddingList == null || textSegmentList == null || embeddingList.size() != textSegmentList.size()) {
|
||||
throw new IllegalArgumentException("embeddingList and textSegmentList must not be null and have the same size");
|
||||
}
|
||||
|
||||
// Map as JsonDocument list
|
||||
List<JsonDocument> recordList = new ArrayList<>();
|
||||
for (int i = 0; i < embeddingList.size(); i++) {
|
||||
recordList.add(mapRecord(embeddingList.get(i), textSegmentList.get(i)));
|
||||
}
|
||||
|
||||
// No upsert needed (ids will be generated)
|
||||
return astraDBCollection
|
||||
.insertManyChunkedJsonDocuments(recordList, itemsPerChunk, concurrentThreads)
|
||||
.stream()
|
||||
.map(JsonDocumentMutationResult::getDocument)
|
||||
.map(Document::getId)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||
return findRelevant(referenceEmbedding, (Filter) null, maxResults, minScore);
|
||||
}
|
||||
|
||||
/**
|
||||
* Semantic search with metadata filtering.
|
||||
*
|
||||
* @param referenceEmbedding
|
||||
* vector
|
||||
* @param metaDatafilter
|
||||
* fileter for metadata
|
||||
* @param maxResults
|
||||
* limit
|
||||
* @param minScore
|
||||
* threshold
|
||||
* @return
|
||||
* records
|
||||
*/
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, Filter metaDatafilter, int maxResults, double minScore) {
|
||||
return astraDBCollection.findVector(referenceEmbedding.vector(), metaDatafilter, maxResults)
|
||||
.filter(r -> r.getSimilarity() >= minScore)
|
||||
.map(this::mapJsonResult)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping the output of the query to a {@link EmbeddingMatch}..
|
||||
*
|
||||
* @param jsonRes
|
||||
* returned object as Json
|
||||
* @return
|
||||
* embedding match as expected by langchain4j
|
||||
*/
|
||||
private EmbeddingMatch<TextSegment> mapJsonResult(JsonDocumentResult jsonRes) {
|
||||
Double score = (double) jsonRes.getSimilarity();
|
||||
String embeddingId = jsonRes.getId();
|
||||
Embedding embedding = Embedding.from(jsonRes.getVector());
|
||||
TextSegment embedded = null;
|
||||
Map<String, Object> properties = jsonRes.getData();
|
||||
if (properties!= null) {
|
||||
Object body = properties.get(KEY_ATTRIBUTES_BLOB);
|
||||
if (body != null) {
|
||||
Metadata metadata = new Metadata(properties.entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey,
|
||||
entry -> entry.getValue() == null ? "" : entry.getValue().toString()
|
||||
)));
|
||||
metadata.remove(KEY_ATTRIBUTES_BLOB);
|
||||
metadata.remove(KEY_SIMILARITY);
|
||||
embedded = new TextSegment(body.toString(), metadata);
|
||||
}
|
||||
}
|
||||
return new EmbeddingMatch<TextSegment>(score, embeddingId, embedding, embedded);
|
||||
}
|
||||
|
||||
/**
|
||||
* Map from LangChain4j record to AstraDB record.
|
||||
*
|
||||
* @param embedding
|
||||
* embedding (vector)
|
||||
* @param textSegment
|
||||
* text segment (text to encode)
|
||||
* @return
|
||||
* a json document
|
||||
*/
|
||||
private JsonDocument mapRecord(Embedding embedding, TextSegment textSegment) {
|
||||
JsonDocument record = new JsonDocument().vector(embedding.vector());
|
||||
if (textSegment != null) {
|
||||
record.put(KEY_ATTRIBUTES_BLOB, textSegment.text());
|
||||
textSegment.metadata().asMap().forEach(record::put);
|
||||
}
|
||||
return record;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
||||
/**
|
||||
* Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore.
|
||||
* Wrapping all arguments needed to initialize a store in a single object makes it easier to pass them around.
|
||||
* It also makes it easier to add new arguments in the future, without having to change the constructor of the store.
|
||||
* This is especially useful when the store is used in a pipeline, where the arguments are passed around multiple times.
|
||||
*
|
||||
* @see CassandraEmbeddingStore
|
||||
*/
|
||||
@Getter
|
||||
@Builder
|
||||
public class AstraDbEmbeddingConfiguration {
|
||||
|
||||
/**
|
||||
* Represents the Api Key to interact with Astra DB
|
||||
*
|
||||
* @see <a href="https://docs.datastax.com/en/astra/docs/manage-application-tokens.html">Astra DB Api Key</a>
|
||||
*/
|
||||
@NonNull
|
||||
private String token;
|
||||
|
||||
/**
|
||||
* Represents the unique identifier for your database.
|
||||
*/
|
||||
@NonNull
|
||||
private String databaseId;
|
||||
|
||||
/**
|
||||
* Represents the region where your database is hosted. A database can be deployed
|
||||
* in multiple regions at the same time, and you can choose the region that is closest to your users.
|
||||
* If a database has a single region, it will be picked for you.
|
||||
*/
|
||||
private String databaseRegion;
|
||||
|
||||
/**
|
||||
* Represents the workspace name where you create your tables. One database can hold multiple keyspaces.
|
||||
* Best practice is to provide a keyspace for each application.
|
||||
*/
|
||||
@NonNull
|
||||
protected String keyspace;
|
||||
|
||||
/**
|
||||
* Represents the name of the table.
|
||||
*/
|
||||
@NonNull
|
||||
protected String table;
|
||||
|
||||
/**
|
||||
* Represents the dimension of the vector used to save the embeddings.
|
||||
*/
|
||||
@NonNull
|
||||
protected Integer dimension;
|
||||
|
||||
/**
|
||||
* Initialize the builder.
|
||||
*
|
||||
* @return cassandra embedding configuration builder
|
||||
*/
|
||||
public static AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder builder() {
|
||||
return new AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Signature for the builder.
|
||||
*/
|
||||
public static class AstraDbEmbeddingConfigurationBuilder {
|
||||
}
|
||||
}
|
|
@ -1,111 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import com.datastax.astra.sdk.AstraClient;
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
/**
|
||||
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
|
||||
*
|
||||
* @see EmbeddingStore
|
||||
* @see MetadataVectorCassandraTable
|
||||
*/
|
||||
public class AstraDbEmbeddingStore extends CassandraEmbeddingStoreSupport {
|
||||
|
||||
/**
|
||||
* Build the store from the configuration.
|
||||
*
|
||||
* @param config configuration
|
||||
*/
|
||||
public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) {
|
||||
CqlSession cqlSession = AstraClient.builder()
|
||||
.withToken(config.getToken())
|
||||
.withCqlKeyspace(config.getKeyspace())
|
||||
.withDatabaseId(config.getDatabaseId())
|
||||
.withDatabaseRegion(config.getDatabaseRegion())
|
||||
.enableCql()
|
||||
.enableDownloadSecureConnectBundle()
|
||||
.build().cqlSession();
|
||||
embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension());
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Syntax Sugar Builder.
|
||||
*/
|
||||
public static class Builder {
|
||||
|
||||
/**
|
||||
* Configuration built with the builder
|
||||
*/
|
||||
private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf;
|
||||
|
||||
/**
|
||||
* Initialization
|
||||
*/
|
||||
public Builder() {
|
||||
conf = AstraDbEmbeddingConfiguration.builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating token.
|
||||
*
|
||||
* @param token token
|
||||
* @return current reference
|
||||
*/
|
||||
public Builder token(String token) {
|
||||
conf.token(token);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating token.
|
||||
*
|
||||
* @param databaseId database Identifier
|
||||
* @param databaseRegion database region
|
||||
* @return current reference
|
||||
*/
|
||||
public Builder database(String databaseId, String databaseRegion) {
|
||||
conf.databaseId(databaseId);
|
||||
conf.databaseRegion(databaseRegion);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating model dimension.
|
||||
*
|
||||
* @param dimension model dimension
|
||||
* @return current reference
|
||||
*/
|
||||
public Builder vectorDimension(int dimension) {
|
||||
conf.dimension(dimension);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating table name.
|
||||
*
|
||||
* @param keyspace keyspace name
|
||||
* @param table table name
|
||||
* @return current reference
|
||||
*/
|
||||
public Builder table(String keyspace, String table) {
|
||||
conf.keyspace(keyspace);
|
||||
conf.table(table);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Building the Store.
|
||||
*
|
||||
* @return store for Astra.
|
||||
*/
|
||||
public AstraDbEmbeddingStore build() {
|
||||
return new AstraDbEmbeddingStore(conf.build());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore.
|
||||
* Wrapping all arguments needed to initialize a store in a single object makes it easier to pass them around.
|
||||
* It also makes it easier to add new arguments in the future, without having to change the constructor of the store.
|
||||
* This is especially useful when the store is used in a pipeline, where the arguments are passed around multiple times.
|
||||
*
|
||||
* @see CassandraEmbeddingStore
|
||||
*/
|
||||
@Getter
|
||||
@Builder
|
||||
public class CassandraEmbeddingConfiguration {
|
||||
|
||||
/**
|
||||
* Default Cassandra Port.
|
||||
*/
|
||||
public static Integer DEFAULT_PORT = 9042;
|
||||
|
||||
// --- Connectivity Parameters ---
|
||||
|
||||
/**
|
||||
* Represents the cassandra Contact points.
|
||||
*/
|
||||
@NonNull
|
||||
private List<String> contactPoints;
|
||||
|
||||
/**
|
||||
* Represent the local data center.
|
||||
*/
|
||||
@NonNull
|
||||
private String localDataCenter;
|
||||
|
||||
/**
|
||||
* Connection Port
|
||||
*/
|
||||
@NonNull
|
||||
private Integer port;
|
||||
|
||||
/**
|
||||
* (Optional) Represents the username to connect to the database.
|
||||
*/
|
||||
private String userName;
|
||||
|
||||
/**
|
||||
* (Optional) Represents the password to connect to the database.
|
||||
*/
|
||||
private String password;
|
||||
|
||||
/**
|
||||
* Represents the workspace name where you create your tables. One database can hold multiple keyspaces.
|
||||
* Best practice is to provide a keyspace for each application.
|
||||
*/
|
||||
@NonNull
|
||||
protected String keyspace;
|
||||
|
||||
/**
|
||||
* Represents the name of the table.
|
||||
*/
|
||||
@NonNull
|
||||
protected String table;
|
||||
|
||||
/**
|
||||
* Represents the dimension of the model use to create the embeddings. The vector holding the embeddings
|
||||
* is a fixed size. The dimension of the vector is the dimension of the model used to create the embeddings.
|
||||
*/
|
||||
@NonNull
|
||||
protected Integer dimension;
|
||||
|
||||
/**
|
||||
* Initialize the builder.
|
||||
*
|
||||
* @return cassandra embedding configuration buildesr
|
||||
*/
|
||||
public static CassandraEmbeddingConfigurationBuilder builder() {
|
||||
return new CassandraEmbeddingConfigurationBuilder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Signature for the builder.
|
||||
*/
|
||||
public static class CassandraEmbeddingConfigurationBuilder {
|
||||
}
|
||||
}
|
|
@ -2,63 +2,166 @@ package dev.langchain4j.store.embedding.cassandra;
|
|||
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
|
||||
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
|
||||
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
||||
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
|
||||
import com.dtsx.astra.sdk.cassio.AnnQuery;
|
||||
import com.dtsx.astra.sdk.cassio.AnnResult;
|
||||
import com.dtsx.astra.sdk.cassio.CassIO;
|
||||
import com.dtsx.astra.sdk.cassio.MetadataVectorRecord;
|
||||
import com.dtsx.astra.sdk.cassio.MetadataVectorTable;
|
||||
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
|
||||
import com.dtsx.astra.sdk.utils.AstraEnvironment;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
/**
|
||||
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
|
||||
* Implementation of {@link EmbeddingStore} using Cassandra.
|
||||
*
|
||||
* @see EmbeddingStore
|
||||
* @see MetadataVectorCassandraTable
|
||||
* @see MetadataVectorTable
|
||||
*/
|
||||
public class CassandraEmbeddingStore extends CassandraEmbeddingStoreSupport {
|
||||
public class CassandraEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* Build the store from the configuration.
|
||||
*
|
||||
* @param config configuration
|
||||
* Represents an embedding table in Cassandra, it is a table with a vector column.
|
||||
*/
|
||||
public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) {
|
||||
CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config);
|
||||
createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace());
|
||||
sessionBuilder.withKeyspace(config.getKeyspace());
|
||||
this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(),
|
||||
config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS);
|
||||
protected MetadataVectorTable embeddingTable;
|
||||
|
||||
/**
|
||||
* Cassandra question.
|
||||
*/
|
||||
@Getter
|
||||
protected CqlSession cassandraSession;
|
||||
|
||||
/**
|
||||
* Embedding Store.
|
||||
*
|
||||
* @param session
|
||||
* cassandra Session
|
||||
* @param tableName
|
||||
* table name
|
||||
* @param dimension
|
||||
* dimension
|
||||
*/
|
||||
public CassandraEmbeddingStore(CqlSession session, String tableName, int dimension) {
|
||||
this(session, tableName, dimension, CassandraSimilarityMetric.COSINE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the cassandra session from the config. At the difference of adminSession there
|
||||
* a keyspace attached to it.
|
||||
* Embedding Store.
|
||||
*
|
||||
* @param config current configuration
|
||||
* @return cassandra session
|
||||
* @param session
|
||||
* cassandra Session
|
||||
* @param tableName
|
||||
* table name
|
||||
* @param dimension
|
||||
* dimension
|
||||
* @param metric
|
||||
* metric
|
||||
*/
|
||||
private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) {
|
||||
CqlSessionBuilder cqlSessionBuilder = CqlSession.builder();
|
||||
cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter());
|
||||
if (config.getUserName() != null && config.getPassword() != null) {
|
||||
cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword());
|
||||
public CassandraEmbeddingStore(CqlSession session, String tableName, int dimension, CassandraSimilarityMetric metric) {
|
||||
this.cassandraSession = session;
|
||||
this.embeddingTable = new MetadataVectorTable(session, session.getKeyspace().get().asInternal(), tableName, dimension, metric);
|
||||
embeddingTable.create();
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete the table.
|
||||
*/
|
||||
public void delete() {
|
||||
embeddingTable.delete();
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all rows.
|
||||
*/
|
||||
public void clear() {
|
||||
embeddingTable.clear();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
public static Integer DEFAULT_PORT = 9042;
|
||||
private List<String> contactPoints;
|
||||
private String localDataCenter;
|
||||
private Integer port = DEFAULT_PORT;
|
||||
private String userName;
|
||||
private String password;
|
||||
protected String keyspace;
|
||||
protected String table;
|
||||
protected Integer dimension;
|
||||
protected CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;
|
||||
|
||||
public Builder contactPoints(List<String> contactPoints) {
|
||||
this.contactPoints = contactPoints;
|
||||
return this;
|
||||
}
|
||||
config.getContactPoints().forEach(cp ->
|
||||
cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort())));
|
||||
return cqlSessionBuilder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the keyspace in cassandra Destination if not exist.
|
||||
*/
|
||||
private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) {
|
||||
try (CqlSession adminSession = cqlSessionBuilder.build()) {
|
||||
adminSession.execute(SchemaBuilder.createKeyspace(keyspace)
|
||||
.ifNotExists()
|
||||
.withSimpleStrategy(1)
|
||||
.withDurableWrites(true)
|
||||
.build());
|
||||
public Builder localDataCenter(String localDataCenter) {
|
||||
this.localDataCenter = localDataCenter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder port(Integer port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder userName(String userName) {
|
||||
this.userName = userName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder password(String password) {
|
||||
this.password = password;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder keyspace(String keyspace) {
|
||||
this.keyspace = keyspace;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder table(String table) {
|
||||
this.table = table;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder dimension(Integer dimension) {
|
||||
this.dimension = dimension;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder metric(CassandraSimilarityMetric metric) {
|
||||
this.metric = metric;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder() {
|
||||
}
|
||||
|
||||
public CassandraEmbeddingStore build() {
|
||||
CqlSessionBuilder builder = CqlSession.builder()
|
||||
.withKeyspace(keyspace)
|
||||
.withLocalDatacenter(localDataCenter);
|
||||
if (userName != null && password != null) {
|
||||
builder.withAuthCredentials(userName, password);
|
||||
}
|
||||
contactPoints.forEach(cp -> builder.addContactPoint(new InetSocketAddress(cp, port)));
|
||||
return new CassandraEmbeddingStore(builder.build(),table, dimension, metric);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,87 +169,216 @@ public class CassandraEmbeddingStore extends CassandraEmbeddingStoreSupport {
|
|||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Syntax Sugar Builder.
|
||||
*/
|
||||
public static class Builder {
|
||||
public static BuilderAstra builderAstra() {
|
||||
return new BuilderAstra();
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration built with the builder
|
||||
*/
|
||||
private final CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder conf;
|
||||
public static class BuilderAstra {
|
||||
private String token;
|
||||
private UUID dbId;
|
||||
private String tableName;
|
||||
private int dimension;
|
||||
private String keyspaceName = "default_keyspace";
|
||||
private String dbRegion = "us-east1";
|
||||
private CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;
|
||||
private AstraEnvironment env = AstraEnvironment.PROD;
|
||||
|
||||
/**
|
||||
* Initialization
|
||||
*/
|
||||
public Builder() {
|
||||
conf = CassandraEmbeddingConfiguration.builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating cassandra port.
|
||||
*
|
||||
* @param port port
|
||||
* @return current reference
|
||||
*/
|
||||
public CassandraEmbeddingStore.Builder port(int port) {
|
||||
conf.port(port);
|
||||
public BuilderAstra token(String token) {
|
||||
this.token = token;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating cassandra contact points.
|
||||
*
|
||||
* @param hosts port
|
||||
* @return current reference
|
||||
*/
|
||||
public CassandraEmbeddingStore.Builder contactPoints(String... hosts) {
|
||||
conf.contactPoints(Arrays.asList(hosts));
|
||||
public BuilderAstra env(AstraEnvironment env) {
|
||||
this.env = env;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating model dimension.
|
||||
*
|
||||
* @param dimension model dimension
|
||||
* @return current reference
|
||||
*/
|
||||
public CassandraEmbeddingStore.Builder vectorDimension(int dimension) {
|
||||
conf.dimension(dimension);
|
||||
public BuilderAstra databaseId(UUID dbId) {
|
||||
this.dbId = dbId;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating datacenter.
|
||||
*
|
||||
* @param dc datacenter
|
||||
* @return current reference
|
||||
*/
|
||||
public CassandraEmbeddingStore.Builder localDataCenter(String dc) {
|
||||
conf.localDataCenter(dc);
|
||||
public BuilderAstra databaseRegion(String dbRegion) {
|
||||
this.dbRegion = dbRegion;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populating table name.
|
||||
*
|
||||
* @param keyspace keyspace name
|
||||
* @param table table name
|
||||
* @return current reference
|
||||
*/
|
||||
public CassandraEmbeddingStore.Builder table(String keyspace, String table) {
|
||||
conf.keyspace(keyspace);
|
||||
conf.table(table);
|
||||
public BuilderAstra keyspace(String keyspaceName) {
|
||||
this.keyspaceName = keyspaceName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public BuilderAstra table(String tableName) {
|
||||
this.tableName = tableName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public BuilderAstra dimension(int dimension) {
|
||||
this.dimension = dimension;
|
||||
return this;
|
||||
}
|
||||
|
||||
public BuilderAstra metric(CassandraSimilarityMetric metric) {
|
||||
this.metric = metric;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Building the Store.
|
||||
*
|
||||
* @return store for Astra.
|
||||
*/
|
||||
public CassandraEmbeddingStore build() {
|
||||
return new CassandraEmbeddingStore(conf.build());
|
||||
CqlSession cqlSession = CassIO.init(token, dbId, dbRegion, keyspaceName, env);
|
||||
return new CassandraEmbeddingStore(cqlSession, tableName, dimension, metric);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new embedding to the store.
|
||||
* - the row id is generated
|
||||
* - text and metadata are not stored
|
||||
*
|
||||
* @param embedding representation of the list of floats
|
||||
* @return newly created row id
|
||||
*/
|
||||
@Override
|
||||
public String add(@NonNull Embedding embedding) {
|
||||
return add(embedding, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new embedding to the store.
|
||||
* - the row id is generated
|
||||
* - text and metadata coming from the text Segment
|
||||
*
|
||||
* @param embedding representation of the list of floats
|
||||
* @param textSegment text content and metadata
|
||||
* @return newly created row id
|
||||
*/
|
||||
@Override
|
||||
public String add(@NonNull Embedding embedding, TextSegment textSegment) {
|
||||
MetadataVectorRecord record = new MetadataVectorRecord(embedding.vectorAsList());
|
||||
if (textSegment != null) {
|
||||
record.setBody(textSegment.text());
|
||||
record.setMetadata(textSegment.metadata().asMap());
|
||||
}
|
||||
embeddingTable.put(record);
|
||||
return record.getRowId();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new embedding to the store.
|
||||
*
|
||||
* @param rowId the row id
|
||||
* @param embedding representation of the list of floats
|
||||
*/
|
||||
@Override
|
||||
public void add(@NonNull String rowId, @NonNull Embedding embedding) {
|
||||
embeddingTable.put(new MetadataVectorRecord(rowId, embedding.vectorAsList()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add multiple embeddings as a single action.
|
||||
*
|
||||
* @param embeddingList embeddings list
|
||||
* @return list of new row if (same order as the input)
|
||||
*/
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddingList) {
|
||||
return embeddingList.stream()
|
||||
.map(Embedding::vectorAsList)
|
||||
.map(MetadataVectorRecord::new)
|
||||
.peek(embeddingTable::putAsync)
|
||||
.map(MetadataVectorRecord::getRowId)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Add multiple embeddings as a single action.
|
||||
*
|
||||
* @param embeddingList embeddings
|
||||
* @param textSegmentList text segments
|
||||
* @return list of new row if (same order as the input)
|
||||
*/
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
|
||||
if (embeddingList == null || textSegmentList == null || embeddingList.size() != textSegmentList.size()) {
|
||||
throw new IllegalArgumentException("embeddingList and textSegmentList must not be null and have the same size");
|
||||
}
|
||||
// Looping on both list with an index
|
||||
List<String> ids = new ArrayList<>();
|
||||
for (int i = 0; i < embeddingList.size(); i++) {
|
||||
ids.add(add(embeddingList.get(i), textSegmentList.get(i)));
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for relevant.
|
||||
*
|
||||
* @param embedding current embeddings
|
||||
* @param maxResults max number of result
|
||||
* @param minScore threshold
|
||||
* @return list of matching elements
|
||||
*/
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore) {
|
||||
return embeddingTable
|
||||
.similaritySearch(AnnQuery.builder()
|
||||
.embeddings(embedding.vectorAsList())
|
||||
.recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
|
||||
.threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")))
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.build())
|
||||
.stream()
|
||||
.map(CassandraEmbeddingStore::mapSearchResult)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Search result coming from Astra.
|
||||
*
|
||||
* @param record current record
|
||||
* @return search result
|
||||
*/
|
||||
private static EmbeddingMatch<TextSegment> mapSearchResult(AnnResult<MetadataVectorRecord> record) {
|
||||
|
||||
TextSegment embedded = null;
|
||||
String body = record.getEmbedded().getBody();
|
||||
if (body != null
|
||||
&& !body.isEmpty()
|
||||
&& record.getEmbedded().getMetadata() != null) {
|
||||
embedded = TextSegment.from(record.getEmbedded().getBody(),
|
||||
new Metadata(record.getEmbedded().getMetadata()));
|
||||
}
|
||||
return new EmbeddingMatch<>(
|
||||
// Score
|
||||
RelevanceScore.fromCosineSimilarity(record.getSimilarity()),
|
||||
// EmbeddingId : unique identifier
|
||||
record.getEmbedded().getRowId(),
|
||||
// Embeddings vector
|
||||
Embedding.from(record.getEmbedded().getVector()),
|
||||
// Text segment and metadata
|
||||
embedded);
|
||||
}
|
||||
|
||||
/**
|
||||
* Similarity Search ANN based on the embedding.
|
||||
*
|
||||
* @param embedding vector
|
||||
* @param maxResults max number of results
|
||||
* @param minScore score minScore
|
||||
* @param metadata map key-value to build a metadata filter
|
||||
* @return list of matching results
|
||||
*/
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore, Metadata metadata) {
|
||||
AnnQuery.AnnQueryBuilder builder = AnnQuery.builder()
|
||||
.embeddings(embedding.vectorAsList())
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
|
||||
.threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")));
|
||||
if (metadata != null) {
|
||||
builder.metaData(metadata.asMap());
|
||||
}
|
||||
return embeddingTable
|
||||
.similaritySearch(builder.build())
|
||||
.stream()
|
||||
.map(CassandraEmbeddingStore::mapSearchResult)
|
||||
.collect(toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,179 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
||||
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
|
||||
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery;
|
||||
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery.SimilaritySearchQueryBuilder;
|
||||
import com.dtsx.astra.sdk.cassio.SimilaritySearchResult;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
/**
|
||||
* Support for CassandraEmbeddingStore with and Without Astra.
|
||||
*/
|
||||
@Getter
|
||||
abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* Represents an embedding table in Cassandra, it is a table with a vector column.
|
||||
*/
|
||||
protected MetadataVectorCassandraTable embeddingTable;
|
||||
|
||||
/**
|
||||
* Add a new embedding to the store.
|
||||
* - the row id is generated
|
||||
* - text and metadata are not stored
|
||||
*
|
||||
* @param embedding representation of the list of floats
|
||||
* @return newly created row id
|
||||
*/
|
||||
@Override
|
||||
public String add(@NonNull Embedding embedding) {
|
||||
return add(embedding, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new embedding to the store.
|
||||
* - the row id is generated
|
||||
* - text and metadata coming from the text Segment
|
||||
*
|
||||
* @param embedding representation of the list of floats
|
||||
* @param textSegment text content and metadata
|
||||
* @return newly created row id
|
||||
*/
|
||||
@Override
|
||||
public String add(@NonNull Embedding embedding, TextSegment textSegment) {
|
||||
MetadataVectorCassandraTable.Record record = new MetadataVectorCassandraTable.Record(embedding.vectorAsList());
|
||||
if (textSegment != null) {
|
||||
record.setBody(textSegment.text());
|
||||
record.setMetadata(textSegment.metadata().asMap());
|
||||
}
|
||||
embeddingTable.put(record);
|
||||
return record.getRowId();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new embedding to the store.
|
||||
*
|
||||
* @param rowId the row id
|
||||
* @param embedding representation of the list of floats
|
||||
*/
|
||||
@Override
|
||||
public void add(@NonNull String rowId, @NonNull Embedding embedding) {
|
||||
embeddingTable.put(new MetadataVectorCassandraTable.Record(rowId, embedding.vectorAsList()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add multiple embeddings as a single action.
|
||||
*
|
||||
* @param embeddingList embeddings list
|
||||
* @return list of new row if (same order as the input)
|
||||
*/
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddingList) {
|
||||
return embeddingList.stream()
|
||||
.map(Embedding::vectorAsList)
|
||||
.map(MetadataVectorCassandraTable.Record::new)
|
||||
.peek(embeddingTable::putAsync)
|
||||
.map(MetadataVectorCassandraTable.Record::getRowId)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Add multiple embeddings as a single action.
|
||||
*
|
||||
* @param embeddingList embeddings
|
||||
* @param textSegmentList text segments
|
||||
* @return list of new row if (same order as the input)
|
||||
*/
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
|
||||
if (embeddingList == null || textSegmentList == null || embeddingList.size() != textSegmentList.size()) {
|
||||
throw new IllegalArgumentException("embeddingList and textSegmentList must not be null and have the same size");
|
||||
}
|
||||
// Looping on both list with an index
|
||||
List<String> ids = new ArrayList<>();
|
||||
for (int i = 0; i < embeddingList.size(); i++) {
|
||||
ids.add(add(embeddingList.get(i), textSegmentList.get(i)));
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for relevant.
|
||||
*
|
||||
* @param embedding current embeddings
|
||||
* @param maxResults max number of result
|
||||
* @param minScore threshold
|
||||
* @return list of matching elements
|
||||
*/
|
||||
@Override
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore) {
|
||||
return embeddingTable
|
||||
.similaritySearch(SimilaritySearchQuery.builder()
|
||||
.embeddings(embedding.vectorAsList())
|
||||
.recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
|
||||
.threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")))
|
||||
.distance(SimilarityMetric.COS)
|
||||
.build())
|
||||
.stream()
|
||||
.map(CassandraEmbeddingStoreSupport::mapSearchResult)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Search result coming from Astra.
|
||||
*
|
||||
* @param record current record
|
||||
* @return search result
|
||||
*/
|
||||
private static EmbeddingMatch<TextSegment> mapSearchResult(SimilaritySearchResult<MetadataVectorCassandraTable.Record> record) {
|
||||
return new EmbeddingMatch<>(
|
||||
// Score
|
||||
RelevanceScore.fromCosineSimilarity(record.getSimilarity()),
|
||||
// EmbeddingId : unique identifier
|
||||
record.getEmbedded().getRowId(),
|
||||
// Embeddings vector
|
||||
Embedding.from(record.getEmbedded().getVector()),
|
||||
// Text segment and metadata
|
||||
TextSegment.from(record.getEmbedded().getBody(), new Metadata(record.getEmbedded().getMetadata())));
|
||||
}
|
||||
|
||||
/**
|
||||
* Similarity Search ANN based on the embedding.
|
||||
*
|
||||
* @param embedding vector
|
||||
* @param maxResults max number of results
|
||||
* @param minScore score minScore
|
||||
* @param metadata map key-value to build a metadata filter
|
||||
* @return list of matching results
|
||||
*/
|
||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore, Metadata metadata) {
|
||||
SimilaritySearchQueryBuilder builder = SimilaritySearchQuery.builder()
|
||||
.embeddings(embedding.vectorAsList())
|
||||
.recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
|
||||
.threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")));
|
||||
if (metadata != null) {
|
||||
builder.metaData(metadata.asMap());
|
||||
}
|
||||
return embeddingTable
|
||||
.similaritySearch(builder.build())
|
||||
.stream()
|
||||
.map(CassandraEmbeddingStoreSupport::mapSearchResult)
|
||||
.collect(toList());
|
||||
}
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.datastax.astra.sdk.AstraClient;
|
||||
|
||||
/**
|
||||
* AstraDb is a version of Cassandra running in Saas Mode.
|
||||
* <p>
|
||||
* The initialization of the CQLSession will be done through an AstraClient
|
||||
*/
|
||||
public class AstraDbChatMemoryStore extends CassandraChatMemoryStore {
|
||||
|
||||
/**
|
||||
* Constructor with default table name.
|
||||
*
|
||||
* @param token token
|
||||
* @param dbId database identifier
|
||||
* @param dbRegion database region
|
||||
* @param keyspaceName keyspace name
|
||||
*/
|
||||
public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName) {
|
||||
this(token, dbId, dbRegion, keyspaceName, DEFAULT_TABLE_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor with explicit table name.
|
||||
*
|
||||
* @param token token
|
||||
* @param dbId database identifier
|
||||
* @param dbRegion database region
|
||||
* @param keyspaceName keyspace name
|
||||
* @param tableName table name
|
||||
*/
|
||||
public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName, String tableName) {
|
||||
super(AstraClient.builder()
|
||||
.withToken(token)
|
||||
.withCqlKeyspace(keyspaceName)
|
||||
.withDatabaseId(dbId)
|
||||
.withDatabaseRegion(dbRegion)
|
||||
.enableCql()
|
||||
.enableDownloadSecureConnectBundle()
|
||||
.build().cqlSession(), keyspaceName, tableName);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,12 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
|
||||
import com.datastax.oss.driver.api.core.uuid.Uuids;
|
||||
import com.dtsx.astra.sdk.cassio.ClusteredCassandraTable;
|
||||
import com.dtsx.astra.sdk.cassio.CassIO;
|
||||
import com.dtsx.astra.sdk.cassio.ClusteredRecord;
|
||||
import com.dtsx.astra.sdk.cassio.ClusteredTable;
|
||||
import com.dtsx.astra.sdk.utils.AstraEnvironment;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ChatMessageDeserializer;
|
||||
import dev.langchain4j.data.message.ChatMessageSerializer;
|
||||
|
@ -10,11 +14,11 @@ import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
|||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
/**
|
||||
|
@ -35,27 +39,56 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
|||
/**
|
||||
* Message Table.
|
||||
*/
|
||||
private final ClusteredCassandraTable messageTable;
|
||||
private final ClusteredTable messageTable;
|
||||
|
||||
/**
|
||||
* Constructor for message store
|
||||
*
|
||||
* @param session cassandra session
|
||||
* @param keyspaceName keyspace name
|
||||
* @param tableName table name
|
||||
*/
|
||||
public CassandraChatMemoryStore(CqlSession session, String keyspaceName, String tableName) {
|
||||
messageTable = new ClusteredCassandraTable(session, keyspaceName, tableName);
|
||||
public CassandraChatMemoryStore(CqlSession session) {
|
||||
this(session, DEFAULT_TABLE_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor for message store
|
||||
*
|
||||
* @param session cassandra session
|
||||
* @param keyspaceName keyspace name
|
||||
* @param tableName table name
|
||||
*/
|
||||
public CassandraChatMemoryStore(CqlSession session, String keyspaceName) {
|
||||
messageTable = new ClusteredCassandraTable(session, keyspaceName, DEFAULT_TABLE_NAME);
|
||||
public CassandraChatMemoryStore(CqlSession session, String tableName) {
|
||||
messageTable = new ClusteredTable(session, session.getKeyspace().get().asInternal(), tableName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the table if not exist.
|
||||
*/
|
||||
public void create() {
|
||||
messageTable.create();
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete the table.
|
||||
*/
|
||||
public void delete() {
|
||||
messageTable.delete();
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all rows.
|
||||
*/
|
||||
public void clear() {
|
||||
messageTable.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Access the cassandra session for fined grained operation.
|
||||
*
|
||||
* @return
|
||||
* current cassandra session
|
||||
*/
|
||||
public CqlSession getCassandraSession() {
|
||||
return messageTable.getCqlSession();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -67,7 +100,7 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
|||
* RATIONAL:
|
||||
* In the cassandra table the order is explicitly put to DESC with
|
||||
* latest to come first (for long conversation for instance). Here we ask
|
||||
* for the full history. Instead of changing the multi purpose table
|
||||
* for the full history. Instead of changing the multipurpose table
|
||||
* we reverse the list.
|
||||
*/
|
||||
List<ChatMessage> latestFirstList = messageTable
|
||||
|
@ -99,12 +132,12 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
|||
}
|
||||
|
||||
/**
|
||||
* Unmarshalling Cassandra row as a Message with proper sub-type.
|
||||
* Unmarshalling Cassandra row as a Message with proper subtype.
|
||||
*
|
||||
* @param record cassandra record
|
||||
* @return chat message
|
||||
*/
|
||||
private ChatMessage toChatMessage(@NonNull Record record) {
|
||||
private ChatMessage toChatMessage(@NonNull ClusteredRecord record) {
|
||||
try {
|
||||
return ChatMessageDeserializer.messageFromJson(record.getBody());
|
||||
} catch (Exception e) {
|
||||
|
@ -120,9 +153,9 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
|||
* @param chatMessage chat message
|
||||
* @return cassandra row.
|
||||
*/
|
||||
private Record fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) {
|
||||
private ClusteredRecord fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) {
|
||||
try {
|
||||
Record record = new Record();
|
||||
ClusteredRecord record = new ClusteredRecord();
|
||||
record.setRowId(Uuids.timeBased());
|
||||
record.setPartitionId(memoryId);
|
||||
record.setBody(ChatMessageSerializer.messageToJson(chatMessage));
|
||||
|
@ -139,4 +172,117 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
|||
}
|
||||
return (String) memoryId;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
public static Integer DEFAULT_PORT = 9042;
|
||||
private List<String> contactPoints;
|
||||
private String localDataCenter;
|
||||
private Integer port = DEFAULT_PORT;
|
||||
private String userName;
|
||||
private String password;
|
||||
protected String keyspace;
|
||||
protected String table = DEFAULT_TABLE_NAME;
|
||||
|
||||
public CassandraChatMemoryStore.Builder contactPoints(List<String> contactPoints) {
|
||||
this.contactPoints = contactPoints;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.Builder localDataCenter(String localDataCenter) {
|
||||
this.localDataCenter = localDataCenter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.Builder port(Integer port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.Builder userName(String userName) {
|
||||
this.userName = userName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.Builder password(String password) {
|
||||
this.password = password;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.Builder keyspace(String keyspace) {
|
||||
this.keyspace = keyspace;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.Builder table(String table) {
|
||||
this.table = table;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder() {
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore build() {
|
||||
CqlSessionBuilder builder = CqlSession.builder()
|
||||
.withKeyspace(keyspace)
|
||||
.withLocalDatacenter(localDataCenter);
|
||||
if (userName != null && password != null) {
|
||||
builder.withAuthCredentials(userName, password);
|
||||
}
|
||||
contactPoints.forEach(cp -> builder.addContactPoint(new InetSocketAddress(cp, port)));
|
||||
return new CassandraChatMemoryStore(builder.build(), table);
|
||||
}
|
||||
}
|
||||
|
||||
public static CassandraChatMemoryStore.Builder builder() {
|
||||
return new CassandraChatMemoryStore.Builder();
|
||||
}
|
||||
|
||||
public static CassandraChatMemoryStore.BuilderAstra builderAstra() {
|
||||
return new CassandraChatMemoryStore.BuilderAstra();
|
||||
}
|
||||
|
||||
public static class BuilderAstra {
|
||||
private String token;
|
||||
private UUID dbId;
|
||||
private String tableName = DEFAULT_TABLE_NAME;
|
||||
private String keyspaceName = "default_keyspace";
|
||||
private String dbRegion = "us-east1";
|
||||
private AstraEnvironment env = AstraEnvironment.PROD;
|
||||
|
||||
public BuilderAstra token(String token) {
|
||||
this.token = token;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.BuilderAstra databaseId(UUID dbId) {
|
||||
this.dbId = dbId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.BuilderAstra env(AstraEnvironment env) {
|
||||
this.env = env;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.BuilderAstra databaseRegion(String dbRegion) {
|
||||
this.dbRegion = dbRegion;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.BuilderAstra keyspace(String keyspaceName) {
|
||||
this.keyspaceName = keyspaceName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore.BuilderAstra table(String tableName) {
|
||||
this.tableName = tableName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CassandraChatMemoryStore build() {
|
||||
CqlSession cqlSession = CassIO.init(token, dbId, dbRegion, keyspaceName, env);
|
||||
return new CassandraChatMemoryStore(cqlSession, tableName);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package dev.langchain4j.store.embedding.astradb;
|
||||
|
||||
import com.dtsx.astra.sdk.AstraDB;
|
||||
import com.dtsx.astra.sdk.AstraDBAdmin;
|
||||
import com.dtsx.astra.sdk.AstraDBCollection;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiModelName;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import io.stargate.sdk.data.domain.SimilarityMetric;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.MethodOrderer;
|
||||
import org.junit.jupiter.api.TestMethodOrder;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Disabled("AstraDB is not available in the CI")
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||
@Slf4j
|
||||
class AstraDbEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
|
||||
static final String TEST_DB = "test_langchain4j";
|
||||
static final String TEST_COLLECTION = "test_collection";
|
||||
static AstraDbEmbeddingStore embeddingStore;
|
||||
static EmbeddingModel embeddingModel;
|
||||
|
||||
static UUID dbId;
|
||||
static AstraDB db;
|
||||
|
||||
@BeforeAll
|
||||
public static void initStoreForTests() {
|
||||
AstraDBAdmin astraDBAdminClient = new AstraDBAdmin(getAstraToken());
|
||||
dbId = astraDBAdminClient.createDatabase(TEST_DB);
|
||||
assertNotNull(dbId);
|
||||
log.info("[init] - Database exists id={}", dbId);
|
||||
|
||||
// Select the Database as working object
|
||||
db = astraDBAdminClient.database(dbId);
|
||||
assertNotNull(db);
|
||||
|
||||
AstraDBCollection collection =
|
||||
db.createCollection(TEST_COLLECTION, 1536, SimilarityMetric.cosine);
|
||||
log.info("[init] - Collection create name={}", TEST_COLLECTION);
|
||||
|
||||
// Creating the store (and collection) if not exists
|
||||
embeddingStore = new AstraDbEmbeddingStore(collection);
|
||||
log.info("[init] - Embedding Store initialized");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
embeddingStore.clear();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
if (embeddingModel == null) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName(OpenAiModelName.TEXT_EMBEDDING_ADA_002)
|
||||
.build();
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
void testAddEmbeddingAndFindRelevant() {
|
||||
|
||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||
String id = embeddingStore.add(embedding, textSegment);
|
||||
assertTrue(id != null && !id.isEmpty());
|
||||
|
||||
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(refereceEmbedding, 1);
|
||||
assertEquals(1, embeddingMatches.size());
|
||||
|
||||
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* The configuration objects include a few validation rules.
|
||||
*/
|
||||
public class AstraDbEmbeddingConfigurationTest {
|
||||
|
||||
@Test
|
||||
public void should_build_configuration_test() {
|
||||
AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder()
|
||||
.token("token")
|
||||
.databaseId("dbId")
|
||||
.databaseRegion("dbRegion")
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.table("table")
|
||||
.build();
|
||||
assertNotNull(config);
|
||||
assertNotNull(config.getToken());
|
||||
assertNotNull(config.getDatabaseId());
|
||||
assertNotNull(config.getDatabaseRegion());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_table_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> AstraDbEmbeddingConfiguration.builder()
|
||||
.token("token")
|
||||
.databaseId("dbId")
|
||||
.databaseRegion("dbRegion")
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.build());
|
||||
assertEquals("table is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_keyspace_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> AstraDbEmbeddingConfiguration.builder()
|
||||
.token("token")
|
||||
.databaseId("dbId")
|
||||
.databaseRegion("dbRegion")
|
||||
.table("ks")
|
||||
.dimension(20)
|
||||
.build());
|
||||
assertEquals("keyspace is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_dimension_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> AstraDbEmbeddingConfiguration.builder()
|
||||
.token("token")
|
||||
.databaseId("dbId")
|
||||
.databaseRegion("dbRegion")
|
||||
.table("ks")
|
||||
.keyspace("ks")
|
||||
.build());
|
||||
assertEquals("dimension is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_token_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> AstraDbEmbeddingConfiguration.builder()
|
||||
.databaseId("dbId")
|
||||
.databaseRegion("dbRegion")
|
||||
.table("ks")
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.build());
|
||||
assertEquals("token is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_database_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> AstraDbEmbeddingConfiguration.builder()
|
||||
.token("token")
|
||||
.table("ks")
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.build());
|
||||
assertEquals("databaseId is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import com.datastax.astra.sdk.AstraClient;
|
||||
import com.dtsx.astra.sdk.utils.TestUtils;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* Testing implementation of Embedding Store using AstraDB.
|
||||
*/
|
||||
class AstraDbEmbeddingStoreIT {
|
||||
|
||||
private static final String TEST_KEYSPACE = "langchain4j";
|
||||
private static final String TEST_INDEX = "test_embedding_store";
|
||||
|
||||
/**
|
||||
* We want to trigger the test only if the expected variable
|
||||
* is present.
|
||||
*/
|
||||
@Test
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
void testAddEmbeddingAndFindRelevant() {
|
||||
String astraToken = getAstraToken();
|
||||
String databaseId = setupDatabase("langchain4j", TEST_KEYSPACE);
|
||||
|
||||
// Flush Table for test to be idempotent
|
||||
truncateTable(databaseId, TEST_KEYSPACE, TEST_INDEX);
|
||||
|
||||
// Create the Store with the builder
|
||||
AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration
|
||||
.builder()
|
||||
.token(astraToken)
|
||||
.databaseId(databaseId)
|
||||
.databaseRegion(TestUtils.TEST_REGION)
|
||||
.keyspace(TEST_KEYSPACE)
|
||||
.table(TEST_INDEX)
|
||||
.dimension(11)
|
||||
.build());
|
||||
|
||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||
String id = astraDbEmbeddingStore.add(embedding, textSegment);
|
||||
assertTrue(id != null && !id.isEmpty());
|
||||
|
||||
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = astraDbEmbeddingStore.findRelevant(refereceEmbedding, 10);
|
||||
assertEquals(1, embeddingMatches.size());
|
||||
|
||||
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||
}
|
||||
|
||||
private void truncateTable(String databaseId, String keyspace, String table) {
|
||||
try (AstraClient astraClient = AstraClient.builder()
|
||||
.withToken(getAstraToken())
|
||||
.withCqlKeyspace(keyspace)
|
||||
.withDatabaseId(databaseId)
|
||||
.withDatabaseRegion(TestUtils.TEST_REGION)
|
||||
.enableCql()
|
||||
.enableDownloadSecureConnectBundle()
|
||||
.build()) {
|
||||
astraClient.cqlSession()
|
||||
.execute("TRUNCATE TABLE " + table);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.DEFAULT_PORT;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class CassandraEmbeddingConfigurationTest {
|
||||
|
||||
@Test
|
||||
public void should_build_configuration_test() {
|
||||
CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder()
|
||||
.contactPoints(singletonList("localhost"))
|
||||
.port(DEFAULT_PORT)
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.table("table")
|
||||
.localDataCenter("dc1")
|
||||
.build();
|
||||
assertNotNull(config);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_datacenter_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> CassandraEmbeddingConfiguration.builder()
|
||||
.contactPoints(singletonList("localhost"))
|
||||
.port(DEFAULT_PORT)
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.table("table")
|
||||
.build());
|
||||
assertEquals("localDataCenter is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_table_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> CassandraEmbeddingConfiguration.builder()
|
||||
.contactPoints(singletonList("localhost"))
|
||||
.port(DEFAULT_PORT)
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.localDataCenter("dc1")
|
||||
.build());
|
||||
assertEquals("table is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_keyspace_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> CassandraEmbeddingConfiguration.builder()
|
||||
.contactPoints(singletonList("localhost"))
|
||||
.port(DEFAULT_PORT)
|
||||
.table("ks")
|
||||
.dimension(20)
|
||||
.localDataCenter("dc1")
|
||||
.build());
|
||||
assertEquals("keyspace is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_dimension_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> CassandraEmbeddingConfiguration.builder()
|
||||
.contactPoints(singletonList("localhost"))
|
||||
.port(DEFAULT_PORT)
|
||||
.table("ks")
|
||||
.keyspace("ks")
|
||||
.localDataCenter("dc1")
|
||||
.build());
|
||||
assertEquals("dimension is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void should_error_if_no_contact_points_test() {
|
||||
// Table is required
|
||||
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||
() -> CassandraEmbeddingConfiguration.builder()
|
||||
.port(DEFAULT_PORT)
|
||||
.table("ks")
|
||||
.keyspace("ks")
|
||||
.dimension(20)
|
||||
.localDataCenter("dc1")
|
||||
.build());
|
||||
assertEquals("contactPoints is marked non-null but is null", exception.getMessage());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import com.dtsx.astra.sdk.AstraDBAdmin;
|
||||
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
|
||||
/**
|
||||
* Integration test where Cassandra is running in AstraDB (dbaas).
|
||||
*/
|
||||
@Disabled("AstraDB is not available in the CI")
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
class CassandraEmbeddingStoreAstraIT extends CassandraEmbeddingStoreIT {
|
||||
|
||||
/**
|
||||
* Initializing the embedding store to work with Saas ASTRA DB.
|
||||
*
|
||||
* @return
|
||||
* embedding store.
|
||||
*/
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
if (embeddingStore == null) {
|
||||
// Create if not exists
|
||||
UUID dbId = new AstraDBAdmin((getAstraToken())).createDatabase("test_langchain4j");
|
||||
embeddingStore = CassandraEmbeddingStore.builderAstra()
|
||||
.token(getAstraToken())
|
||||
.databaseId(dbId)
|
||||
.databaseRegion(TEST_REGION)
|
||||
.keyspace(KEYSPACE)
|
||||
.table(TEST_INDEX)
|
||||
.dimension(embeddingModelDimension()) // openai model
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.build();
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.testcontainers.DockerClientFactory;
|
||||
import org.testcontainers.containers.CassandraContainer;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.Collections;
|
||||
|
||||
/**
|
||||
* Work with Cassandra Embedding Store.
|
||||
*/
|
||||
@Disabled("No Docker in the CI")
|
||||
@Testcontainers
|
||||
class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
|
||||
|
||||
static final String CASSANDRA_IMAGE = "cassandra:5.0";
|
||||
static final String DATACENTER = "datacenter1";
|
||||
static final String CLUSTER = "langchain4j";
|
||||
static CassandraContainer<?> cassandraContainer;
|
||||
|
||||
/**
|
||||
* Check Docker is installed and running on host
|
||||
*/
|
||||
@BeforeAll
|
||||
static void ensureDockerIsRunning() {
|
||||
DockerClientFactory.instance().client();
|
||||
if (cassandraContainer == null) {
|
||||
cassandraContainer = new CassandraContainer<>(
|
||||
DockerImageName.parse(CASSANDRA_IMAGE))
|
||||
.withEnv("CLUSTER_NAME", CLUSTER)
|
||||
.withEnv("DC", DATACENTER);
|
||||
cassandraContainer.start();
|
||||
|
||||
// Part of Database Creation, creating keyspace
|
||||
final InetSocketAddress contactPoint = cassandraContainer.getContactPoint();
|
||||
CqlSession.builder()
|
||||
.addContactPoint(contactPoint)
|
||||
.withLocalDatacenter(DATACENTER)
|
||||
.build().execute(
|
||||
"CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE +
|
||||
" WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop Cassandra Node
|
||||
*/
|
||||
@AfterAll
|
||||
static void afterTests() throws Exception {
|
||||
cassandraContainer.stop();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
final InetSocketAddress contactPoint = cassandraContainer.getContactPoint();
|
||||
if (embeddingStore == null) {
|
||||
embeddingStore = CassandraEmbeddingStore.builder()
|
||||
.contactPoints(Collections.singletonList(contactPoint.getHostName()))
|
||||
.port(contactPoint.getPort())
|
||||
.localDataCenter(DATACENTER)
|
||||
.keyspace(KEYSPACE)
|
||||
.table(TEST_INDEX)
|
||||
.dimension(embeddingModelDimension())
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.build();
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
}
|
|
@ -3,51 +3,250 @@ package dev.langchain4j.store.embedding.cassandra;
|
|||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiModelName;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.MethodOrderer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestMethodOrder;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* Work with Cassandra Embedding Store.
|
||||
*/
|
||||
class CassandraEmbeddingStoreIT {
|
||||
@Slf4j
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
|
||||
protected static final String KEYSPACE = "langchain4j";
|
||||
|
||||
protected static final String TEST_INDEX = "test_embedding_store";
|
||||
|
||||
CassandraEmbeddingStore embeddingStore;
|
||||
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName(OpenAiModelName.TEXT_EMBEDDING_ADA_002)
|
||||
.timeout(Duration.ofSeconds(15))
|
||||
.build();
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
protected int embeddingModelDimension() {
|
||||
return 1536;
|
||||
}
|
||||
|
||||
/**
|
||||
* It is required to clean the repository in between tests
|
||||
*/
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
((CassandraEmbeddingStore) embeddingStore()).clear();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void awaitUntilPersisted() {
|
||||
try {
|
||||
Thread.sleep(1000);
|
||||
} catch(Exception e) {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled("To run this test, you must have a local Cassandra instance, a docker-compose is provided")
|
||||
public void testAddEmbeddingAndFindRelevant() {
|
||||
|
||||
CassandraEmbeddingStore cassandraEmbeddingStore = initStore();
|
||||
|
||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||
String id = cassandraEmbeddingStore.add(embedding, textSegment);
|
||||
void should_retrieve_inserted_vector_by_ann() {
|
||||
String sourceSentence = "Testing is doubting !";
|
||||
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
|
||||
TextSegment sourceTextSegment = TextSegment.from(sourceSentence);
|
||||
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
|
||||
assertTrue(id != null && !id.isEmpty());
|
||||
|
||||
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 1);
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(sourceEmbedding, 10);
|
||||
assertEquals(1, embeddingMatches.size());
|
||||
|
||||
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||
assertThat(embeddingMatch.embedding()).isEqualTo(sourceEmbedding);
|
||||
assertThat(embeddingMatch.embedded()).isEqualTo(sourceTextSegment);
|
||||
}
|
||||
|
||||
private CassandraEmbeddingStore initStore() {
|
||||
return CassandraEmbeddingStore.builder()
|
||||
.contactPoints("127.0.0.1")
|
||||
.port(9042)
|
||||
.localDataCenter("datacenter1")
|
||||
.table("langchain4j", "table_" + randomUUID().replace("-", ""))
|
||||
.vectorDimension(11)
|
||||
.build();
|
||||
@Test
|
||||
void should_retrieve_inserted_vector_by_ann_and_metadata() {
|
||||
String sourceSentence = "In GOD we trust, everything else we test!";
|
||||
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
|
||||
TextSegment sourceTextSegment = TextSegment.from(sourceSentence, new Metadata()
|
||||
.add("user", "GOD")
|
||||
.add("test", "false"));
|
||||
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
|
||||
assertTrue(id != null && !id.isEmpty());
|
||||
|
||||
// Should be found with no filter
|
||||
List<EmbeddingMatch<TextSegment>> matchesAnnOnly = embeddingStore
|
||||
.findRelevant(sourceEmbedding, 10);
|
||||
assertEquals(1, matchesAnnOnly.size());
|
||||
|
||||
// Should retrieve if user is god
|
||||
List<EmbeddingMatch<TextSegment>> matchesGod = embeddingStore
|
||||
.findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "GOD"));
|
||||
assertEquals(1, matchesGod.size());
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matchesJohn = embeddingStore
|
||||
.findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "JOHN"));
|
||||
assertEquals(0, matchesJohn.size());
|
||||
}
|
||||
|
||||
// metrics returned are 1.95% off we updated to "withPercentage(2)"
|
||||
|
||||
@Test
|
||||
void should_return_correct_score() {
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
String id = embeddingStore().add(embedding);
|
||||
assertThat(id).isNotBlank();
|
||||
Embedding referenceEmbedding = embeddingModel().embed("hi").content();
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(referenceEmbedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_find_with_min_score() {
|
||||
String firstId = randomUUID();
|
||||
Embedding firstEmbedding = embeddingModel().embed("hello").content();
|
||||
embeddingStore().add(firstId, firstEmbedding);
|
||||
String secondId = randomUUID();
|
||||
Embedding secondEmbedding = embeddingModel().embed("hi").content();
|
||||
embeddingStore().add(secondId, secondEmbedding);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() - 0.01
|
||||
);
|
||||
assertThat(relevant2).hasSize(2);
|
||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score()
|
||||
);
|
||||
assertThat(relevant3).hasSize(2);
|
||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() + 0.01
|
||||
);
|
||||
assertThat(relevant4).hasSize(1);
|
||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings_with_segments() {
|
||||
|
||||
TextSegment firstSegment = TextSegment.from("hello");
|
||||
Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content();
|
||||
|
||||
TextSegment secondSegment = TextSegment.from("hi");
|
||||
Embedding secondEmbedding = embeddingModel().embed(secondSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore().addAll(
|
||||
asList(firstEmbedding, secondEmbedding),
|
||||
asList(firstSegment, secondSegment)
|
||||
);
|
||||
assertThat(ids).hasSize(2);
|
||||
assertThat(ids.get(0)).isNotBlank();
|
||||
assertThat(ids.get(1)).isNotBlank();
|
||||
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings() {
|
||||
|
||||
Embedding firstEmbedding = embeddingModel().embed("hello").content();
|
||||
Embedding secondEmbedding = embeddingModel().embed("hi").content();
|
||||
|
||||
List<String> ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding));
|
||||
assertThat(ids).hasSize(2);
|
||||
assertThat(ids.get(0)).isNotBlank();
|
||||
assertThat(ids.get(1)).isNotBlank();
|
||||
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(2));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isNull();
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isNull();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.dtsx.astra.sdk.AstraDBAdmin;
|
||||
import com.dtsx.astra.sdk.db.domain.CloudProviderType;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
||||
/**
|
||||
* Test Cassandra Chat Memory Store with a Saas DB.
|
||||
*/
|
||||
@Disabled("AstraDB is not available in the CI")
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
class CassandraChatMemoryStoreAstraIT extends CassandraChatMemoryStoreTestSupport {
|
||||
|
||||
static final String DB = "test_langchain4j";
|
||||
static String token;
|
||||
static UUID dbId;
|
||||
|
||||
@Override
|
||||
void createDatabase() {
|
||||
token = getAstraToken();
|
||||
assertNotNull(token);
|
||||
dbId = new AstraDBAdmin(token).createDatabase(DB, CloudProviderType.GCP, "us-east1");
|
||||
assertNotNull(dbId);
|
||||
}
|
||||
|
||||
@Override
|
||||
CassandraChatMemoryStore createChatMemoryStore() {
|
||||
return CassandraChatMemoryStore.builderAstra()
|
||||
.token(getAstraToken())
|
||||
.databaseId(dbId)
|
||||
.databaseRegion(TEST_REGION)
|
||||
.keyspace(KEYSPACE)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.testcontainers.DockerClientFactory;
|
||||
import org.testcontainers.containers.CassandraContainer;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
|
||||
/**
|
||||
* Test Cassandra Chat Memory Store with a Saas DB.
|
||||
*/
|
||||
@Disabled("No Docker in the CI")
|
||||
@Testcontainers
|
||||
class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSupport {
|
||||
static final String DATACENTER = "datacenter1";
|
||||
static final DockerImageName CASSANDRA_IMAGE = DockerImageName.parse("cassandra:5.0");
|
||||
static CassandraContainer<?> cassandraContainer;
|
||||
|
||||
@BeforeAll
|
||||
public static void ensureDockerIsRunning() {
|
||||
DockerClientFactory.instance().client();
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("resource")
|
||||
void createDatabase() {
|
||||
cassandraContainer = new CassandraContainer<>(CASSANDRA_IMAGE)
|
||||
.withEnv("CLUSTER_NAME", "langchain4j")
|
||||
.withEnv("DC", DATACENTER);
|
||||
cassandraContainer.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("resource")
|
||||
CassandraChatMemoryStore createChatMemoryStore() {
|
||||
final InetSocketAddress contactPoint =
|
||||
cassandraContainer.getContactPoint();
|
||||
CqlSession.builder()
|
||||
.addContactPoint(contactPoint)
|
||||
.withLocalDatacenter(DATACENTER)
|
||||
.build().execute(
|
||||
"CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE +
|
||||
" WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};");
|
||||
return new CassandraChatMemoryStore(CqlSession.builder()
|
||||
.addContactPoint(contactPoint)
|
||||
.withLocalDatacenter(DATACENTER)
|
||||
.withKeyspace(KEYSPACE)
|
||||
.build());
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void afterTests() throws Exception {
|
||||
cassandraContainer.stop();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.memory.ChatMemory;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.MethodOrderer;
|
||||
import org.junit.jupiter.api.Order;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestMethodOrder;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
@Slf4j
|
||||
abstract class CassandraChatMemoryStoreTestSupport {
|
||||
protected final String KEYSPACE = "langchain4j";
|
||||
protected static CassandraChatMemoryStore chatMemoryStore;
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
@DisplayName("1. Should create a database")
|
||||
void shouldInitializeDatabase() {
|
||||
createDatabase();
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
@DisplayName("2. Connection to the database")
|
||||
void shouldConnectToDatabase() {
|
||||
chatMemoryStore = createChatMemoryStore();
|
||||
log.info("Chat memory store is created.");
|
||||
// Connection to Cassandra is established
|
||||
Assertions.assertTrue(chatMemoryStore.getCassandraSession()
|
||||
.getMetadata()
|
||||
.getKeyspace(KEYSPACE)
|
||||
.isPresent());
|
||||
log.info("Chat memory table is present.");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
@DisplayName("3. ChatMemoryStore initialization (table)")
|
||||
void shouldCreateChatMemoryStore() {
|
||||
chatMemoryStore.create();
|
||||
// Table exists
|
||||
Assertions.assertTrue(chatMemoryStore.getCassandraSession()
|
||||
.refreshSchema()
|
||||
.getKeyspace(KEYSPACE).get()
|
||||
.getTable(CassandraChatMemoryStore.DEFAULT_TABLE_NAME).isPresent());
|
||||
chatMemoryStore.clear();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(4)
|
||||
@DisplayName("4. Insert items")
|
||||
void shouldInsertItems() {
|
||||
// When
|
||||
String chatSessionId = "chat-" + UUID.randomUUID();
|
||||
|
||||
ChatMemory chatMemory = MessageWindowChatMemory.builder()
|
||||
.chatMemoryStore(chatMemoryStore)
|
||||
.maxMessages(100)
|
||||
.id(chatSessionId)
|
||||
.build();
|
||||
|
||||
// When
|
||||
UserMessage userMessage = userMessage("I will ask you a few question about ff4j.");
|
||||
chatMemory.add(userMessage);
|
||||
|
||||
AiMessage aiMessage = aiMessage("Sure, go ahead!");
|
||||
chatMemory.add(aiMessage);
|
||||
|
||||
// Then
|
||||
assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage);
|
||||
}
|
||||
|
||||
abstract void createDatabase();
|
||||
|
||||
abstract CassandraChatMemoryStore createChatMemoryStore();
|
||||
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.datastax.astra.sdk.AstraClient;
|
||||
import com.dtsx.astra.sdk.utils.TestUtils;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.memory.ChatMemory;
|
||||
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.*;
|
||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
||||
/**
|
||||
* Test Cassandra Chat Memory Store with a Saas DB.
|
||||
*/
|
||||
public class ChatMemoryStoreAstraTest {
|
||||
|
||||
private static final String TEST_DATABASE = "langchain4j";
|
||||
private static final String TEST_KEYSPACE = "langchain4j";
|
||||
|
||||
@Test
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||
void chatMemoryAstraTest() {
|
||||
|
||||
// Initialization
|
||||
String astraToken = getAstraToken();
|
||||
String databaseId = setupDatabase(TEST_DATABASE, TEST_KEYSPACE);
|
||||
|
||||
// Given
|
||||
assertNotNull(databaseId);
|
||||
assertNotNull(astraToken);
|
||||
|
||||
// Flush Table before test
|
||||
truncateTable(databaseId, TEST_KEYSPACE, CassandraChatMemoryStore.DEFAULT_TABLE_NAME);
|
||||
|
||||
// When
|
||||
ChatMemoryStore chatMemoryStore =
|
||||
new AstraDbChatMemoryStore(astraToken, databaseId, TEST_REGION, "langchain4j");
|
||||
|
||||
// When
|
||||
String chatSessionId = "chat-" + UUID.randomUUID();
|
||||
ChatMemory chatMemory = TokenWindowChatMemory.builder()
|
||||
.chatMemoryStore(chatMemoryStore)
|
||||
.id(chatSessionId)
|
||||
.maxTokens(300, new OpenAiTokenizer(GPT_3_5_TURBO))
|
||||
.build();
|
||||
|
||||
// When
|
||||
UserMessage userMessage = userMessage("I will ask you a few question about ff4j.");
|
||||
chatMemory.add(userMessage);
|
||||
|
||||
AiMessage aiMessage = aiMessage("Sure, go ahead!");
|
||||
chatMemory.add(aiMessage);
|
||||
|
||||
// Then
|
||||
assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage);
|
||||
}
|
||||
|
||||
private void truncateTable(String databaseId, String keyspace, String table) {
|
||||
try (AstraClient astraClient = AstraClient.builder()
|
||||
.withToken(getAstraToken())
|
||||
.withCqlKeyspace(keyspace)
|
||||
.withDatabaseId(databaseId)
|
||||
.withDatabaseRegion(TestUtils.TEST_REGION)
|
||||
.enableCql()
|
||||
.enableDownloadSecureConnectBundle()
|
||||
.build()) {
|
||||
astraClient.cqlSession()
|
||||
.execute("TRUNCATE TABLE " + table);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.dtsx.astra.sdk.AstraDBAdmin;
|
||||
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
|
||||
import com.dtsx.astra.sdk.utils.TestUtils;
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.DocumentSplitter;
|
||||
|
@ -20,6 +22,8 @@ import dev.langchain4j.model.output.Response;
|
|||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import dev.langchain4j.store.embedding.astradb.AstraDbEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingStore;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
|
@ -28,7 +32,9 @@ import java.nio.file.Path;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
|
@ -37,23 +43,24 @@ import static java.time.Duration.ofSeconds;
|
|||
import static java.util.stream.Collectors.joining;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
||||
class SampleDocumentLoaderAndRagWithAstraTest {
|
||||
class DocumentLoaderAndRagWithAstraTest {
|
||||
|
||||
public static final String DB_NAME = "langchain4j";
|
||||
|
||||
@Test
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||
void shouldRagWithOpenAiAndAstra() {
|
||||
// Initialization
|
||||
String astraToken = getAstraToken();
|
||||
String databaseId = setupDatabase("langchain4j", "langchain4j");
|
||||
String openAIKey = System.getenv("OPENAI_API_KEY");
|
||||
|
||||
// Given
|
||||
assertNotNull(openAIKey);
|
||||
// Database Id
|
||||
UUID databaseId = new AstraDBAdmin(getAstraToken()).createDatabase(DB_NAME);
|
||||
assertNotNull(databaseId);
|
||||
assertNotNull(astraToken);
|
||||
|
||||
// --- Ingesting documents ---
|
||||
// OpenAI Key
|
||||
String openAIKey = System.getenv("OPENAI_API_KEY");
|
||||
assertNotNull(openAIKey);
|
||||
|
||||
// --- Documents Ingestion ---
|
||||
|
||||
// Parsing input file
|
||||
Path path = new File(getClass().getResource("/story-about-happy-carrot.txt").getFile()).toPath();
|
||||
|
@ -65,20 +72,20 @@ class SampleDocumentLoaderAndRagWithAstraTest {
|
|||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(openAIKey)
|
||||
.modelName(TEXT_EMBEDDING_ADA_002)
|
||||
.timeout(ofSeconds(15))
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
// Embed the document and it in the store
|
||||
EmbeddingStore<TextSegment> embeddingStore = AstraDbEmbeddingStore.builder()
|
||||
.token(astraToken)
|
||||
.database(databaseId, TestUtils.TEST_REGION)
|
||||
.table("langchain4j", "table_story")
|
||||
.vectorDimension(1536)
|
||||
EmbeddingStore<TextSegment> embeddingStore = CassandraEmbeddingStore.builderAstra()
|
||||
.token(getAstraToken())
|
||||
.databaseId(databaseId)
|
||||
.databaseRegion(TEST_REGION)
|
||||
.keyspace("default_keyspace")
|
||||
.table( "table_story")
|
||||
.dimension(1536) // openai model
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.build();
|
||||
|
||||
// Ingest method 2
|
||||
// Ingest method
|
||||
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
|
||||
.documentSplitter(splitter)
|
||||
.embeddingModel(embeddingModel)
|
|
@ -0,0 +1,170 @@
|
|||
package dev.langchain4j.store.memory.chat.cassandra;
|
||||
|
||||
import com.dtsx.astra.sdk.AstraDBAdmin;
|
||||
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.DocumentSplitter;
|
||||
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
|
||||
import dev.langchain4j.data.document.loader.UrlDocumentLoader;
|
||||
import dev.langchain4j.data.document.parser.TextDocumentParser;
|
||||
import dev.langchain4j.data.document.source.UrlSource;
|
||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||
import dev.langchain4j.data.document.transformer.HtmlTextExtractor;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingStore;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
||||
public class WebPageLoaderAndRagWIthAstraTest {
|
||||
|
||||
public static final String DB_NAME = "langchain4j";
|
||||
|
||||
|
||||
@Test
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||
void shouldRagWithOpenAiAndAstra() throws IOException {
|
||||
|
||||
// Database Id
|
||||
UUID databaseId = new AstraDBAdmin(getAstraToken()).createDatabase(DB_NAME);
|
||||
assertNotNull(databaseId);
|
||||
|
||||
// OpenAI Key
|
||||
String openAIKey = System.getenv("OPENAI_API_KEY");
|
||||
assertNotNull(openAIKey);
|
||||
|
||||
// --- Documents Ingestion ---
|
||||
|
||||
// Parsing input file
|
||||
//Path path = new File(getClass().getResource("/story-about-happy-carrot.txt").getFile()).toPath();
|
||||
//Document document = FileSystemDocumentLoader.loadDocument(path, new TextDocumentParser());
|
||||
|
||||
//Document document = UrlDocumentLoader.load("https://beta.goodbards.ai", new HtmlDocumentParser());;
|
||||
|
||||
HtmlTextExtractor transformer = new HtmlTextExtractor();
|
||||
|
||||
UrlSource.from("https://beta.goodbards.ai").inputStream();
|
||||
|
||||
Document htmlDocument = Document.from("https://beta.goodbards.ai");
|
||||
Document goodbardsBetaHomePage = transformer.transform(htmlDocument);
|
||||
|
||||
System.out.println(goodbardsBetaHomePage.text());
|
||||
|
||||
DocumentSplitter splitter = DocumentSplitters
|
||||
.recursive(100, 10, new OpenAiTokenizer(GPT_3_5_TURBO));
|
||||
|
||||
// Embedding model (OpenAI)
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(openAIKey)
|
||||
.modelName(TEXT_EMBEDDING_ADA_002)
|
||||
.build();
|
||||
|
||||
// Embed the document and it in the store
|
||||
CassandraEmbeddingStore embeddingStore = CassandraEmbeddingStore.builderAstra()
|
||||
.token(getAstraToken())
|
||||
.databaseId(databaseId)
|
||||
.databaseRegion(TEST_REGION)
|
||||
.keyspace("default_keyspace")
|
||||
.table( "goodbards")
|
||||
.dimension(1536) // openai model
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.build();
|
||||
embeddingStore.clear();
|
||||
|
||||
// Ingest method
|
||||
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
|
||||
.documentSplitter(splitter)
|
||||
.embeddingModel(embeddingModel)
|
||||
.embeddingStore(embeddingStore)
|
||||
.build();
|
||||
ingestor.ingest(goodbardsBetaHomePage);
|
||||
|
||||
// --------- RAG -------------
|
||||
|
||||
// Specify the question you want to ask the model
|
||||
String question = "What is goodbards ?";
|
||||
|
||||
// Embed the question
|
||||
Response<Embedding> questionEmbedding = embeddingModel.embed(question);
|
||||
|
||||
// Find relevant embeddings in embedding store by semantic similarity
|
||||
// You can play with parameters below to find a sweet spot for your specific use case
|
||||
int maxResults = 3;
|
||||
double minScore = 0.8;
|
||||
List<EmbeddingMatch<TextSegment>> relevantEmbeddings =
|
||||
embeddingStore.findRelevant(questionEmbedding.content(), maxResults, minScore);
|
||||
|
||||
// --------- Chat Template -------------
|
||||
|
||||
// Create a prompt for the model that includes question and relevant embeddings
|
||||
PromptTemplate promptTemplate = PromptTemplate.from(
|
||||
"Answer the following question to the best of your ability:\n"
|
||||
+ "\n"
|
||||
+ "Question:\n"
|
||||
+ "{{question}}\n"
|
||||
+ "\n"
|
||||
+ "Base your answer on the following information:\n"
|
||||
+ "{{information}}\n"
|
||||
+ "Put each sentence on a different line:\n"
|
||||
);
|
||||
|
||||
String information = relevantEmbeddings.stream()
|
||||
.map(match -> match.embedded().text())
|
||||
.collect(joining("\n\n"));
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("question", question);
|
||||
variables.put("information", information);
|
||||
|
||||
Prompt prompt = promptTemplate.apply(variables);
|
||||
|
||||
// Send the prompt to the OpenAI chat model
|
||||
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(openAIKey)
|
||||
.modelName(GPT_3_5_TURBO)
|
||||
.temperature(0.7)
|
||||
.timeout(ofSeconds(15))
|
||||
.maxRetries(3)
|
||||
.logResponses(true)
|
||||
.logRequests(true)
|
||||
.build();
|
||||
|
||||
Response<AiMessage> aiMessage = chatModel.generate(prompt.toUserMessage());
|
||||
|
||||
// See an answer from the model
|
||||
String answer = aiMessage.content().text();
|
||||
System.out.println(answer);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
version: '3'
|
||||
services:
|
||||
cassandra:
|
||||
image: stargateio/dse-next:4.0.7-e47eb8e14b96
|
||||
ports:
|
||||
- 9042:9042
|
|
@ -0,0 +1,32 @@
|
|||
<configuration scan="true">
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern>%d{HH:mm:ss.SSS} %magenta(%-5level) %cyan(%-20logger) : %msg%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<!--
|
||||
<logger name="com.datastax.astra.sdk" level="INFO" additivity="false">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</logger>
|
||||
-->
|
||||
|
||||
<logger name="com.dtsx.astra.sdk" level="DEBUG" additivity="false">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</logger>
|
||||
|
||||
<logger name="io.stargate.sdk.data" level="DEBUG" additivity="false">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</logger>
|
||||
|
||||
|
||||
<logger name="org.springframework" level="ERROR" additivity="false">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</logger>
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</root>
|
||||
|
||||
</configuration>
|
Loading…
Reference in New Issue