Removed dynamic loading from AstraDB/Cassandra

This commit is contained in:
deep-learning-dynamo 2023-09-27 17:11:01 +02:00
parent c632322493
commit ef8f04015b
24 changed files with 639 additions and 1016 deletions

View File

@ -7,11 +7,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Set up JDK 11 - name: Set up JDK 8
uses: actions/setup-java@v3 uses: actions/setup-java@v3
with: with:
java-version: '11' java-version: '8'
distribution: 'adopt' distribution: 'temurin'
- name: Test - name: Test
run: mvn --batch-mode test run: mvn --batch-mode test

View File

@ -64,6 +64,14 @@
<artifactId>junit-jupiter-engine</artifactId> <artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId> <artifactId>langchain4j-open-ai</artifactId>
@ -82,26 +90,4 @@
</license> </license>
</licenses> </licenses>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>11</source>
<target>11</target>
<showWarnings>false</showWarnings>
</configuration>
</plugin>
<plugin>
<groupId>org.honton.chas</groupId>
<artifactId>license-maven-plugin</artifactId>
<configuration>
<skipCompliance>true</skipCompliance>
</configuration>
</plugin>
</plugins>
</build>
</project> </project>

View File

@ -3,7 +3,6 @@ package dev.langchain4j.store.embedding.cassandra;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.experimental.SuperBuilder;
/** /**
* Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore. * Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore.
@ -13,11 +12,13 @@ import lombok.experimental.SuperBuilder;
* *
* @see CassandraEmbeddingStore * @see CassandraEmbeddingStore
*/ */
@Getter @Builder @Getter
@Builder
public class AstraDbEmbeddingConfiguration { public class AstraDbEmbeddingConfiguration {
/** /**
* Represents the Api Key to interact with Astra DB * Represents the Api Key to interact with Astra DB
*
* @see <a href="https://docs.datastax.com/en/astra/docs/manage-application-tokens.html">Astra DB Api Key</a> * @see <a href="https://docs.datastax.com/en/astra/docs/manage-application-tokens.html">Astra DB Api Key</a>
*/ */
@NonNull @NonNull
@ -58,8 +59,7 @@ public class AstraDbEmbeddingConfiguration {
/** /**
* Initialize the builder. * Initialize the builder.
* *
* @return * @return cassandra embedding configuration builder
* cassandra embedding configuration buildesr
*/ */
public static AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder builder() { public static AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder builder() {
return new AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder(); return new AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder();
@ -68,5 +68,6 @@ public class AstraDbEmbeddingConfiguration {
/** /**
* Signature for the builder. * Signature for the builder.
*/ */
public static class AstraDbEmbeddingConfigurationBuilder{} public static class AstraDbEmbeddingConfigurationBuilder {
}
} }

View File

@ -0,0 +1,111 @@
package dev.langchain4j.store.embedding.cassandra;
import com.datastax.astra.sdk.AstraClient;
import com.datastax.oss.driver.api.core.CqlSession;
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
import dev.langchain4j.store.embedding.EmbeddingStore;
/**
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
*
* @see EmbeddingStore
* @see MetadataVectorCassandraTable
*/
public class AstraDbEmbeddingStore extends CassandraEmbeddingStoreSupport {
/**
* Build the store from the configuration.
*
* @param config configuration
*/
public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) {
CqlSession cqlSession = AstraClient.builder()
.withToken(config.getToken())
.withCqlKeyspace(config.getKeyspace())
.withDatabaseId(config.getDatabaseId())
.withDatabaseRegion(config.getDatabaseRegion())
.enableCql()
.enableDownloadSecureConnectBundle()
.build().cqlSession();
embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension());
}
public static Builder builder() {
return new Builder();
}
/**
* Syntax Sugar Builder.
*/
public static class Builder {
/**
* Configuration built with the builder
*/
private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf;
/**
* Initialization
*/
public Builder() {
conf = AstraDbEmbeddingConfiguration.builder();
}
/**
* Populating token.
*
* @param token token
* @return current reference
*/
public Builder token(String token) {
conf.token(token);
return this;
}
/**
* Populating token.
*
* @param databaseId database Identifier
* @param databaseRegion database region
* @return current reference
*/
public Builder database(String databaseId, String databaseRegion) {
conf.databaseId(databaseId);
conf.databaseRegion(databaseRegion);
return this;
}
/**
* Populating model dimension.
*
* @param dimension model dimension
* @return current reference
*/
public Builder vectorDimension(int dimension) {
conf.dimension(dimension);
return this;
}
/**
* Populating table name.
*
* @param keyspace keyspace name
* @param table table name
* @return current reference
*/
public Builder table(String keyspace, String table) {
conf.keyspace(keyspace);
conf.table(table);
return this;
}
/**
* Building the Store.
*
* @return store for Astra.
*/
public AstraDbEmbeddingStore build() {
return new AstraDbEmbeddingStore(conf.build());
}
}
}

View File

@ -1,34 +0,0 @@
package dev.langchain4j.store.embedding.cassandra;
import com.datastax.astra.sdk.AstraClient;
import com.datastax.oss.driver.api.core.CqlSession;
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
import dev.langchain4j.store.embedding.EmbeddingStore;
/**
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
*
* @see EmbeddingStore
* @see MetadataVectorCassandraTable
*/
public class AstraDbEmbeddingStoreImpl extends CassandraEmbeddingStoreSupport {
/**
* Build the store from the configuration.
*
* @param config
* configuration
*/
public AstraDbEmbeddingStoreImpl(AstraDbEmbeddingConfiguration config) {
CqlSession cqlSession = AstraClient.builder()
.withToken(config.getToken())
.withCqlKeyspace(config.getKeyspace())
.withDatabaseId(config.getDatabaseId())
.withDatabaseRegion(config.getDatabaseRegion())
.enableCql()
.enableDownloadSecureConnectBundle()
.build().cqlSession();
embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension());
}
}

View File

@ -3,7 +3,6 @@ package dev.langchain4j.store.embedding.cassandra;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.experimental.SuperBuilder;
import java.util.List; import java.util.List;
@ -15,10 +14,13 @@ import java.util.List;
* *
* @see CassandraEmbeddingStore * @see CassandraEmbeddingStore
*/ */
@Getter @Builder @Getter
@Builder
public class CassandraEmbeddingConfiguration { public class CassandraEmbeddingConfiguration {
/** Default Cassandra Port. */ /**
* Default Cassandra Port.
*/
public static Integer DEFAULT_PORT = 9042; public static Integer DEFAULT_PORT = 9042;
// --- Connectivity Parameters --- // --- Connectivity Parameters ---
@ -74,8 +76,7 @@ public class CassandraEmbeddingConfiguration {
/** /**
* Initialize the builder. * Initialize the builder.
* *
* @return * @return cassandra embedding configuration buildesr
* cassandra embedding configuration buildesr
*/ */
public static CassandraEmbeddingConfigurationBuilder builder() { public static CassandraEmbeddingConfigurationBuilder builder() {
return new CassandraEmbeddingConfigurationBuilder(); return new CassandraEmbeddingConfigurationBuilder();
@ -84,6 +85,6 @@ public class CassandraEmbeddingConfiguration {
/** /**
* Signature for the builder. * Signature for the builder.
*/ */
public static class CassandraEmbeddingConfigurationBuilder{} public static class CassandraEmbeddingConfigurationBuilder {
}
} }

View File

@ -0,0 +1,152 @@
package dev.langchain4j.store.embedding.cassandra;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.net.InetSocketAddress;
import java.util.Arrays;
/**
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
*
* @see EmbeddingStore
* @see MetadataVectorCassandraTable
*/
public class CassandraEmbeddingStore extends CassandraEmbeddingStoreSupport {
/**
* Build the store from the configuration.
*
* @param config configuration
*/
public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) {
CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config);
createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace());
sessionBuilder.withKeyspace(config.getKeyspace());
this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(),
config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS);
}
/**
* Build the cassandra session from the config. At the difference of adminSession there
* a keyspace attached to it.
*
* @param config current configuration
* @return cassandra session
*/
private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) {
CqlSessionBuilder cqlSessionBuilder = CqlSession.builder();
cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter());
if (config.getUserName() != null && config.getPassword() != null) {
cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword());
}
config.getContactPoints().forEach(cp ->
cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort())));
return cqlSessionBuilder;
}
/**
* Create the keyspace in cassandra Destination if not exist.
*/
private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) {
try (CqlSession adminSession = cqlSessionBuilder.build()) {
adminSession.execute(SchemaBuilder.createKeyspace(keyspace)
.ifNotExists()
.withSimpleStrategy(1)
.withDurableWrites(true)
.build());
}
}
public static Builder builder() {
return new Builder();
}
/**
* Syntax Sugar Builder.
*/
public static class Builder {
/**
* Configuration built with the builder
*/
private final CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder conf;
/**
* Initialization
*/
public Builder() {
conf = CassandraEmbeddingConfiguration.builder();
}
/**
* Populating cassandra port.
*
* @param port port
* @return current reference
*/
public CassandraEmbeddingStore.Builder port(int port) {
conf.port(port);
return this;
}
/**
* Populating cassandra contact points.
*
* @param hosts port
* @return current reference
*/
public CassandraEmbeddingStore.Builder contactPoints(String... hosts) {
conf.contactPoints(Arrays.asList(hosts));
return this;
}
/**
* Populating model dimension.
*
* @param dimension model dimension
* @return current reference
*/
public CassandraEmbeddingStore.Builder vectorDimension(int dimension) {
conf.dimension(dimension);
return this;
}
/**
* Populating datacenter.
*
* @param dc datacenter
* @return current reference
*/
public CassandraEmbeddingStore.Builder localDataCenter(String dc) {
conf.localDataCenter(dc);
return this;
}
/**
* Populating table name.
*
* @param keyspace keyspace name
* @param table table name
* @return current reference
*/
public CassandraEmbeddingStore.Builder table(String keyspace, String table) {
conf.keyspace(keyspace);
conf.table(table);
return this;
}
/**
* Building the Store.
*
* @return store for Astra.
*/
public CassandraEmbeddingStore build() {
return new CassandraEmbeddingStore(conf.build());
}
}
}

View File

@ -1,67 +0,0 @@
package dev.langchain4j.store.embedding.cassandra;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.net.InetSocketAddress;
/**
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
*
* @see EmbeddingStore
* @see MetadataVectorCassandraTable
*/
public class CassandraEmbeddingStoreImpl extends CassandraEmbeddingStoreSupport {
/**
* Build the store from the configuration.
*
* @param config
* configuration
*/
public CassandraEmbeddingStoreImpl(CassandraEmbeddingConfiguration config) {
CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config);
createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace());
sessionBuilder.withKeyspace(config.getKeyspace());
this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(),
config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS);
}
/**
* Build the cassandra session from the config. At the difference of adminSession there
* a keyspace attached to it.
*
* @param config
* current configuration
* @return
* cassandra session
*/
private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) {
CqlSessionBuilder cqlSessionBuilder = CqlSession.builder();
cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter());
if (config.getUserName() != null && config.getPassword() != null) {
cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword());
}
config.getContactPoints().forEach(cp ->
cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort())));
return cqlSessionBuilder;
}
/**
* Create the keyspace in cassandra Destination if not exist.
*/
private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) {
try(CqlSession adminSession = cqlSessionBuilder.build()) {
adminSession.execute(SchemaBuilder.createKeyspace(keyspace)
.ifNotExists()
.withSimpleStrategy(1)
.withDurableWrites(true)
.build());
}
}
}

View File

@ -3,6 +3,7 @@ package dev.langchain4j.store.embedding.cassandra;
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable; import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
import com.dtsx.astra.sdk.cassio.SimilarityMetric; import com.dtsx.astra.sdk.cassio.SimilarityMetric;
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery; import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery;
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery.SimilaritySearchQueryBuilder;
import com.dtsx.astra.sdk.cassio.SimilaritySearchResult; import com.dtsx.astra.sdk.cassio.SimilaritySearchResult;
import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
@ -18,11 +19,15 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
import static java.util.stream.Collectors.toList;
/** /**
* Support for CassandraEmbeddingStore with and Without Astra. * Support for CassandraEmbeddingStore with and Without Astra.
*/ */
@Getter @Getter
public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<TextSegment> { abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<TextSegment> {
/** /**
* Represents an embedding table in Cassandra, it is a table with a vector column. * Represents an embedding table in Cassandra, it is a table with a vector column.
@ -34,10 +39,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
* - the row id is generated * - the row id is generated
* - text and metadata are not stored * - text and metadata are not stored
* *
* @param embedding * @param embedding representation of the list of floats
* representation of the list of floats * @return newly created row id
* @return
* newly created row id
*/ */
@Override @Override
public String add(@NonNull Embedding embedding) { public String add(@NonNull Embedding embedding) {
@ -49,12 +52,9 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
* - the row id is generated * - the row id is generated
* - text and metadata coming from the text Segment * - text and metadata coming from the text Segment
* *
* @param embedding * @param embedding representation of the list of floats
* representation of the list of floats * @param textSegment text content and metadata
* @param textSegment * @return newly created row id
* text content and metadata
* @return
* newly created row id
*/ */
@Override @Override
public String add(@NonNull Embedding embedding, TextSegment textSegment) { public String add(@NonNull Embedding embedding, TextSegment textSegment) {
@ -70,10 +70,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
/** /**
* Add a new embedding to the store. * Add a new embedding to the store.
* *
* @param rowId * @param rowId the row id
* the row id * @param embedding representation of the list of floats
* @param embedding
* representation of the list of floats
*/ */
@Override @Override
public void add(@NonNull String rowId, @NonNull Embedding embedding) { public void add(@NonNull String rowId, @NonNull Embedding embedding) {
@ -83,10 +81,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
/** /**
* Add multiple embeddings as a single action. * Add multiple embeddings as a single action.
* *
* @param embeddingList * @param embeddingList embeddings list
* embeddings list * @return list of new row if (same order as the input)
* @return
* list of new row if (same order as the input)
*/ */
@Override @Override
public List<String> addAll(List<Embedding> embeddingList) { public List<String> addAll(List<Embedding> embeddingList) {
@ -95,18 +91,15 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
.map(MetadataVectorCassandraTable.Record::new) .map(MetadataVectorCassandraTable.Record::new)
.peek(embeddingTable::putAsync) .peek(embeddingTable::putAsync)
.map(MetadataVectorCassandraTable.Record::getRowId) .map(MetadataVectorCassandraTable.Record::getRowId)
.collect(Collectors.toList()); .collect(toList());
} }
/** /**
* Add multiple embeddings as a single action. * Add multiple embeddings as a single action.
* *
* @param embeddingList * @param embeddingList embeddings
* embeddings * @param textSegmentList text segments
* @param textSegmentList * @return list of new row if (same order as the input)
* text segments
* @return
* list of new row if (same order as the input)
*/ */
@Override @Override
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) { public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
@ -115,7 +108,7 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
} }
// Looping on both list with an index // Looping on both list with an index
List<String> ids = new ArrayList<>(); List<String> ids = new ArrayList<>();
for(int i = 0; i < embeddingList.size(); i++) { for (int i = 0; i < embeddingList.size(); i++) {
ids.add(add(embeddingList.get(i), textSegmentList.get(i))); ids.add(add(embeddingList.get(i), textSegmentList.get(i)));
} }
return ids; return ids;
@ -124,36 +117,30 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
/** /**
* Search for relevant. * Search for relevant.
* *
* @param embedding * @param embedding current embeddings
* current embeddings * @param maxResults max number of result
* @param maxResults * @param minScore threshold
* max number of result * @return list of matching elements
* @param minScore
* threshold
* @return
* list of matching elements
*/ */
@Override @Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore) { public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore) {
return embeddingTable return embeddingTable
.similaritySearch(SimilaritySearchQuery.builder() .similaritySearch(SimilaritySearchQuery.builder()
.embeddings(embedding.vectorAsList()) .embeddings(embedding.vectorAsList())
.recordCount(maxResults) .recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
.threshold(CosineSimilarity.fromRelevanceScore(minScore)) .threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")))
.distance(SimilarityMetric.COS) .distance(SimilarityMetric.COS)
.build()) .build())
.stream() .stream()
.map(CassandraEmbeddingStoreSupport::mapSearchResult) .map(CassandraEmbeddingStoreSupport::mapSearchResult)
.collect(Collectors.toList()); .collect(toList());
} }
/** /**
* Map Search result coming from Astra. * Map Search result coming from Astra.
* *
* @param record * @param record current record
* current record * @return search result
* @return
* search result
*/ */
private static EmbeddingMatch<TextSegment> mapSearchResult(SimilaritySearchResult<MetadataVectorCassandraTable.Record> record) { private static EmbeddingMatch<TextSegment> mapSearchResult(SimilaritySearchResult<MetadataVectorCassandraTable.Record> record) {
return new EmbeddingMatch<>( return new EmbeddingMatch<>(
@ -163,33 +150,24 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
record.getEmbedded().getRowId(), record.getEmbedded().getRowId(),
// Embeddings vector // Embeddings vector
Embedding.from(record.getEmbedded().getVector()), Embedding.from(record.getEmbedded().getVector()),
// Text Fragment and metadata // Text segment and metadata
TextSegment.from(record.getEmbedded().getBody(), new Metadata(record.getEmbedded().getMetadata()))); TextSegment.from(record.getEmbedded().getBody(), new Metadata(record.getEmbedded().getMetadata())));
} }
/** /**
* Similarity Search ANN based on the embedding. * Similarity Search ANN based on the embedding.
* *
* @param embedding * @param embedding vector
* vector * @param maxResults max number of results
* @param maxResults * @param minScore score minScore
* max number of record * @param metadata map key-value to build a metadata filter
* @param minScore * @return list of matching results
* score minScore
* @param metadata
* map key-value to build a metadata filter
* @return
* list of matching results
*/ */
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, Integer maxResults, Double minScore, Metadata metadata) { public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore, Metadata metadata) {
SimilaritySearchQuery.SimilaritySearchQueryBuilder builder = SimilaritySearchQueryBuilder builder = SimilaritySearchQuery.builder()
SimilaritySearchQuery.builder().embeddings(embedding.vectorAsList()); .embeddings(embedding.vectorAsList())
if (maxResults == null || maxResults < 1) { .recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
throw new IllegalArgumentException("maxResults (param[1]) must not be null and greater than 0"); .threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")));
}
if (minScore == null || minScore < 1 || minScore > 1) {
throw new IllegalArgumentException("minScore (param[2]) must not be null and in between 0 and 1.");
}
if (metadata != null) { if (metadata != null) {
builder.metaData(metadata.asMap()); builder.metaData(metadata.asMap());
} }
@ -197,7 +175,6 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
.similaritySearch(builder.build()) .similaritySearch(builder.build())
.stream() .stream()
.map(CassandraEmbeddingStoreSupport::mapSearchResult) .map(CassandraEmbeddingStoreSupport::mapSearchResult)
.collect(Collectors.toList()); .collect(toList());
} }
} }

View File

@ -4,24 +4,18 @@ import com.datastax.astra.sdk.AstraClient;
/** /**
* AstraDb is a version of Cassandra running in Saas Mode. * AstraDb is a version of Cassandra running in Saas Mode.
* * <p>
* The initialization of the CQLSession will be done through an AstraClient * The initialization of the CQLSession will be done through an AstraClient
*
* @author Cedrick Lunven (clun)
*/ */
public class AstraDbChatMemoryStore extends CassandraChatMemoryStore { public class AstraDbChatMemoryStore extends CassandraChatMemoryStore {
/** /**
* Constructor with default table name. * Constructor with default table name.
* *
* @param token * @param token token
* token * @param dbId database identifier
* @param dbId * @param dbRegion database region
* database idendifier * @param keyspaceName keyspace name
* @param dbRegion
* database region
* @param keyspaceName
* keyspace name
*/ */
public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName) { public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName) {
this(token, dbId, dbRegion, keyspaceName, DEFAULT_TABLE_NAME); this(token, dbId, dbRegion, keyspaceName, DEFAULT_TABLE_NAME);

View File

@ -1,28 +1,19 @@
package dev.langchain4j.store.memory.chat.cassandra; package dev.langchain4j.store.memory.chat.cassandra;
import com.datastax.astra.sdk.AstraClient;
import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSession;
import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record;
import com.datastax.oss.driver.api.core.uuid.Uuids; import com.datastax.oss.driver.api.core.uuid.Uuids;
import com.dtsx.astra.sdk.cassio.ClusteredCassandraTable; import com.dtsx.astra.sdk.cassio.ClusteredCassandraTable;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageDeserializer; import dev.langchain4j.data.message.ChatMessageDeserializer;
import dev.langchain4j.data.message.ChatMessageSerializer; import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore; import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record;
import static java.util.stream.Collectors.toList;
/** /**
* Implementation of {@link ChatMemoryStore} using Astra DB Vector Search. * Implementation of {@link ChatMemoryStore} using Astra DB Vector Search.
@ -30,8 +21,6 @@ import java.util.stream.Collectors;
* is a partition.Message id is a time uuid. * is a partition.Message id is a time uuid.
* *
* @see <a href="https://docs.datastax.com/en/astra-serverless/docs/vector-search/overview.html">Astra Vector Store Documentation</a> * @see <a href="https://docs.datastax.com/en/astra-serverless/docs/vector-search/overview.html">Astra Vector Store Documentation</a>
* @author Cedrick Lunven (clun)
* @since 0.22.0
*/ */
@Slf4j @Slf4j
public class CassandraChatMemoryStore implements ChatMemoryStore { public class CassandraChatMemoryStore implements ChatMemoryStore {
@ -46,18 +35,12 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
*/ */
private final ClusteredCassandraTable messageTable; private final ClusteredCassandraTable messageTable;
/** Object Mapper. */
private static final ObjectMapper OM = new ObjectMapper();
/** /**
* Constructor for message store * Constructor for message store
* *
* @param session * @param session cassandra session
* cassandra session * @param keyspaceName keyspace name
* @param keyspaceName * @param tableName table name
* keyspace name
* @param tableName
* table name
*/ */
public CassandraChatMemoryStore(CqlSession session, String keyspaceName, String tableName) { public CassandraChatMemoryStore(CqlSession session, String keyspaceName, String tableName) {
messageTable = new ClusteredCassandraTable(session, keyspaceName, tableName); messageTable = new ClusteredCassandraTable(session, keyspaceName, tableName);
@ -66,52 +49,54 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
/** /**
* Constructor for message store * Constructor for message store
* *
* @param session * @param session cassandra session
* cassandra session * @param keyspaceName keyspace name
* @param keyspaceName
* keyspace name
*/ */
public CassandraChatMemoryStore(CqlSession session, String keyspaceName) { public CassandraChatMemoryStore(CqlSession session, String keyspaceName) {
messageTable = new ClusteredCassandraTable(session, keyspaceName, DEFAULT_TABLE_NAME); messageTable = new ClusteredCassandraTable(session, keyspaceName, DEFAULT_TABLE_NAME);
} }
/** {@inheritDoc} */ /**
* {@inheritDoc}
*/
@Override @Override
public List<ChatMessage> getMessages(@NonNull Object memoryId) { public List<ChatMessage> getMessages(@NonNull Object memoryId) {
return messageTable return messageTable
.findPartition(getMemoryId(memoryId)) .findPartition(getMemoryId(memoryId))
.stream() .stream()
.map(this::toChatMessage) .map(this::toChatMessage)
.collect(Collectors.toList()); .collect(toList());
} }
/** {@inheritDoc} */ /**
* {@inheritDoc}
*/
@Override @Override
public void updateMessages(@NonNull Object memoryId, @NonNull List<ChatMessage> list) { public void updateMessages(@NonNull Object memoryId, @NonNull List<ChatMessage> messages) {
deleteMessages(memoryId); deleteMessages(memoryId);
messageTable.upsertPartition(list.stream() messageTable.upsertPartition(messages.stream()
.map(r -> this.fromChatMessage(getMemoryId(memoryId), r)) .map(record -> fromChatMessage(getMemoryId(memoryId), record))
.collect(Collectors.toList())); .collect(toList()));
} }
/** {@inheritDoc} */ /**
* {@inheritDoc}
*/
@Override @Override
public void deleteMessages(@NonNull Object memoryId) { public void deleteMessages(@NonNull Object memoryId) {
messageTable.deletePartition(getMemoryId(memoryId)); messageTable.deletePartition(getMemoryId(memoryId));
} }
/** /**
* Unmarshalling Cassandra row as a Message with proper sub-type. * Unmarshalling Cassandra row as a Message with proper sub-type.
* *
* @param record * @param record cassandra record
* cassandra record * @return chat message
* @return
* chat message
*/ */
private ChatMessage toChatMessage(@NonNull Record record) { private ChatMessage toChatMessage(@NonNull Record record) {
try { try {
return ChatMessageDeserializer.messageFromJson(record.getBody()); return ChatMessageDeserializer.messageFromJson(record.getBody());
} catch(Exception e) { } catch (Exception e) {
log.error("Unable to parse message body", e); log.error("Unable to parse message body", e);
throw new IllegalArgumentException("Unable to parse message body"); throw new IllegalArgumentException("Unable to parse message body");
} }
@ -119,12 +104,10 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
/** /**
* Serialize the {@link ChatMessage} as a Cassandra Row. * Serialize the {@link ChatMessage} as a Cassandra Row.
* @param memoryId *
* chat session identifier * @param memoryId chat session identifier
* @param chatMessage * @param chatMessage chat message
* chat message * @return cassandra row.
* @return
* cassandra row.
*/ */
private Record fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) { private Record fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) {
try { try {
@ -133,18 +116,16 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
record.setPartitionId(memoryId); record.setPartitionId(memoryId);
record.setBody(ChatMessageSerializer.messageToJson(chatMessage)); record.setBody(ChatMessageSerializer.messageToJson(chatMessage));
return record; return record;
} catch(Exception e) { } catch (Exception e) {
log.error("Unable to parse message body", e); log.error("Unable to parse message body", e);
throw new IllegalArgumentException("Unable to parse message body", e); throw new IllegalArgumentException("Unable to parse message body", e);
} }
} }
private String getMemoryId(Object memoryId) { private String getMemoryId(Object memoryId) {
if (!(memoryId instanceof String) ) { if (!(memoryId instanceof String)) {
throw new IllegalArgumentException("memoryId must be a String"); throw new IllegalArgumentException("memoryId must be a String");
} }
return (String) memoryId; return (String) memoryId;
} }
} }

View File

@ -0,0 +1,96 @@
package dev.langchain4j.store.embedding.cassandra;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
/**
* The configuration objects include a few validation rules.
*/
public class AstraDbEmbeddingConfigurationTest {
@Test
public void should_build_configuration_test() {
AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder()
.token("token")
.databaseId("dbId")
.databaseRegion("dbRegion")
.keyspace("ks")
.dimension(20)
.table("table")
.build();
assertNotNull(config);
assertNotNull(config.getToken());
assertNotNull(config.getDatabaseId());
assertNotNull(config.getDatabaseRegion());
}
@Test
public void should_error_if_no_table_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token")
.databaseId("dbId")
.databaseRegion("dbRegion")
.keyspace("ks")
.dimension(20)
.build());
assertEquals("table is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_keyspace_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token")
.databaseId("dbId")
.databaseRegion("dbRegion")
.table("ks")
.dimension(20)
.build());
assertEquals("keyspace is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_dimension_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token")
.databaseId("dbId")
.databaseRegion("dbRegion")
.table("ks")
.keyspace("ks")
.build());
assertEquals("dimension is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_token_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.databaseId("dbId")
.databaseRegion("dbRegion")
.table("ks")
.keyspace("ks")
.dimension(20)
.build());
assertEquals("token is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_database_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token")
.table("ks")
.keyspace("ks")
.dimension(20)
.build());
assertEquals("databaseId is marked non-null but is null", exception.getMessage());
}
}

View File

@ -1,14 +1,12 @@
package dev.langchain4j.store.embedding.cassandra; package dev.langchain4j.store.embedding.cassandra;
import com.datastax.astra.sdk.AstraClient; import com.datastax.astra.sdk.AstraClient;
import com.datastax.oss.driver.api.core.CqlSession;
import com.dtsx.astra.sdk.utils.TestUtils; import com.dtsx.astra.sdk.utils.TestUtils;
import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@ -16,53 +14,59 @@ import java.util.List;
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase; import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* Testing implementation of Embedding Store using AstraDB. * Testing implementation of Embedding Store using AstraDB.
*/ */
class AstraDbEmbeddingStoreTest { class AstraDbEmbeddingStoreTest {
public static final String TEST_DB = "langchain4j"; private static final String TEST_KEYSPACE = "langchain4j";
public static final String TEST_KEYSPACE = "langchain4j"; private static final String TEST_INDEX = "test_embedding_store";
public static final String TEST_INDEX = "test_embedding_store";
@Test private final String astraToken = getAstraToken();
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") private final String databaseId = setupDatabase("langchain4j", TEST_KEYSPACE);
public void testAddEmbeddingAndFindRelevant() private final AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration
throws Exception { .builder()
.token(astraToken)
.databaseId(databaseId)
.databaseRegion(TestUtils.TEST_REGION)
.keyspace(TEST_KEYSPACE)
.table(TEST_INDEX)
.dimension(11)
.build());
// Read Token from environment variable ASTRA_DB_APPLICATION_TOKEN @BeforeEach
String astraToken = getAstraToken(); void truncateTable() {
// Database will be created if not exist (can take 90 seconds on first run)
String databaseId = setupDatabase(TEST_DB, TEST_KEYSPACE);
// Store initialization
AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration
.builder().token(astraToken).databaseId(databaseId)
.databaseRegion(TestUtils.TEST_REGION)
.keyspace(TEST_KEYSPACE)
.table(TEST_INDEX)
.dimension(11).build());
// Flushing Table before Start (idem potent)
AstraClient.builder() AstraClient.builder()
.withToken(astraToken) .withToken(getAstraToken())
.withCqlKeyspace(TEST_KEYSPACE) .withCqlKeyspace(TEST_KEYSPACE)
.withDatabaseId(databaseId) .withDatabaseId(databaseId)
.withDatabaseRegion(TestUtils.TEST_REGION) .withDatabaseRegion(TestUtils.TEST_REGION)
.enableCql() .enableCql()
.enableDownloadSecureConnectBundle() .enableDownloadSecureConnectBundle()
.build().cqlSession().execute("TRUNCATE TABLE " + TEST_INDEX); .build().cqlSession().execute("TRUNCATE TABLE " + TEST_INDEX);
}
@Test
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
void testAddEmbeddingAndFindRelevant() {
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F}); Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
TextSegment textSegment = TextSegment.textSegment("Text", Metadata.from("Key", "Value")); TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
String added = astraDbEmbeddingStore.add(embedding, textSegment); String id = astraDbEmbeddingStore.add(embedding, textSegment);
Assertions.assertTrue(added != null && !added.isEmpty()); assertTrue(id != null && !id.isEmpty());
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F}); Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
List<EmbeddingMatch<TextSegment>> embeddingMatches = astraDbEmbeddingStore.findRelevant(refereceEmbedding, 10); List<EmbeddingMatch<TextSegment>> embeddingMatches = astraDbEmbeddingStore.findRelevant(refereceEmbedding, 10);
Assertions.assertEquals(1, embeddingMatches.size()); assertEquals(1, embeddingMatches.size());
}
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
}
} }

View File

@ -0,0 +1,93 @@
package dev.langchain4j.store.embedding.cassandra;
import org.junit.jupiter.api.Test;
import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.DEFAULT_PORT;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.*;
public class CassandraEmbeddingConfigurationTest {
@Test
public void should_build_configuration_test() {
CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder()
.contactPoints(singletonList("localhost"))
.port(DEFAULT_PORT)
.keyspace("ks")
.dimension(20)
.table("table")
.localDataCenter("dc1")
.build();
assertNotNull(config);
}
@Test
public void should_error_if_no_datacenter_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(singletonList("localhost"))
.port(DEFAULT_PORT)
.keyspace("ks")
.dimension(20)
.table("table")
.build());
assertEquals("localDataCenter is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_table_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(singletonList("localhost"))
.port(DEFAULT_PORT)
.keyspace("ks")
.dimension(20)
.localDataCenter("dc1")
.build());
assertEquals("table is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_keyspace_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(singletonList("localhost"))
.port(DEFAULT_PORT)
.table("ks")
.dimension(20)
.localDataCenter("dc1")
.build());
assertEquals("keyspace is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_dimension_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(singletonList("localhost"))
.port(DEFAULT_PORT)
.table("ks")
.keyspace("ks")
.localDataCenter("dc1")
.build());
assertEquals("dimension is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_contact_points_test() {
// Table is required
NullPointerException exception = assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.port(DEFAULT_PORT)
.table("ks")
.keyspace("ks")
.dimension(20)
.localDataCenter("dc1")
.build());
assertEquals("contactPoints is marked non-null but is null", exception.getMessage());
}
}

View File

@ -4,49 +4,50 @@ import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.List; import java.util.List;
import static dev.langchain4j.internal.Utils.randomUUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* Work with Cassandra Embedding Store. * Work with Cassandra Embedding Store.
*/ */
class CassandraEmbeddingStoreTest { class CassandraEmbeddingStoreTest {
public static final String TEST_KEYSPACE = "langchain4j";
public static final String TEST_INDEX = "test_embedding_store";
@Test @Test
@Disabled("To run this test, you must have a local Cassandra instance, a docker-compose is provided") @Disabled("To run this test, you must have a local Cassandra instance, a docker-compose is provided")
public void testAddEmbeddingAndFindRelevant() public void testAddEmbeddingAndFindRelevant() {
throws Exception {
CassandraEmbeddingStore cassandraEmbeddingStore = initStore(); CassandraEmbeddingStore cassandraEmbeddingStore = initStore();
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F}); Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
TextSegment textSegment = TextSegment.textSegment("Text", Metadata.from("Key", "Value")); TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
String added = cassandraEmbeddingStore.add(embedding, textSegment); String id = cassandraEmbeddingStore.add(embedding, textSegment);
Assertions.assertTrue(added != null && !added.isEmpty()); assertTrue(id != null && !id.isEmpty());
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F}); Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
List<EmbeddingMatch<TextSegment>> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 10); List<EmbeddingMatch<TextSegment>> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 1);
Assertions.assertEquals(1, embeddingMatches.size()); assertEquals(1, embeddingMatches.size());
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
} }
private CassandraEmbeddingStore initStore() private CassandraEmbeddingStore initStore() {
throws Exception { return CassandraEmbeddingStore.builder()
return CassandraEmbeddingStore
.builder()
.contactPoints("127.0.0.1") .contactPoints("127.0.0.1")
.port(9042) .port(9042)
.localDataCenter("datacenter1") .localDataCenter("datacenter1")
.table(TEST_KEYSPACE, TEST_INDEX) .table("langchain4j", "table_" + randomUUID().replace("-", ""))
.vectorDimension(11) .vectorDimension(11)
.build(); .build();
} }
} }

View File

@ -20,7 +20,6 @@ import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@ -29,7 +28,6 @@ import java.nio.file.Path;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase; import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
@ -37,22 +35,23 @@ import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002; import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
import static java.time.Duration.ofSeconds; import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.joining;
import static org.junit.jupiter.api.Assertions.assertNotNull;
class SampleDocumentLoaderAndRagWithAstraTest { class SampleDocumentLoaderAndRagWithAstraTest {
@Test @Test
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*") @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
public void shouldRagWithOpenAiAndAstra() throws InterruptedException { void shouldRagWithOpenAiAndAstra() {
// Initialization // Initialization
String astraToken = getAstraToken(); String astraToken = getAstraToken();
String databaseId = setupDatabase("langchain4j", "langchain4j"); String databaseId = setupDatabase("langchain4j", "langchain4j");
String openAIKey = System.getenv("OPENAI_API_KEY"); String openAIKey = System.getenv("OPENAI_API_KEY");
// Given // Given
Assertions.assertNotNull(openAIKey); assertNotNull(openAIKey);
Assertions.assertNotNull(databaseId); assertNotNull(databaseId);
Assertions.assertNotNull(astraToken); assertNotNull(astraToken);
// --- Ingesting documents --- // --- Ingesting documents ---

View File

@ -1,26 +1,23 @@
package dev.langchain4j.store.memory.chat.cassandra; package dev.langchain4j.store.memory.chat.cassandra;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.TokenWindowChatMemory; import dev.langchain4j.memory.chat.TokenWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.memory.chat.ChatMemoryStore; import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.UUID; import java.util.UUID;
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION; import static com.dtsx.astra.sdk.utils.TestUtils.*;
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.time.Duration.ofSeconds; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
/** /**
* Test Cassandra Chat Memory Store with a Saas DB. * Test Cassandra Chat Memory Store with a Saas DB.
@ -28,27 +25,19 @@ import static java.time.Duration.ofSeconds;
public class ChatMemoryStoreAstraTest { public class ChatMemoryStoreAstraTest {
@Test @Test
@Disabled("bug: order of retrieved messages is wrong")
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*") @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
void chatMemoryAstraTest() throws InterruptedException { void chatMemoryAstraTest() {
// Initialization // Initialization
String astraToken = getAstraToken(); String astraToken = getAstraToken();
String databaseId = setupDatabase("langchain4j", "langchain4j"); String databaseId = setupDatabase("langchain4j", "langchain4j");
String openAIKey = System.getenv("OPENAI_API_KEY");
// Given // Given
Assertions.assertNotNull(openAIKey); assertNotNull(databaseId);
Assertions.assertNotNull(databaseId); assertNotNull(astraToken);
Assertions.assertNotNull(astraToken);
// Given
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(openAIKey)
.modelName(GPT_3_5_TURBO)
.temperature(0.3)
.timeout(ofSeconds(120))
.logRequests(true)
.logResponses(true)
.build();
// When // When
ChatMemoryStore chatMemoryStore = ChatMemoryStore chatMemoryStore =
new AstraDbChatMemoryStore(astraToken, databaseId, TEST_REGION, "langchain4j"); new AstraDbChatMemoryStore(astraToken, databaseId, TEST_REGION, "langchain4j");
@ -61,11 +50,13 @@ public class ChatMemoryStoreAstraTest {
.build(); .build();
// When // When
chatMemory.add(userMessage("I will ask you a few question about ff4j. Response in a single sentence")); UserMessage userMessage = userMessage("I will ask you a few question about ff4j.");
chatMemory.add(userMessage("Can I use it with Javascript ? ")); chatMemory.add(userMessage);
AiMessage aiMessage = aiMessage("Sure, go ahead!");
chatMemory.add(aiMessage);
// Then // Then
Response<AiMessage> output = model.generate(chatMemory.messages()); assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage);
Assertions.assertNotNull(output.content().text());
} }
} }

View File

@ -1,20 +0,0 @@
<configuration debug="false">
<statusListener class="ch.qos.logback.core.status.NopStatusListener" />
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} %magenta(%-5level) %cyan(%-47logger) : %msg%n</pattern>
</encoder>
</appender>
<logger name="com.datastax.astra.sdk" level="INFO" additivity="false">
<appender-ref ref="STDOUT" />
</logger>
<logger name="dev.langchain4j.store" level="INFO" additivity="false">
<appender-ref ref="STDOUT" />
</logger>
<logger name="com.datastax.oss.driver" level="ERROR" additivity="false">
<appender-ref ref="STDOUT" />
</logger>
<root level="WARN">
<appender-ref ref="STDOUT" />
</root>
</configuration>

View File

@ -15,6 +15,7 @@ import java.util.List;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Disabled("needs Elasticsearch to be running locally") @Disabled("needs Elasticsearch to be running locally")
class ElasticsearchEmbeddingStoreTest { class ElasticsearchEmbeddingStoreTest {
@ -77,7 +78,7 @@ class ElasticsearchEmbeddingStoreTest {
@Test @Test
void testAddNotEqualSizeEmbeddingAndEmbedded() { void testAddNotEqualSizeEmbeddingAndEmbedded() {
Throwable ex = Assertions.assertThrows(IllegalArgumentException.class, () -> store.addAll(asList( Throwable ex = assertThrows(IllegalArgumentException.class, () -> store.addAll(asList(
Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)), Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f, 0.55f)), Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f, 0.55f)),
Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f)) Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))

View File

@ -1,167 +0,0 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.embedding.Embedding;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* This parent class would implement the boilerplate code to invoke the concrete implementation.
*/
public abstract class AbstractEmbeddingStore<T> implements EmbeddingStore<T> {
/** Concrete Implementation if class is available. */
protected String implementationClassName;
/**
* Concrete Implementation (delegate Pattern).
*/
protected EmbeddingStore<T> delegateImplementation;
/**
* Initialize the concrete implementation.
*/
protected abstract EmbeddingStore<T> loadImplementation()
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
IllegalAccessException, InvocationTargetException;
/**
* Getter for the concrete implementation with an initialization.
*
* @return
* delegate implementation
*/
protected EmbeddingStore<T> getDelegateImplementation() {
if (delegateImplementation == null) {
try {
this.delegateImplementation = loadImplementation();
} catch(ClassNotFoundException cnf) {
throw new RuntimeException(
"Class '" + implementationClassName + "' not found, please checks your dependencies.", cnf);
} catch(NoSuchMethodException e) {
throw new IllegalArgumentException(
"Class '" + implementationClassName + "' does not have constructor with expected parameters", e);
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
throw new IllegalArgumentException(
"Constructor of class '" + implementationClassName + "' cannot be invoked check visibility or code.", e);
} catch(Exception e) {
throw new RuntimeException("Unexpected error while loading implementation", e);
}
}
return delegateImplementation;
}
/**
* Add embedding (vector only) to the store. The id is generated by the store.
*
* @param embedding
* The embedding (vector) to be added to the store.
* @return
* unique id of the embedding
*/
@Override
public String add(Embedding embedding) {
return getDelegateImplementation().add(embedding);
}
/**
* Add embedding (vector only) to the store enforcing its identifier.
*
* @param embeddingId
* unique identifier of the embedding
* @param embedding
* The embedding (vector) to be added to the store.
*/
@Override
public void add(String embeddingId, Embedding embedding) {
Objects.requireNonNull(embeddingId, "embeddingId (param[0]) must not be null");
Objects.requireNonNull(embedding, "embedding (param[1]) must not be null");
getDelegateImplementation().add(embeddingId, embedding);
}
/**
* Add embedding to the store with text and metadata. The id is generated by the store.
*
* @param embedding
* The embedding (vector) to be added to the store.
* @param textSegment
* Text and metadata
* @return
* unique id of the embedding
*/
@Override
public String add(Embedding embedding, T textSegment) {
Objects.requireNonNull(embedding, "embedding (param[0]) must not be null");
return getDelegateImplementation().add(embedding, textSegment);
}
/**
* Add a list of embeddings to the store.
*
* @param embeddings
* list of embeddings
* @return
* list of ids
*/
@Override
public List<String> addAll(List<Embedding> embeddings) {
Objects.requireNonNull(embeddings, "embeddings must not be null");
return embeddings.stream().map(this::add).collect(Collectors.toList());
}
/**
* Add a list of embeddings to the store.
*
* @param embeddings
* list of embeddings
* @param textSegments
* list of text segments
* @return
* list of ids
*/
@Override
public List<String> addAll(List<Embedding> embeddings, List<T> textSegments) {
Objects.requireNonNull(embeddings, "embeddings (param[0] must not be null");
Objects.requireNonNull(textSegments, "textSegments (param[1] must not be null");
if (embeddings.size() != textSegments.size()) {
throw new IllegalArgumentException("embeddings and textSegment lists must have the same size");
}
return getDelegateImplementation().addAll(embeddings, textSegments);
}
/**
* Search for relevant embeddings.
*
* @param referenceEmbedding
* The embedding used as a reference. Returned embeddings should be relevant (closest) to this one.
* @param maxResults
* The maximum number of embeddings to be returned.
* @return
* List of relevant embeddings.
*/
@Override
public List<EmbeddingMatch<T>> findRelevant(Embedding referenceEmbedding, int maxResults) {
return getDelegateImplementation().findRelevant(referenceEmbedding, maxResults);
}
/**
* Search for relevant embeddings.
*
* @param referenceEmbedding
* The embedding used as a reference. Returned embeddings should be relevant (closest) to this one.
* @param maxResults
* The maximum number of embeddings to be returned.
* @param minScore
* The minimum similarity score of the returned embeddings.
* @return
* List of relevant embeddings.
*/
@Override
public List<EmbeddingMatch<T>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
return getDelegateImplementation().findRelevant(referenceEmbedding, maxResults, minScore);
}
}

View File

@ -1,152 +0,0 @@
package dev.langchain4j.store.embedding.cassandra;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.AbstractEmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.InvocationTargetException;
/**
* Represent an Embedding store using Cassandra AstraDB.
*
* @author Cedrick Lunven (clun)
*/
@Slf4j
public class AstraDbEmbeddingStore extends AbstractEmbeddingStore<TextSegment> {
/** Default implementation bu can be override. */
private static final String DEFAULT_IMPLEMENTATION =
"dev.langchain4j.store.embedding.cassandra.AstraDbEmbeddingStoreImpl";
/**
* Store Configuration.
*/
private final AstraDbEmbeddingConfiguration configuration;
/**
* Constructor with default table name.
*
* @param config
* load configuration
*/
public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) {
this(DEFAULT_IMPLEMENTATION, config);
}
/**
* Constructor with default table name.
*
* @param config
* load configuration
*/
public AstraDbEmbeddingStore(String impl, AstraDbEmbeddingConfiguration config) {
this.configuration = config;
this.implementationClassName = impl;
getDelegateImplementation();
}
/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
protected EmbeddingStore<TextSegment> loadImplementation()
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
IllegalAccessException, InvocationTargetException {
return (EmbeddingStore<TextSegment>) Class
.forName(implementationClassName)
.getConstructor(AstraDbEmbeddingConfiguration.class)
.newInstance(configuration);
}
public static Builder builder() {
return new Builder();
}
/**
* Syntax Sugar Builder.
*/
public static class Builder {
/**
* Configuration built with the builder
*/
private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf;
/**
* Initialization
*/
public Builder() {
conf = AstraDbEmbeddingConfiguration.builder();
}
/**
* Populating token.
*
* @param token
* token
* @return
* current reference
*/
public Builder token(String token) {
conf.token(token);
return this;
}
/**
* Populating token.
*
* @param databaseId
* database Identifier
* @param databaseRegion
* database region
* @return
* current reference
*/
public Builder database(String databaseId, String databaseRegion) {
conf.databaseId(databaseId);
conf.databaseRegion(databaseRegion);
return this;
}
/**
* Populating model dimension.
*
* @param dimension
* model dimension
* @return
* current reference
*/
public Builder vectorDimension(int dimension) {
conf.dimension(dimension);
return this;
}
/**
* Populating table name.
*
* @param keyspace
* keyspace name
* @param table
* table name
* @return
* current reference
*/
public Builder table(String keyspace, String table) {
conf.keyspace(keyspace);
conf.table(table);
return this;
}
/**
* Building the Store.
*
* @return
* store for Astra.
*/
public AstraDbEmbeddingStore build() {
return new AstraDbEmbeddingStore(conf.build());
}
}
}

View File

@ -1,161 +0,0 @@
package dev.langchain4j.store.embedding.cassandra;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.AbstractEmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder;
/**
* Represents an embeddings with
*/
public class CassandraEmbeddingStore extends AbstractEmbeddingStore<TextSegment> {
/** Default implementation bu can be override. */
private static final String DEFAULT_IMPLEMENTATION =
"dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingStoreImpl";
/**
* Store Configuration.
*/
private final CassandraEmbeddingConfiguration configuration;
/**
* Constructor with default table name.
*
* @param config
* load configuration
*/
public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) {
this(DEFAULT_IMPLEMENTATION, config);
}
/**
* Constructor with default table name.
*
* @param config
* load configuration
*/
public CassandraEmbeddingStore(String impl, CassandraEmbeddingConfiguration config) {
this.configuration = config;
this.implementationClassName = impl;
getDelegateImplementation();
}
/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
protected EmbeddingStore<TextSegment> loadImplementation()
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
IllegalAccessException, InvocationTargetException {
return (EmbeddingStore<TextSegment>) Class
.forName(implementationClassName)
.getConstructor(CassandraEmbeddingConfiguration.class)
.newInstance(configuration);
}
public static CassandraEmbeddingStore.Builder builder() {
return new CassandraEmbeddingStore.Builder();
}
/**
* Syntax Sugar Builder.
*/
public static class Builder {
/**
* Configuration built with the builder
*/
private final CassandraEmbeddingConfigurationBuilder conf;
/**
* Initialization
*/
public Builder() {
conf = CassandraEmbeddingConfiguration.builder();
}
/**
* Populating cassandra port.
*
* @param port
* port
* @return
* current reference
*/
public CassandraEmbeddingStore.Builder port(int port) {
conf.port(port);
return this;
}
/**
* Populating cassandra contact points.
*
* @param hosts
* port
* @return
* current reference
*/
public CassandraEmbeddingStore.Builder contactPoints(String... hosts) {
conf.contactPoints(Arrays.asList(hosts));
return this;
}
/**
* Populating model dimension.
*
* @param dimension
* model dimension
* @return
* current reference
*/
public CassandraEmbeddingStore.Builder vectorDimension(int dimension) {
conf.dimension(dimension);
return this;
}
/**
* Populating datacenter.
*
* @param dc
* datacenter
* @return
* current reference
*/
public CassandraEmbeddingStore.Builder localDataCenter(String dc) {
conf.localDataCenter(dc);
return this;
}
/**
* Populating table name.
*
* @param keyspace
* keyspace name
* @param table
* table name
* @return
* current reference
*/
public CassandraEmbeddingStore.Builder table(String keyspace, String table) {
conf.keyspace(keyspace);
conf.table(table);
return this;
}
/**
* Building the Store.
*
* @return
* store for Astra.
*/
public CassandraEmbeddingStore build() {
return new CassandraEmbeddingStore(conf.build());
}
}
}

View File

@ -1,80 +0,0 @@
package dev.langchain4j.store.embedding.cassandra;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/**
* The configuration objects include a few validation rules.
*/
public class AstraDbEmbeddingConfigurationTest {
@Test
public void should_build_configuration_test() {
AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder()
.token("token").databaseId("dbId").databaseRegion("dbRegion")
.keyspace("ks").dimension(20).table("table")
.build();
Assertions.assertNotNull(config);
Assertions.assertNotNull(config.getToken());
Assertions.assertNotNull(config.getDatabaseId());
Assertions.assertNotNull(config.getDatabaseRegion());
}
@Test
public void should_error_if_no_table_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token").databaseId("dbId").databaseRegion("dbRegion")
.keyspace("ks").dimension(20)
.build());
Assertions. assertEquals("table is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_keyspace_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token").databaseId("dbId").databaseRegion("dbRegion")
.table("ks").dimension(20)
.build());
Assertions. assertEquals("keyspace is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_dimension_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token").databaseId("dbId").databaseRegion("dbRegion")
.table("ks").keyspace("ks")
.build());
Assertions. assertEquals("dimension is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_token_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.databaseId("dbId").databaseRegion("dbRegion")
.table("ks").keyspace("ks").dimension(20)
.build());
Assertions. assertEquals("token is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_database_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> AstraDbEmbeddingConfiguration.builder()
.token("token")
.table("ks").keyspace("ks").dimension(20)
.build());
Assertions. assertEquals("databaseId is marked non-null but is null", exception.getMessage());
}
}

View File

@ -1,84 +0,0 @@
package dev.langchain4j.store.embedding.cassandra;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.Collections;
public class CassandraEmbeddingConfigurationTest {
@Test
public void should_build_configuration_test() {
CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder()
.contactPoints(Collections.singletonList("localhost"))
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
.keyspace("ks").dimension(20).table("table")
.localDataCenter("dc1")
.build();
Assertions.assertNotNull(config);
}
@Test
public void should_error_if_no_datacenter_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(Collections.singletonList("localhost"))
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
.keyspace("ks").dimension(20).table("table")
.build());
Assertions. assertEquals("localDataCenter is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_table_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(Collections.singletonList("localhost"))
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
.keyspace("ks").dimension(20)
.localDataCenter("dc1")
.build());
Assertions. assertEquals("table is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_keyspace_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() ->CassandraEmbeddingConfiguration.builder()
.contactPoints(Collections.singletonList("localhost"))
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
.table("ks").dimension(20)
.localDataCenter("dc1")
.build());
Assertions. assertEquals("keyspace is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_dimension_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.contactPoints(Collections.singletonList("localhost"))
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
.table("ks").keyspace("ks")
.localDataCenter("dc1")
.build());
Assertions. assertEquals("dimension is marked non-null but is null", exception.getMessage());
}
@Test
public void should_error_if_no_contact_points_test() {
// Table is required
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
() -> CassandraEmbeddingConfiguration.builder()
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
.table("ks").keyspace("ks").dimension(20)
.localDataCenter("dc1")
.build());
Assertions. assertEquals("contactPoints is marked non-null but is null", exception.getMessage());
}
}