From ef8f04015b3ce12b622582c18bba24050ffd398d Mon Sep 17 00:00:00 2001 From: deep-learning-dynamo Date: Wed, 27 Sep 2023 17:11:01 +0200 Subject: [PATCH] Removed dynamic loading from AstraDB/Cassandra --- .github/workflows/main.yaml | 6 +- langchain4j-cassandra/pom.xml | 30 +--- .../AstraDbEmbeddingConfiguration.java | 11 +- .../cassandra/AstraDbEmbeddingStore.java | 111 ++++++++++++ .../cassandra/AstraDbEmbeddingStoreImpl.java | 34 ---- .../CassandraEmbeddingConfiguration.java | 15 +- .../cassandra/CassandraEmbeddingStore.java | 152 ++++++++++++++++ .../CassandraEmbeddingStoreImpl.java | 67 ------- .../CassandraEmbeddingStoreSupport.java | 105 +++++------ .../cassandra/AstraDbChatMemoryStore.java | 16 +- .../cassandra/CassandraChatMemoryStore.java | 83 ++++----- .../AstraDbEmbeddingConfigurationTest.java | 96 ++++++++++ .../cassandra/AstraDbEmbeddingStoreTest.java | 66 +++---- .../CassandraEmbeddingConfigurationTest.java | 93 ++++++++++ .../CassandraEmbeddingStoreTest.java | 39 ++-- ...mpleDocumentLoaderAndRagWithAstraTest.java | 17 +- .../cassandra/ChatMemoryStoreAstraTest.java | 47 ++--- .../src/test/resources/logback-test.xml | 20 --- .../ElasticsearchEmbeddingStoreTest.java | 3 +- .../embedding/AbstractEmbeddingStore.java | 167 ------------------ .../cassandra/AstraDbEmbeddingStore.java | 152 ---------------- .../cassandra/CassandraEmbeddingStore.java | 161 ----------------- .../AstraDbEmbeddingConfigurationTest.java | 80 --------- .../CassandraEmbeddingConfigurationTest.java | 84 --------- 24 files changed, 639 insertions(+), 1016 deletions(-) rename {langchain4j => langchain4j-cassandra}/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java (91%) create mode 100644 langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java delete mode 100644 langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreImpl.java rename {langchain4j => langchain4j-cassandra}/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java (90%) create mode 100644 langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java delete mode 100644 langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreImpl.java create mode 100644 langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java create mode 100644 langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java delete mode 100755 langchain4j-cassandra/src/test/resources/logback-test.xml delete mode 100644 langchain4j/src/main/java/dev/langchain4j/store/embedding/AbstractEmbeddingStore.java delete mode 100644 langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java delete mode 100644 langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java delete mode 100644 langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java delete mode 100644 langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 9b02e8289..bbd915088 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -7,11 +7,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 8 uses: actions/setup-java@v3 with: - java-version: '11' - distribution: 'adopt' + java-version: '8' + distribution: 'temurin' - name: Test run: mvn --batch-mode test diff --git a/langchain4j-cassandra/pom.xml b/langchain4j-cassandra/pom.xml index 7e035ca7c..f2dce0714 100644 --- a/langchain4j-cassandra/pom.xml +++ b/langchain4j-cassandra/pom.xml @@ -64,6 +64,14 @@ junit-jupiter-engine test + + + org.assertj + assertj-core + ${assertj.version} + test + + dev.langchain4j langchain4j-open-ai @@ -82,26 +90,4 @@ - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.11.0 - - 11 - 11 - false - - - - org.honton.chas - license-maven-plugin - - true - - - - - \ No newline at end of file diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java similarity index 91% rename from langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java rename to langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java index f67057bf4..7d669da8b 100644 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java +++ b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfiguration.java @@ -3,7 +3,6 @@ package dev.langchain4j.store.embedding.cassandra; import lombok.Builder; import lombok.Getter; import lombok.NonNull; -import lombok.experimental.SuperBuilder; /** * Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore. @@ -13,11 +12,13 @@ import lombok.experimental.SuperBuilder; * * @see CassandraEmbeddingStore */ -@Getter @Builder +@Getter +@Builder public class AstraDbEmbeddingConfiguration { /** * Represents the Api Key to interact with Astra DB + * * @see Astra DB Api Key */ @NonNull @@ -58,8 +59,7 @@ public class AstraDbEmbeddingConfiguration { /** * Initialize the builder. * - * @return - * cassandra embedding configuration buildesr + * @return cassandra embedding configuration builder */ public static AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder builder() { return new AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder(); @@ -68,5 +68,6 @@ public class AstraDbEmbeddingConfiguration { /** * Signature for the builder. */ - public static class AstraDbEmbeddingConfigurationBuilder{} + public static class AstraDbEmbeddingConfigurationBuilder { + } } diff --git a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java new file mode 100644 index 000000000..06ce94265 --- /dev/null +++ b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java @@ -0,0 +1,111 @@ +package dev.langchain4j.store.embedding.cassandra; + +import com.datastax.astra.sdk.AstraClient; +import com.datastax.oss.driver.api.core.CqlSession; +import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable; +import dev.langchain4j.store.embedding.EmbeddingStore; + +/** + * Implementation of {@link EmbeddingStore} using Cassandra AstraDB. + * + * @see EmbeddingStore + * @see MetadataVectorCassandraTable + */ +public class AstraDbEmbeddingStore extends CassandraEmbeddingStoreSupport { + + /** + * Build the store from the configuration. + * + * @param config configuration + */ + public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) { + CqlSession cqlSession = AstraClient.builder() + .withToken(config.getToken()) + .withCqlKeyspace(config.getKeyspace()) + .withDatabaseId(config.getDatabaseId()) + .withDatabaseRegion(config.getDatabaseRegion()) + .enableCql() + .enableDownloadSecureConnectBundle() + .build().cqlSession(); + embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension()); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Syntax Sugar Builder. + */ + public static class Builder { + + /** + * Configuration built with the builder + */ + private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf; + + /** + * Initialization + */ + public Builder() { + conf = AstraDbEmbeddingConfiguration.builder(); + } + + /** + * Populating token. + * + * @param token token + * @return current reference + */ + public Builder token(String token) { + conf.token(token); + return this; + } + + /** + * Populating token. + * + * @param databaseId database Identifier + * @param databaseRegion database region + * @return current reference + */ + public Builder database(String databaseId, String databaseRegion) { + conf.databaseId(databaseId); + conf.databaseRegion(databaseRegion); + return this; + } + + /** + * Populating model dimension. + * + * @param dimension model dimension + * @return current reference + */ + public Builder vectorDimension(int dimension) { + conf.dimension(dimension); + return this; + } + + /** + * Populating table name. + * + * @param keyspace keyspace name + * @param table table name + * @return current reference + */ + public Builder table(String keyspace, String table) { + conf.keyspace(keyspace); + conf.table(table); + return this; + } + + /** + * Building the Store. + * + * @return store for Astra. + */ + public AstraDbEmbeddingStore build() { + return new AstraDbEmbeddingStore(conf.build()); + } + } +} diff --git a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreImpl.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreImpl.java deleted file mode 100644 index 2b1e954a1..000000000 --- a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreImpl.java +++ /dev/null @@ -1,34 +0,0 @@ -package dev.langchain4j.store.embedding.cassandra; - -import com.datastax.astra.sdk.AstraClient; -import com.datastax.oss.driver.api.core.CqlSession; -import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable; -import dev.langchain4j.store.embedding.EmbeddingStore; - -/** - * Implementation of {@link EmbeddingStore} using Cassandra AstraDB. - * - * @see EmbeddingStore - * @see MetadataVectorCassandraTable - */ -public class AstraDbEmbeddingStoreImpl extends CassandraEmbeddingStoreSupport { - - /** - * Build the store from the configuration. - * - * @param config - * configuration - */ - public AstraDbEmbeddingStoreImpl(AstraDbEmbeddingConfiguration config) { - CqlSession cqlSession = AstraClient.builder() - .withToken(config.getToken()) - .withCqlKeyspace(config.getKeyspace()) - .withDatabaseId(config.getDatabaseId()) - .withDatabaseRegion(config.getDatabaseRegion()) - .enableCql() - .enableDownloadSecureConnectBundle() - .build().cqlSession(); - embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension()); - } - -} diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java similarity index 90% rename from langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java rename to langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java index 2b1fbe593..e21bee7a7 100644 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java +++ b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfiguration.java @@ -3,7 +3,6 @@ package dev.langchain4j.store.embedding.cassandra; import lombok.Builder; import lombok.Getter; import lombok.NonNull; -import lombok.experimental.SuperBuilder; import java.util.List; @@ -15,10 +14,13 @@ import java.util.List; * * @see CassandraEmbeddingStore */ -@Getter @Builder +@Getter +@Builder public class CassandraEmbeddingConfiguration { - /** Default Cassandra Port. */ + /** + * Default Cassandra Port. + */ public static Integer DEFAULT_PORT = 9042; // --- Connectivity Parameters --- @@ -74,8 +76,7 @@ public class CassandraEmbeddingConfiguration { /** * Initialize the builder. * - * @return - * cassandra embedding configuration buildesr + * @return cassandra embedding configuration buildesr */ public static CassandraEmbeddingConfigurationBuilder builder() { return new CassandraEmbeddingConfigurationBuilder(); @@ -84,6 +85,6 @@ public class CassandraEmbeddingConfiguration { /** * Signature for the builder. */ - public static class CassandraEmbeddingConfigurationBuilder{} - + public static class CassandraEmbeddingConfigurationBuilder { + } } diff --git a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java new file mode 100644 index 000000000..92c3d1c2b --- /dev/null +++ b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java @@ -0,0 +1,152 @@ +package dev.langchain4j.store.embedding.cassandra; + +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; +import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable; +import com.dtsx.astra.sdk.cassio.SimilarityMetric; +import dev.langchain4j.store.embedding.EmbeddingStore; + +import java.net.InetSocketAddress; +import java.util.Arrays; + +/** + * Implementation of {@link EmbeddingStore} using Cassandra AstraDB. + * + * @see EmbeddingStore + * @see MetadataVectorCassandraTable + */ +public class CassandraEmbeddingStore extends CassandraEmbeddingStoreSupport { + + /** + * Build the store from the configuration. + * + * @param config configuration + */ + public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) { + CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config); + createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace()); + sessionBuilder.withKeyspace(config.getKeyspace()); + this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(), + config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS); + } + + /** + * Build the cassandra session from the config. At the difference of adminSession there + * a keyspace attached to it. + * + * @param config current configuration + * @return cassandra session + */ + private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) { + CqlSessionBuilder cqlSessionBuilder = CqlSession.builder(); + cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter()); + if (config.getUserName() != null && config.getPassword() != null) { + cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword()); + } + config.getContactPoints().forEach(cp -> + cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort()))); + return cqlSessionBuilder; + } + + /** + * Create the keyspace in cassandra Destination if not exist. + */ + private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) { + try (CqlSession adminSession = cqlSessionBuilder.build()) { + adminSession.execute(SchemaBuilder.createKeyspace(keyspace) + .ifNotExists() + .withSimpleStrategy(1) + .withDurableWrites(true) + .build()); + } + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Syntax Sugar Builder. + */ + public static class Builder { + + /** + * Configuration built with the builder + */ + private final CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder conf; + + /** + * Initialization + */ + public Builder() { + conf = CassandraEmbeddingConfiguration.builder(); + } + + /** + * Populating cassandra port. + * + * @param port port + * @return current reference + */ + public CassandraEmbeddingStore.Builder port(int port) { + conf.port(port); + return this; + } + + /** + * Populating cassandra contact points. + * + * @param hosts port + * @return current reference + */ + public CassandraEmbeddingStore.Builder contactPoints(String... hosts) { + conf.contactPoints(Arrays.asList(hosts)); + return this; + } + + /** + * Populating model dimension. + * + * @param dimension model dimension + * @return current reference + */ + public CassandraEmbeddingStore.Builder vectorDimension(int dimension) { + conf.dimension(dimension); + return this; + } + + /** + * Populating datacenter. + * + * @param dc datacenter + * @return current reference + */ + public CassandraEmbeddingStore.Builder localDataCenter(String dc) { + conf.localDataCenter(dc); + return this; + } + + /** + * Populating table name. + * + * @param keyspace keyspace name + * @param table table name + * @return current reference + */ + public CassandraEmbeddingStore.Builder table(String keyspace, String table) { + conf.keyspace(keyspace); + conf.table(table); + return this; + } + + /** + * Building the Store. + * + * @return store for Astra. + */ + public CassandraEmbeddingStore build() { + return new CassandraEmbeddingStore(conf.build()); + } + } +} diff --git a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreImpl.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreImpl.java deleted file mode 100644 index 84598e480..000000000 --- a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreImpl.java +++ /dev/null @@ -1,67 +0,0 @@ -package dev.langchain4j.store.embedding.cassandra; - -import com.datastax.oss.driver.api.core.CqlSession; -import com.datastax.oss.driver.api.core.CqlSessionBuilder; -import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; -import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable; -import com.dtsx.astra.sdk.cassio.SimilarityMetric; -import dev.langchain4j.store.embedding.EmbeddingStore; - -import java.net.InetSocketAddress; - -/** - * Implementation of {@link EmbeddingStore} using Cassandra AstraDB. - * - * @see EmbeddingStore - * @see MetadataVectorCassandraTable - */ -public class CassandraEmbeddingStoreImpl extends CassandraEmbeddingStoreSupport { - - /** - * Build the store from the configuration. - * - * @param config - * configuration - */ - public CassandraEmbeddingStoreImpl(CassandraEmbeddingConfiguration config) { - CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config); - createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace()); - sessionBuilder.withKeyspace(config.getKeyspace()); - this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(), - config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS); - } - - /** - * Build the cassandra session from the config. At the difference of adminSession there - * a keyspace attached to it. - * - * @param config - * current configuration - * @return - * cassandra session - */ - private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) { - CqlSessionBuilder cqlSessionBuilder = CqlSession.builder(); - cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter()); - if (config.getUserName() != null && config.getPassword() != null) { - cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword()); - } - config.getContactPoints().forEach(cp -> - cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort()))); - return cqlSessionBuilder; - } - - /** - * Create the keyspace in cassandra Destination if not exist. - */ - private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) { - try(CqlSession adminSession = cqlSessionBuilder.build()) { - adminSession.execute(SchemaBuilder.createKeyspace(keyspace) - .ifNotExists() - .withSimpleStrategy(1) - .withDurableWrites(true) - .build()); - } - } - -} diff --git a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreSupport.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreSupport.java index e1392bf57..4f3835b6b 100644 --- a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreSupport.java +++ b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreSupport.java @@ -3,6 +3,7 @@ package dev.langchain4j.store.embedding.cassandra; import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable; import com.dtsx.astra.sdk.cassio.SimilarityMetric; import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery; +import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery.SimilaritySearchQueryBuilder; import com.dtsx.astra.sdk.cassio.SimilaritySearchResult; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; @@ -18,11 +19,15 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import static dev.langchain4j.internal.ValidationUtils.ensureBetween; +import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero; +import static java.util.stream.Collectors.toList; + /** * Support for CassandraEmbeddingStore with and Without Astra. */ @Getter -public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore { +abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore { /** * Represents an embedding table in Cassandra, it is a table with a vector column. @@ -34,10 +39,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore addAll(List embeddingList) { @@ -95,18 +91,15 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore addAll(List embeddingList, List textSegmentList) { @@ -115,7 +108,7 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore ids = new ArrayList<>(); - for(int i = 0; i < embeddingList.size(); i++) { + for (int i = 0; i < embeddingList.size(); i++) { ids.add(add(embeddingList.get(i), textSegmentList.get(i))); } return ids; @@ -124,36 +117,30 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore> findRelevant(Embedding embedding, int maxResults, double minScore) { return embeddingTable .similaritySearch(SimilaritySearchQuery.builder() .embeddings(embedding.vectorAsList()) - .recordCount(maxResults) - .threshold(CosineSimilarity.fromRelevanceScore(minScore)) + .recordCount(ensureGreaterThanZero(maxResults, "maxResults")) + .threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore"))) .distance(SimilarityMetric.COS) .build()) .stream() .map(CassandraEmbeddingStoreSupport::mapSearchResult) - .collect(Collectors.toList()); + .collect(toList()); } /** * Map Search result coming from Astra. * - * @param record - * current record - * @return - * search result + * @param record current record + * @return search result */ private static EmbeddingMatch mapSearchResult(SimilaritySearchResult record) { return new EmbeddingMatch<>( @@ -163,33 +150,24 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore> findRelevant(Embedding embedding, Integer maxResults, Double minScore, Metadata metadata) { - SimilaritySearchQuery.SimilaritySearchQueryBuilder builder = - SimilaritySearchQuery.builder().embeddings(embedding.vectorAsList()); - if (maxResults == null || maxResults < 1) { - throw new IllegalArgumentException("maxResults (param[1]) must not be null and greater than 0"); - } - if (minScore == null || minScore < 1 || minScore > 1) { - throw new IllegalArgumentException("minScore (param[2]) must not be null and in between 0 and 1."); - } + public List> findRelevant(Embedding embedding, int maxResults, double minScore, Metadata metadata) { + SimilaritySearchQueryBuilder builder = SimilaritySearchQuery.builder() + .embeddings(embedding.vectorAsList()) + .recordCount(ensureGreaterThanZero(maxResults, "maxResults")) + .threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore"))); if (metadata != null) { builder.metaData(metadata.asMap()); } @@ -197,7 +175,6 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore * The initialization of the CQLSession will be done through an AstraClient - * - * @author Cedrick Lunven (clun) */ public class AstraDbChatMemoryStore extends CassandraChatMemoryStore { /** * Constructor with default table name. * - * @param token - * token - * @param dbId - * database idendifier - * @param dbRegion - * database region - * @param keyspaceName - * keyspace name + * @param token token + * @param dbId database identifier + * @param dbRegion database region + * @param keyspaceName keyspace name */ public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName) { this(token, dbId, dbRegion, keyspaceName, DEFAULT_TABLE_NAME); diff --git a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStore.java b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStore.java index f1ee6fed1..e1702e5e3 100644 --- a/langchain4j-cassandra/src/main/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStore.java +++ b/langchain4j-cassandra/src/main/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStore.java @@ -1,28 +1,19 @@ package dev.langchain4j.store.memory.chat.cassandra; -import com.datastax.astra.sdk.AstraClient; import com.datastax.oss.driver.api.core.CqlSession; -import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record; - import com.datastax.oss.driver.api.core.uuid.Uuids; import com.dtsx.astra.sdk.cassio.ClusteredCassandraTable; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ChatMessageDeserializer; import dev.langchain4j.data.message.ChatMessageSerializer; -import dev.langchain4j.data.message.SystemMessage; -import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.store.memory.chat.ChatMemoryStore; -import lombok.Data; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; + +import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record; +import static java.util.stream.Collectors.toList; /** * Implementation of {@link ChatMemoryStore} using Astra DB Vector Search. @@ -30,8 +21,6 @@ import java.util.stream.Collectors; * is a partition.Message id is a time uuid. * * @see Astra Vector Store Documentation - * @author Cedrick Lunven (clun) - * @since 0.22.0 */ @Slf4j public class CassandraChatMemoryStore implements ChatMemoryStore { @@ -46,18 +35,12 @@ public class CassandraChatMemoryStore implements ChatMemoryStore { */ private final ClusteredCassandraTable messageTable; - /** Object Mapper. */ - private static final ObjectMapper OM = new ObjectMapper(); - /** * Constructor for message store * - * @param session - * cassandra session - * @param keyspaceName - * keyspace name - * @param tableName - * table name + * @param session cassandra session + * @param keyspaceName keyspace name + * @param tableName table name */ public CassandraChatMemoryStore(CqlSession session, String keyspaceName, String tableName) { messageTable = new ClusteredCassandraTable(session, keyspaceName, tableName); @@ -66,52 +49,54 @@ public class CassandraChatMemoryStore implements ChatMemoryStore { /** * Constructor for message store * - * @param session - * cassandra session - * @param keyspaceName - * keyspace name + * @param session cassandra session + * @param keyspaceName keyspace name */ public CassandraChatMemoryStore(CqlSession session, String keyspaceName) { messageTable = new ClusteredCassandraTable(session, keyspaceName, DEFAULT_TABLE_NAME); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public List getMessages(@NonNull Object memoryId) { return messageTable .findPartition(getMemoryId(memoryId)) .stream() .map(this::toChatMessage) - .collect(Collectors.toList()); + .collect(toList()); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override - public void updateMessages(@NonNull Object memoryId, @NonNull List list) { + public void updateMessages(@NonNull Object memoryId, @NonNull List messages) { deleteMessages(memoryId); - messageTable.upsertPartition(list.stream() - .map(r -> this.fromChatMessage(getMemoryId(memoryId), r)) - .collect(Collectors.toList())); + messageTable.upsertPartition(messages.stream() + .map(record -> fromChatMessage(getMemoryId(memoryId), record)) + .collect(toList())); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override - public void deleteMessages(@NonNull Object memoryId) { + public void deleteMessages(@NonNull Object memoryId) { messageTable.deletePartition(getMemoryId(memoryId)); } /** * Unmarshalling Cassandra row as a Message with proper sub-type. * - * @param record - * cassandra record - * @return - * chat message + * @param record cassandra record + * @return chat message */ private ChatMessage toChatMessage(@NonNull Record record) { try { return ChatMessageDeserializer.messageFromJson(record.getBody()); - } catch(Exception e) { + } catch (Exception e) { log.error("Unable to parse message body", e); throw new IllegalArgumentException("Unable to parse message body"); } @@ -119,12 +104,10 @@ public class CassandraChatMemoryStore implements ChatMemoryStore { /** * Serialize the {@link ChatMessage} as a Cassandra Row. - * @param memoryId - * chat session identifier - * @param chatMessage - * chat message - * @return - * cassandra row. + * + * @param memoryId chat session identifier + * @param chatMessage chat message + * @return cassandra row. */ private Record fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) { try { @@ -133,18 +116,16 @@ public class CassandraChatMemoryStore implements ChatMemoryStore { record.setPartitionId(memoryId); record.setBody(ChatMessageSerializer.messageToJson(chatMessage)); return record; - } catch(Exception e) { + } catch (Exception e) { log.error("Unable to parse message body", e); throw new IllegalArgumentException("Unable to parse message body", e); } } private String getMemoryId(Object memoryId) { - if (!(memoryId instanceof String) ) { + if (!(memoryId instanceof String)) { throw new IllegalArgumentException("memoryId must be a String"); } return (String) memoryId; } - - } \ No newline at end of file diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java new file mode 100644 index 000000000..70f505484 --- /dev/null +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java @@ -0,0 +1,96 @@ +package dev.langchain4j.store.embedding.cassandra; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * The configuration objects include a few validation rules. + */ +public class AstraDbEmbeddingConfigurationTest { + + @Test + public void should_build_configuration_test() { + AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder() + .token("token") + .databaseId("dbId") + .databaseRegion("dbRegion") + .keyspace("ks") + .dimension(20) + .table("table") + .build(); + assertNotNull(config); + assertNotNull(config.getToken()); + assertNotNull(config.getDatabaseId()); + assertNotNull(config.getDatabaseRegion()); + } + + @Test + public void should_error_if_no_table_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> AstraDbEmbeddingConfiguration.builder() + .token("token") + .databaseId("dbId") + .databaseRegion("dbRegion") + .keyspace("ks") + .dimension(20) + .build()); + assertEquals("table is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_keyspace_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> AstraDbEmbeddingConfiguration.builder() + .token("token") + .databaseId("dbId") + .databaseRegion("dbRegion") + .table("ks") + .dimension(20) + .build()); + assertEquals("keyspace is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_dimension_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> AstraDbEmbeddingConfiguration.builder() + .token("token") + .databaseId("dbId") + .databaseRegion("dbRegion") + .table("ks") + .keyspace("ks") + .build()); + assertEquals("dimension is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_token_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> AstraDbEmbeddingConfiguration.builder() + .databaseId("dbId") + .databaseRegion("dbRegion") + .table("ks") + .keyspace("ks") + .dimension(20) + .build()); + assertEquals("token is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_database_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> AstraDbEmbeddingConfiguration.builder() + .token("token") + .table("ks") + .keyspace("ks") + .dimension(20) + .build()); + assertEquals("databaseId is marked non-null but is null", exception.getMessage()); + } +} diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreTest.java index cd99bbb6e..38d271768 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreTest.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStoreTest.java @@ -1,14 +1,12 @@ package dev.langchain4j.store.embedding.cassandra; import com.datastax.astra.sdk.AstraClient; -import com.datastax.oss.driver.api.core.CqlSession; import com.dtsx.astra.sdk.utils.TestUtils; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -16,53 +14,59 @@ import java.util.List; import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * Testing implementation of Embedding Store using AstraDB. */ class AstraDbEmbeddingStoreTest { - public static final String TEST_DB = "langchain4j"; - public static final String TEST_KEYSPACE = "langchain4j"; - public static final String TEST_INDEX = "test_embedding_store"; + private static final String TEST_KEYSPACE = "langchain4j"; + private static final String TEST_INDEX = "test_embedding_store"; - @Test - @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") - public void testAddEmbeddingAndFindRelevant() - throws Exception { + private final String astraToken = getAstraToken(); + private final String databaseId = setupDatabase("langchain4j", TEST_KEYSPACE); + private final AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration + .builder() + .token(astraToken) + .databaseId(databaseId) + .databaseRegion(TestUtils.TEST_REGION) + .keyspace(TEST_KEYSPACE) + .table(TEST_INDEX) + .dimension(11) + .build()); - // Read Token from environment variable ASTRA_DB_APPLICATION_TOKEN - String astraToken = getAstraToken(); - - // Database will be created if not exist (can take 90 seconds on first run) - String databaseId = setupDatabase(TEST_DB, TEST_KEYSPACE); - - // Store initialization - AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration - .builder().token(astraToken).databaseId(databaseId) - .databaseRegion(TestUtils.TEST_REGION) - .keyspace(TEST_KEYSPACE) - .table(TEST_INDEX) - .dimension(11).build()); - - // Flushing Table before Start (idem potent) + @BeforeEach + void truncateTable() { AstraClient.builder() - .withToken(astraToken) + .withToken(getAstraToken()) .withCqlKeyspace(TEST_KEYSPACE) .withDatabaseId(databaseId) .withDatabaseRegion(TestUtils.TEST_REGION) .enableCql() .enableDownloadSecureConnectBundle() .build().cqlSession().execute("TRUNCATE TABLE " + TEST_INDEX); + } + + @Test + @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") + void testAddEmbeddingAndFindRelevant() { Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F}); - TextSegment textSegment = TextSegment.textSegment("Text", Metadata.from("Key", "Value")); - String added = astraDbEmbeddingStore.add(embedding, textSegment); - Assertions.assertTrue(added != null && !added.isEmpty()); + TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value")); + String id = astraDbEmbeddingStore.add(embedding, textSegment); + assertTrue(id != null && !id.isEmpty()); Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F}); List> embeddingMatches = astraDbEmbeddingStore.findRelevant(refereceEmbedding, 10); - Assertions.assertEquals(1, embeddingMatches.size()); - } + assertEquals(1, embeddingMatches.size()); + EmbeddingMatch embeddingMatch = embeddingMatches.get(0); + assertThat(embeddingMatch.score()).isBetween(0d, 1d); + assertThat(embeddingMatch.embeddingId()).isEqualTo(id); + assertThat(embeddingMatch.embedding()).isEqualTo(embedding); + assertThat(embeddingMatch.embedded()).isEqualTo(textSegment); + } } diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java new file mode 100644 index 000000000..568142763 --- /dev/null +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java @@ -0,0 +1,93 @@ +package dev.langchain4j.store.embedding.cassandra; + +import org.junit.jupiter.api.Test; + +import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.DEFAULT_PORT; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.*; + +public class CassandraEmbeddingConfigurationTest { + + @Test + public void should_build_configuration_test() { + CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder() + .contactPoints(singletonList("localhost")) + .port(DEFAULT_PORT) + .keyspace("ks") + .dimension(20) + .table("table") + .localDataCenter("dc1") + .build(); + assertNotNull(config); + } + + @Test + public void should_error_if_no_datacenter_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> CassandraEmbeddingConfiguration.builder() + .contactPoints(singletonList("localhost")) + .port(DEFAULT_PORT) + .keyspace("ks") + .dimension(20) + .table("table") + .build()); + assertEquals("localDataCenter is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_table_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> CassandraEmbeddingConfiguration.builder() + .contactPoints(singletonList("localhost")) + .port(DEFAULT_PORT) + .keyspace("ks") + .dimension(20) + .localDataCenter("dc1") + .build()); + assertEquals("table is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_keyspace_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> CassandraEmbeddingConfiguration.builder() + .contactPoints(singletonList("localhost")) + .port(DEFAULT_PORT) + .table("ks") + .dimension(20) + .localDataCenter("dc1") + .build()); + assertEquals("keyspace is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_dimension_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> CassandraEmbeddingConfiguration.builder() + .contactPoints(singletonList("localhost")) + .port(DEFAULT_PORT) + .table("ks") + .keyspace("ks") + .localDataCenter("dc1") + .build()); + assertEquals("dimension is marked non-null but is null", exception.getMessage()); + } + + @Test + public void should_error_if_no_contact_points_test() { + // Table is required + NullPointerException exception = assertThrows(NullPointerException.class, + () -> CassandraEmbeddingConfiguration.builder() + .port(DEFAULT_PORT) + .table("ks") + .keyspace("ks") + .dimension(20) + .localDataCenter("dc1") + .build()); + assertEquals("contactPoints is marked non-null but is null", exception.getMessage()); + } +} diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreTest.java index e3f936f14..1fe49ed81 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreTest.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreTest.java @@ -4,49 +4,50 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingStore; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import java.util.Collections; import java.util.List; +import static dev.langchain4j.internal.Utils.randomUUID; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + /** * Work with Cassandra Embedding Store. */ class CassandraEmbeddingStoreTest { - public static final String TEST_KEYSPACE = "langchain4j"; - public static final String TEST_INDEX = "test_embedding_store"; - @Test @Disabled("To run this test, you must have a local Cassandra instance, a docker-compose is provided") - public void testAddEmbeddingAndFindRelevant() - throws Exception { + public void testAddEmbeddingAndFindRelevant() { CassandraEmbeddingStore cassandraEmbeddingStore = initStore(); Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F}); - TextSegment textSegment = TextSegment.textSegment("Text", Metadata.from("Key", "Value")); - String added = cassandraEmbeddingStore.add(embedding, textSegment); - Assertions.assertTrue(added != null && !added.isEmpty()); + TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value")); + String id = cassandraEmbeddingStore.add(embedding, textSegment); + assertTrue(id != null && !id.isEmpty()); Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F}); - List> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 10); - Assertions.assertEquals(1, embeddingMatches.size()); + List> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 1); + assertEquals(1, embeddingMatches.size()); + + EmbeddingMatch embeddingMatch = embeddingMatches.get(0); + assertThat(embeddingMatch.score()).isBetween(0d, 1d); + assertThat(embeddingMatch.embeddingId()).isEqualTo(id); + assertThat(embeddingMatch.embedding()).isEqualTo(embedding); + assertThat(embeddingMatch.embedded()).isEqualTo(textSegment); } - private CassandraEmbeddingStore initStore() - throws Exception { - return CassandraEmbeddingStore - .builder() + private CassandraEmbeddingStore initStore() { + return CassandraEmbeddingStore.builder() .contactPoints("127.0.0.1") .port(9042) .localDataCenter("datacenter1") - .table(TEST_KEYSPACE, TEST_INDEX) + .table("langchain4j", "table_" + randomUUID().replace("-", "")) .vectorDimension(11) .build(); } - } diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java index bb5101124..38b4efc37 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java @@ -20,7 +20,6 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -29,7 +28,6 @@ import java.nio.file.Path; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase; @@ -37,22 +35,23 @@ import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002; import static java.time.Duration.ofSeconds; import static java.util.stream.Collectors.joining; +import static org.junit.jupiter.api.Assertions.assertNotNull; class SampleDocumentLoaderAndRagWithAstraTest { @Test @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*") - public void shouldRagWithOpenAiAndAstra() throws InterruptedException { + void shouldRagWithOpenAiAndAstra() { // Initialization - String astraToken = getAstraToken(); - String databaseId = setupDatabase("langchain4j", "langchain4j"); - String openAIKey = System.getenv("OPENAI_API_KEY"); + String astraToken = getAstraToken(); + String databaseId = setupDatabase("langchain4j", "langchain4j"); + String openAIKey = System.getenv("OPENAI_API_KEY"); // Given - Assertions.assertNotNull(openAIKey); - Assertions.assertNotNull(databaseId); - Assertions.assertNotNull(astraToken); + assertNotNull(openAIKey); + assertNotNull(databaseId); + assertNotNull(astraToken); // --- Ingesting documents --- diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java index 1ba55640e..f8fcd84cc 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java @@ -1,26 +1,23 @@ package dev.langchain4j.store.memory.chat.cassandra; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.TokenWindowChatMemory; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiTokenizer; -import dev.langchain4j.model.output.Response; import dev.langchain4j.store.memory.chat.ChatMemoryStore; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import java.util.UUID; -import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION; -import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; -import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase; +import static com.dtsx.astra.sdk.utils.TestUtils.*; +import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; -import static java.time.Duration.ofSeconds; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; /** * Test Cassandra Chat Memory Store with a Saas DB. @@ -28,27 +25,19 @@ import static java.time.Duration.ofSeconds; public class ChatMemoryStoreAstraTest { @Test + @Disabled("bug: order of retrieved messages is wrong") @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*") - void chatMemoryAstraTest() throws InterruptedException { + void chatMemoryAstraTest() { + // Initialization - String astraToken = getAstraToken(); - String databaseId = setupDatabase("langchain4j", "langchain4j"); - String openAIKey = System.getenv("OPENAI_API_KEY"); + String astraToken = getAstraToken(); + String databaseId = setupDatabase("langchain4j", "langchain4j"); // Given - Assertions.assertNotNull(openAIKey); - Assertions.assertNotNull(databaseId); - Assertions.assertNotNull(astraToken); - // Given - ChatLanguageModel model = OpenAiChatModel.builder() - .apiKey(openAIKey) - .modelName(GPT_3_5_TURBO) - .temperature(0.3) - .timeout(ofSeconds(120)) - .logRequests(true) - .logResponses(true) - .build(); + assertNotNull(databaseId); + assertNotNull(astraToken); + // When ChatMemoryStore chatMemoryStore = new AstraDbChatMemoryStore(astraToken, databaseId, TEST_REGION, "langchain4j"); @@ -61,11 +50,13 @@ public class ChatMemoryStoreAstraTest { .build(); // When - chatMemory.add(userMessage("I will ask you a few question about ff4j. Response in a single sentence")); - chatMemory.add(userMessage("Can I use it with Javascript ? ")); + UserMessage userMessage = userMessage("I will ask you a few question about ff4j."); + chatMemory.add(userMessage); + + AiMessage aiMessage = aiMessage("Sure, go ahead!"); + chatMemory.add(aiMessage); // Then - Response output = model.generate(chatMemory.messages()); - Assertions.assertNotNull(output.content().text()); + assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage); } } diff --git a/langchain4j-cassandra/src/test/resources/logback-test.xml b/langchain4j-cassandra/src/test/resources/logback-test.xml deleted file mode 100755 index d98495634..000000000 --- a/langchain4j-cassandra/src/test/resources/logback-test.xml +++ /dev/null @@ -1,20 +0,0 @@ - - - - - %d{HH:mm:ss.SSS} %magenta(%-5level) %cyan(%-47logger) : %msg%n - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreTest.java b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreTest.java index 90024cb36..56b919e16 100644 --- a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreTest.java +++ b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreTest.java @@ -15,6 +15,7 @@ import java.util.List; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; @Disabled("needs Elasticsearch to be running locally") class ElasticsearchEmbeddingStoreTest { @@ -77,7 +78,7 @@ class ElasticsearchEmbeddingStoreTest { @Test void testAddNotEqualSizeEmbeddingAndEmbedded() { - Throwable ex = Assertions.assertThrows(IllegalArgumentException.class, () -> store.addAll(asList( + Throwable ex = assertThrows(IllegalArgumentException.class, () -> store.addAll(asList( Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)), Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f, 0.55f)), Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f)) diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/AbstractEmbeddingStore.java b/langchain4j/src/main/java/dev/langchain4j/store/embedding/AbstractEmbeddingStore.java deleted file mode 100644 index b6059d3f8..000000000 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/AbstractEmbeddingStore.java +++ /dev/null @@ -1,167 +0,0 @@ -package dev.langchain4j.store.embedding; - -import dev.langchain4j.data.embedding.Embedding; - -import java.lang.reflect.InvocationTargetException; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** - * This parent class would implement the boilerplate code to invoke the concrete implementation. - */ -public abstract class AbstractEmbeddingStore implements EmbeddingStore { - - /** Concrete Implementation if class is available. */ - protected String implementationClassName; - - /** - * Concrete Implementation (delegate Pattern). - */ - protected EmbeddingStore delegateImplementation; - - /** - * Initialize the concrete implementation. - */ - protected abstract EmbeddingStore loadImplementation() - throws ClassNotFoundException, NoSuchMethodException, InstantiationException, - IllegalAccessException, InvocationTargetException; - - /** - * Getter for the concrete implementation with an initialization. - * - * @return - * delegate implementation - */ - protected EmbeddingStore getDelegateImplementation() { - if (delegateImplementation == null) { - try { - this.delegateImplementation = loadImplementation(); - } catch(ClassNotFoundException cnf) { - throw new RuntimeException( - "Class '" + implementationClassName + "' not found, please checks your dependencies.", cnf); - } catch(NoSuchMethodException e) { - throw new IllegalArgumentException( - "Class '" + implementationClassName + "' does not have constructor with expected parameters", e); - } catch (InvocationTargetException | InstantiationException | IllegalAccessException e) { - throw new IllegalArgumentException( - "Constructor of class '" + implementationClassName + "' cannot be invoked check visibility or code.", e); - } catch(Exception e) { - throw new RuntimeException("Unexpected error while loading implementation", e); - } - } - return delegateImplementation; - } - - /** - * Add embedding (vector only) to the store. The id is generated by the store. - * - * @param embedding - * The embedding (vector) to be added to the store. - * @return - * unique id of the embedding - */ - @Override - public String add(Embedding embedding) { - return getDelegateImplementation().add(embedding); - } - - /** - * Add embedding (vector only) to the store enforcing its identifier. - * - * @param embeddingId - * unique identifier of the embedding - * @param embedding - * The embedding (vector) to be added to the store. - */ - @Override - public void add(String embeddingId, Embedding embedding) { - Objects.requireNonNull(embeddingId, "embeddingId (param[0]) must not be null"); - Objects.requireNonNull(embedding, "embedding (param[1]) must not be null"); - getDelegateImplementation().add(embeddingId, embedding); - } - - /** - * Add embedding to the store with text and metadata. The id is generated by the store. - * - * @param embedding - * The embedding (vector) to be added to the store. - * @param textSegment - * Text and metadata - * @return - * unique id of the embedding - */ - @Override - public String add(Embedding embedding, T textSegment) { - Objects.requireNonNull(embedding, "embedding (param[0]) must not be null"); - return getDelegateImplementation().add(embedding, textSegment); - } - - /** - * Add a list of embeddings to the store. - * - * @param embeddings - * list of embeddings - * @return - * list of ids - */ - @Override - public List addAll(List embeddings) { - Objects.requireNonNull(embeddings, "embeddings must not be null"); - return embeddings.stream().map(this::add).collect(Collectors.toList()); - } - - /** - * Add a list of embeddings to the store. - * - * @param embeddings - * list of embeddings - * @param textSegments - * list of text segments - * @return - * list of ids - */ - @Override - public List addAll(List embeddings, List textSegments) { - Objects.requireNonNull(embeddings, "embeddings (param[0] must not be null"); - Objects.requireNonNull(textSegments, "textSegments (param[1] must not be null"); - if (embeddings.size() != textSegments.size()) { - throw new IllegalArgumentException("embeddings and textSegment lists must have the same size"); - } - return getDelegateImplementation().addAll(embeddings, textSegments); - } - - /** - * Search for relevant embeddings. - * - * @param referenceEmbedding - * The embedding used as a reference. Returned embeddings should be relevant (closest) to this one. - * @param maxResults - * The maximum number of embeddings to be returned. - * @return - * List of relevant embeddings. - */ - @Override - public List> findRelevant(Embedding referenceEmbedding, int maxResults) { - return getDelegateImplementation().findRelevant(referenceEmbedding, maxResults); - } - - /** - * Search for relevant embeddings. - * - * @param referenceEmbedding - * The embedding used as a reference. Returned embeddings should be relevant (closest) to this one. - * @param maxResults - * The maximum number of embeddings to be returned. - * @param minScore - * The minimum similarity score of the returned embeddings. - * @return - * List of relevant embeddings. - */ - @Override - public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { - return getDelegateImplementation().findRelevant(referenceEmbedding, maxResults, minScore); - } - -} diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java b/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java deleted file mode 100644 index 3a72c192e..000000000 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingStore.java +++ /dev/null @@ -1,152 +0,0 @@ -package dev.langchain4j.store.embedding.cassandra; - -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.store.embedding.AbstractEmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStore; -import lombok.extern.slf4j.Slf4j; - -import java.lang.reflect.InvocationTargetException; - -/** - * Represent an Embedding store using Cassandra AstraDB. - * - * @author Cedrick Lunven (clun) - */ -@Slf4j -public class AstraDbEmbeddingStore extends AbstractEmbeddingStore { - - /** Default implementation bu can be override. */ - private static final String DEFAULT_IMPLEMENTATION = - "dev.langchain4j.store.embedding.cassandra.AstraDbEmbeddingStoreImpl"; - - /** - * Store Configuration. - */ - private final AstraDbEmbeddingConfiguration configuration; - - /** - * Constructor with default table name. - * - * @param config - * load configuration - */ - public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) { - this(DEFAULT_IMPLEMENTATION, config); - } - - /** - * Constructor with default table name. - * - * @param config - * load configuration - */ - public AstraDbEmbeddingStore(String impl, AstraDbEmbeddingConfiguration config) { - this.configuration = config; - this.implementationClassName = impl; - getDelegateImplementation(); - } - - /** {@inheritDoc} */ - @Override - @SuppressWarnings("unchecked") - protected EmbeddingStore loadImplementation() - throws ClassNotFoundException, NoSuchMethodException, InstantiationException, - IllegalAccessException, InvocationTargetException { - return (EmbeddingStore) Class - .forName(implementationClassName) - .getConstructor(AstraDbEmbeddingConfiguration.class) - .newInstance(configuration); - } - - public static Builder builder() { - return new Builder(); - } - - /** - * Syntax Sugar Builder. - */ - public static class Builder { - - /** - * Configuration built with the builder - */ - private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf; - - /** - * Initialization - */ - public Builder() { - conf = AstraDbEmbeddingConfiguration.builder(); - } - - /** - * Populating token. - * - * @param token - * token - * @return - * current reference - */ - public Builder token(String token) { - conf.token(token); - return this; - } - - /** - * Populating token. - * - * @param databaseId - * database Identifier - * @param databaseRegion - * database region - * @return - * current reference - */ - public Builder database(String databaseId, String databaseRegion) { - conf.databaseId(databaseId); - conf.databaseRegion(databaseRegion); - return this; - } - - /** - * Populating model dimension. - * - * @param dimension - * model dimension - * @return - * current reference - */ - public Builder vectorDimension(int dimension) { - conf.dimension(dimension); - return this; - } - - /** - * Populating table name. - * - * @param keyspace - * keyspace name - * @param table - * table name - * @return - * current reference - */ - public Builder table(String keyspace, String table) { - conf.keyspace(keyspace); - conf.table(table); - return this; - } - - /** - * Building the Store. - * - * @return - * store for Astra. - */ - public AstraDbEmbeddingStore build() { - return new AstraDbEmbeddingStore(conf.build()); - } - - } - -} diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java b/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java deleted file mode 100644 index b64082614..000000000 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.java +++ /dev/null @@ -1,161 +0,0 @@ -package dev.langchain4j.store.embedding.cassandra; - -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.store.embedding.AbstractEmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStore; - -import java.lang.reflect.InvocationTargetException; -import java.util.Arrays; - -import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder; - -/** - * Represents an embeddings with - */ -public class CassandraEmbeddingStore extends AbstractEmbeddingStore { - - /** Default implementation bu can be override. */ - private static final String DEFAULT_IMPLEMENTATION = - "dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingStoreImpl"; - - /** - * Store Configuration. - */ - private final CassandraEmbeddingConfiguration configuration; - - /** - * Constructor with default table name. - * - * @param config - * load configuration - */ - public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) { - this(DEFAULT_IMPLEMENTATION, config); - } - - /** - * Constructor with default table name. - * - * @param config - * load configuration - */ - public CassandraEmbeddingStore(String impl, CassandraEmbeddingConfiguration config) { - this.configuration = config; - this.implementationClassName = impl; - getDelegateImplementation(); - } - - /** {@inheritDoc} */ - @Override - @SuppressWarnings("unchecked") - protected EmbeddingStore loadImplementation() - throws ClassNotFoundException, NoSuchMethodException, InstantiationException, - IllegalAccessException, InvocationTargetException { - return (EmbeddingStore) Class - .forName(implementationClassName) - .getConstructor(CassandraEmbeddingConfiguration.class) - .newInstance(configuration); - } - - public static CassandraEmbeddingStore.Builder builder() { - return new CassandraEmbeddingStore.Builder(); - } - - /** - * Syntax Sugar Builder. - */ - public static class Builder { - - /** - * Configuration built with the builder - */ - private final CassandraEmbeddingConfigurationBuilder conf; - - /** - * Initialization - */ - public Builder() { - conf = CassandraEmbeddingConfiguration.builder(); - } - - /** - * Populating cassandra port. - * - * @param port - * port - * @return - * current reference - */ - public CassandraEmbeddingStore.Builder port(int port) { - conf.port(port); - return this; - } - - /** - * Populating cassandra contact points. - * - * @param hosts - * port - * @return - * current reference - */ - public CassandraEmbeddingStore.Builder contactPoints(String... hosts) { - conf.contactPoints(Arrays.asList(hosts)); - return this; - } - - /** - * Populating model dimension. - * - * @param dimension - * model dimension - * @return - * current reference - */ - public CassandraEmbeddingStore.Builder vectorDimension(int dimension) { - conf.dimension(dimension); - return this; - } - - /** - * Populating datacenter. - * - * @param dc - * datacenter - * @return - * current reference - */ - public CassandraEmbeddingStore.Builder localDataCenter(String dc) { - conf.localDataCenter(dc); - return this; - } - - /** - * Populating table name. - * - * @param keyspace - * keyspace name - * @param table - * table name - * @return - * current reference - */ - public CassandraEmbeddingStore.Builder table(String keyspace, String table) { - conf.keyspace(keyspace); - conf.table(table); - return this; - } - - /** - * Building the Store. - * - * @return - * store for Astra. - */ - public CassandraEmbeddingStore build() { - return new CassandraEmbeddingStore(conf.build()); - } - - } - -} diff --git a/langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java b/langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java deleted file mode 100644 index 3b5f898fd..000000000 --- a/langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/AstraDbEmbeddingConfigurationTest.java +++ /dev/null @@ -1,80 +0,0 @@ -package dev.langchain4j.store.embedding.cassandra; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -/** - * The configuration objects include a few validation rules. - */ -public class AstraDbEmbeddingConfigurationTest { - - @Test - public void should_build_configuration_test() { - AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder() - .token("token").databaseId("dbId").databaseRegion("dbRegion") - .keyspace("ks").dimension(20).table("table") - .build(); - Assertions.assertNotNull(config); - Assertions.assertNotNull(config.getToken()); - Assertions.assertNotNull(config.getDatabaseId()); - Assertions.assertNotNull(config.getDatabaseRegion()); - } - @Test - public void should_error_if_no_table_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> AstraDbEmbeddingConfiguration.builder() - .token("token").databaseId("dbId").databaseRegion("dbRegion") - .keyspace("ks").dimension(20) - .build()); - Assertions. assertEquals("table is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_keyspace_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> AstraDbEmbeddingConfiguration.builder() - .token("token").databaseId("dbId").databaseRegion("dbRegion") - .table("ks").dimension(20) - .build()); - Assertions. assertEquals("keyspace is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_dimension_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> AstraDbEmbeddingConfiguration.builder() - .token("token").databaseId("dbId").databaseRegion("dbRegion") - .table("ks").keyspace("ks") - .build()); - Assertions. assertEquals("dimension is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_token_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> AstraDbEmbeddingConfiguration.builder() - .databaseId("dbId").databaseRegion("dbRegion") - .table("ks").keyspace("ks").dimension(20) - .build()); - Assertions. assertEquals("token is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_database_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> AstraDbEmbeddingConfiguration.builder() - .token("token") - .table("ks").keyspace("ks").dimension(20) - .build()); - Assertions. assertEquals("databaseId is marked non-null but is null", exception.getMessage()); - } - - - - -} diff --git a/langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java b/langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java deleted file mode 100644 index 74a163357..000000000 --- a/langchain4j/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingConfigurationTest.java +++ /dev/null @@ -1,84 +0,0 @@ -package dev.langchain4j.store.embedding.cassandra; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.Collections; - -public class CassandraEmbeddingConfigurationTest { - - @Test - public void should_build_configuration_test() { - CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder() - .contactPoints(Collections.singletonList("localhost")) - .port(CassandraEmbeddingConfiguration.DEFAULT_PORT) - .keyspace("ks").dimension(20).table("table") - .localDataCenter("dc1") - .build(); - Assertions.assertNotNull(config); - } - - @Test - public void should_error_if_no_datacenter_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> CassandraEmbeddingConfiguration.builder() - .contactPoints(Collections.singletonList("localhost")) - .port(CassandraEmbeddingConfiguration.DEFAULT_PORT) - .keyspace("ks").dimension(20).table("table") - .build()); - Assertions. assertEquals("localDataCenter is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_table_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> CassandraEmbeddingConfiguration.builder() - .contactPoints(Collections.singletonList("localhost")) - .port(CassandraEmbeddingConfiguration.DEFAULT_PORT) - .keyspace("ks").dimension(20) - .localDataCenter("dc1") - .build()); - Assertions. assertEquals("table is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_keyspace_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () ->CassandraEmbeddingConfiguration.builder() - .contactPoints(Collections.singletonList("localhost")) - .port(CassandraEmbeddingConfiguration.DEFAULT_PORT) - .table("ks").dimension(20) - .localDataCenter("dc1") - .build()); - Assertions. assertEquals("keyspace is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_dimension_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> CassandraEmbeddingConfiguration.builder() - .contactPoints(Collections.singletonList("localhost")) - .port(CassandraEmbeddingConfiguration.DEFAULT_PORT) - .table("ks").keyspace("ks") - .localDataCenter("dc1") - .build()); - Assertions. assertEquals("dimension is marked non-null but is null", exception.getMessage()); - } - - @Test - public void should_error_if_no_contact_points_test() { - // Table is required - NullPointerException exception = Assertions.assertThrows(NullPointerException.class, - () -> CassandraEmbeddingConfiguration.builder() - .port(CassandraEmbeddingConfiguration.DEFAULT_PORT) - .table("ks").keyspace("ks").dimension(20) - .localDataCenter("dc1") - .build()); - Assertions. assertEquals("contactPoints is marked non-null but is null", exception.getMessage()); - } - -}