Removed dynamic loading from AstraDB/Cassandra
This commit is contained in:
parent
c632322493
commit
ef8f04015b
|
@ -7,11 +7,11 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- name: Set up JDK 11
|
- name: Set up JDK 8
|
||||||
uses: actions/setup-java@v3
|
uses: actions/setup-java@v3
|
||||||
with:
|
with:
|
||||||
java-version: '11'
|
java-version: '8'
|
||||||
distribution: 'adopt'
|
distribution: 'temurin'
|
||||||
- name: Test
|
- name: Test
|
||||||
run: mvn --batch-mode test
|
run: mvn --batch-mode test
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,14 @@
|
||||||
<artifactId>junit-jupiter-engine</artifactId>
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.assertj</groupId>
|
||||||
|
<artifactId>assertj-core</artifactId>
|
||||||
|
<version>${assertj.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-open-ai</artifactId>
|
<artifactId>langchain4j-open-ai</artifactId>
|
||||||
|
@ -82,26 +90,4 @@
|
||||||
</license>
|
</license>
|
||||||
</licenses>
|
</licenses>
|
||||||
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<version>3.11.0</version>
|
|
||||||
<configuration>
|
|
||||||
<source>11</source>
|
|
||||||
<target>11</target>
|
|
||||||
<showWarnings>false</showWarnings>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.honton.chas</groupId>
|
|
||||||
<artifactId>license-maven-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<skipCompliance>true</skipCompliance>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
</project>
|
</project>
|
|
@ -3,7 +3,6 @@ package dev.langchain4j.store.embedding.cassandra;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.experimental.SuperBuilder;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore.
|
* Plain old Java Object (POJO) to hold the configuration for the CassandraEmbeddingStore.
|
||||||
|
@ -13,11 +12,13 @@ import lombok.experimental.SuperBuilder;
|
||||||
*
|
*
|
||||||
* @see CassandraEmbeddingStore
|
* @see CassandraEmbeddingStore
|
||||||
*/
|
*/
|
||||||
@Getter @Builder
|
@Getter
|
||||||
|
@Builder
|
||||||
public class AstraDbEmbeddingConfiguration {
|
public class AstraDbEmbeddingConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents the Api Key to interact with Astra DB
|
* Represents the Api Key to interact with Astra DB
|
||||||
|
*
|
||||||
* @see <a href="https://docs.datastax.com/en/astra/docs/manage-application-tokens.html">Astra DB Api Key</a>
|
* @see <a href="https://docs.datastax.com/en/astra/docs/manage-application-tokens.html">Astra DB Api Key</a>
|
||||||
*/
|
*/
|
||||||
@NonNull
|
@NonNull
|
||||||
|
@ -58,8 +59,7 @@ public class AstraDbEmbeddingConfiguration {
|
||||||
/**
|
/**
|
||||||
* Initialize the builder.
|
* Initialize the builder.
|
||||||
*
|
*
|
||||||
* @return
|
* @return cassandra embedding configuration builder
|
||||||
* cassandra embedding configuration buildesr
|
|
||||||
*/
|
*/
|
||||||
public static AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder builder() {
|
public static AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder builder() {
|
||||||
return new AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder();
|
return new AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder();
|
||||||
|
@ -68,5 +68,6 @@ public class AstraDbEmbeddingConfiguration {
|
||||||
/**
|
/**
|
||||||
* Signature for the builder.
|
* Signature for the builder.
|
||||||
*/
|
*/
|
||||||
public static class AstraDbEmbeddingConfigurationBuilder{}
|
public static class AstraDbEmbeddingConfigurationBuilder {
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
package dev.langchain4j.store.embedding.cassandra;
|
||||||
|
|
||||||
|
import com.datastax.astra.sdk.AstraClient;
|
||||||
|
import com.datastax.oss.driver.api.core.CqlSession;
|
||||||
|
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
|
||||||
|
*
|
||||||
|
* @see EmbeddingStore
|
||||||
|
* @see MetadataVectorCassandraTable
|
||||||
|
*/
|
||||||
|
public class AstraDbEmbeddingStore extends CassandraEmbeddingStoreSupport {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the store from the configuration.
|
||||||
|
*
|
||||||
|
* @param config configuration
|
||||||
|
*/
|
||||||
|
public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) {
|
||||||
|
CqlSession cqlSession = AstraClient.builder()
|
||||||
|
.withToken(config.getToken())
|
||||||
|
.withCqlKeyspace(config.getKeyspace())
|
||||||
|
.withDatabaseId(config.getDatabaseId())
|
||||||
|
.withDatabaseRegion(config.getDatabaseRegion())
|
||||||
|
.enableCql()
|
||||||
|
.enableDownloadSecureConnectBundle()
|
||||||
|
.build().cqlSession();
|
||||||
|
embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Syntax Sugar Builder.
|
||||||
|
*/
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration built with the builder
|
||||||
|
*/
|
||||||
|
private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialization
|
||||||
|
*/
|
||||||
|
public Builder() {
|
||||||
|
conf = AstraDbEmbeddingConfiguration.builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating token.
|
||||||
|
*
|
||||||
|
* @param token token
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public Builder token(String token) {
|
||||||
|
conf.token(token);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating token.
|
||||||
|
*
|
||||||
|
* @param databaseId database Identifier
|
||||||
|
* @param databaseRegion database region
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public Builder database(String databaseId, String databaseRegion) {
|
||||||
|
conf.databaseId(databaseId);
|
||||||
|
conf.databaseRegion(databaseRegion);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating model dimension.
|
||||||
|
*
|
||||||
|
* @param dimension model dimension
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public Builder vectorDimension(int dimension) {
|
||||||
|
conf.dimension(dimension);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating table name.
|
||||||
|
*
|
||||||
|
* @param keyspace keyspace name
|
||||||
|
* @param table table name
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public Builder table(String keyspace, String table) {
|
||||||
|
conf.keyspace(keyspace);
|
||||||
|
conf.table(table);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Building the Store.
|
||||||
|
*
|
||||||
|
* @return store for Astra.
|
||||||
|
*/
|
||||||
|
public AstraDbEmbeddingStore build() {
|
||||||
|
return new AstraDbEmbeddingStore(conf.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,34 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
|
||||||
|
|
||||||
import com.datastax.astra.sdk.AstraClient;
|
|
||||||
import com.datastax.oss.driver.api.core.CqlSession;
|
|
||||||
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
|
|
||||||
*
|
|
||||||
* @see EmbeddingStore
|
|
||||||
* @see MetadataVectorCassandraTable
|
|
||||||
*/
|
|
||||||
public class AstraDbEmbeddingStoreImpl extends CassandraEmbeddingStoreSupport {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the store from the configuration.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* configuration
|
|
||||||
*/
|
|
||||||
public AstraDbEmbeddingStoreImpl(AstraDbEmbeddingConfiguration config) {
|
|
||||||
CqlSession cqlSession = AstraClient.builder()
|
|
||||||
.withToken(config.getToken())
|
|
||||||
.withCqlKeyspace(config.getKeyspace())
|
|
||||||
.withDatabaseId(config.getDatabaseId())
|
|
||||||
.withDatabaseRegion(config.getDatabaseRegion())
|
|
||||||
.enableCql()
|
|
||||||
.enableDownloadSecureConnectBundle()
|
|
||||||
.build().cqlSession();
|
|
||||||
embeddingTable = new MetadataVectorCassandraTable(cqlSession, config.getKeyspace(), config.getTable(), config.getDimension());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -3,7 +3,6 @@ package dev.langchain4j.store.embedding.cassandra;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.experimental.SuperBuilder;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -15,10 +14,13 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @see CassandraEmbeddingStore
|
* @see CassandraEmbeddingStore
|
||||||
*/
|
*/
|
||||||
@Getter @Builder
|
@Getter
|
||||||
|
@Builder
|
||||||
public class CassandraEmbeddingConfiguration {
|
public class CassandraEmbeddingConfiguration {
|
||||||
|
|
||||||
/** Default Cassandra Port. */
|
/**
|
||||||
|
* Default Cassandra Port.
|
||||||
|
*/
|
||||||
public static Integer DEFAULT_PORT = 9042;
|
public static Integer DEFAULT_PORT = 9042;
|
||||||
|
|
||||||
// --- Connectivity Parameters ---
|
// --- Connectivity Parameters ---
|
||||||
|
@ -74,8 +76,7 @@ public class CassandraEmbeddingConfiguration {
|
||||||
/**
|
/**
|
||||||
* Initialize the builder.
|
* Initialize the builder.
|
||||||
*
|
*
|
||||||
* @return
|
* @return cassandra embedding configuration buildesr
|
||||||
* cassandra embedding configuration buildesr
|
|
||||||
*/
|
*/
|
||||||
public static CassandraEmbeddingConfigurationBuilder builder() {
|
public static CassandraEmbeddingConfigurationBuilder builder() {
|
||||||
return new CassandraEmbeddingConfigurationBuilder();
|
return new CassandraEmbeddingConfigurationBuilder();
|
||||||
|
@ -84,6 +85,6 @@ public class CassandraEmbeddingConfiguration {
|
||||||
/**
|
/**
|
||||||
* Signature for the builder.
|
* Signature for the builder.
|
||||||
*/
|
*/
|
||||||
public static class CassandraEmbeddingConfigurationBuilder{}
|
public static class CassandraEmbeddingConfigurationBuilder {
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
package dev.langchain4j.store.embedding.cassandra;
|
||||||
|
|
||||||
|
import com.datastax.oss.driver.api.core.CqlSession;
|
||||||
|
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
|
||||||
|
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
|
||||||
|
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
||||||
|
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
|
||||||
|
*
|
||||||
|
* @see EmbeddingStore
|
||||||
|
* @see MetadataVectorCassandraTable
|
||||||
|
*/
|
||||||
|
public class CassandraEmbeddingStore extends CassandraEmbeddingStoreSupport {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the store from the configuration.
|
||||||
|
*
|
||||||
|
* @param config configuration
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) {
|
||||||
|
CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config);
|
||||||
|
createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace());
|
||||||
|
sessionBuilder.withKeyspace(config.getKeyspace());
|
||||||
|
this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(),
|
||||||
|
config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the cassandra session from the config. At the difference of adminSession there
|
||||||
|
* a keyspace attached to it.
|
||||||
|
*
|
||||||
|
* @param config current configuration
|
||||||
|
* @return cassandra session
|
||||||
|
*/
|
||||||
|
private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) {
|
||||||
|
CqlSessionBuilder cqlSessionBuilder = CqlSession.builder();
|
||||||
|
cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter());
|
||||||
|
if (config.getUserName() != null && config.getPassword() != null) {
|
||||||
|
cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword());
|
||||||
|
}
|
||||||
|
config.getContactPoints().forEach(cp ->
|
||||||
|
cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort())));
|
||||||
|
return cqlSessionBuilder;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the keyspace in cassandra Destination if not exist.
|
||||||
|
*/
|
||||||
|
private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) {
|
||||||
|
try (CqlSession adminSession = cqlSessionBuilder.build()) {
|
||||||
|
adminSession.execute(SchemaBuilder.createKeyspace(keyspace)
|
||||||
|
.ifNotExists()
|
||||||
|
.withSimpleStrategy(1)
|
||||||
|
.withDurableWrites(true)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Syntax Sugar Builder.
|
||||||
|
*/
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration built with the builder
|
||||||
|
*/
|
||||||
|
private final CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder conf;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialization
|
||||||
|
*/
|
||||||
|
public Builder() {
|
||||||
|
conf = CassandraEmbeddingConfiguration.builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating cassandra port.
|
||||||
|
*
|
||||||
|
* @param port port
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore.Builder port(int port) {
|
||||||
|
conf.port(port);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating cassandra contact points.
|
||||||
|
*
|
||||||
|
* @param hosts port
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore.Builder contactPoints(String... hosts) {
|
||||||
|
conf.contactPoints(Arrays.asList(hosts));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating model dimension.
|
||||||
|
*
|
||||||
|
* @param dimension model dimension
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore.Builder vectorDimension(int dimension) {
|
||||||
|
conf.dimension(dimension);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating datacenter.
|
||||||
|
*
|
||||||
|
* @param dc datacenter
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore.Builder localDataCenter(String dc) {
|
||||||
|
conf.localDataCenter(dc);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populating table name.
|
||||||
|
*
|
||||||
|
* @param keyspace keyspace name
|
||||||
|
* @param table table name
|
||||||
|
* @return current reference
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore.Builder table(String keyspace, String table) {
|
||||||
|
conf.keyspace(keyspace);
|
||||||
|
conf.table(table);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Building the Store.
|
||||||
|
*
|
||||||
|
* @return store for Astra.
|
||||||
|
*/
|
||||||
|
public CassandraEmbeddingStore build() {
|
||||||
|
return new CassandraEmbeddingStore(conf.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,67 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
|
||||||
|
|
||||||
import com.datastax.oss.driver.api.core.CqlSession;
|
|
||||||
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
|
|
||||||
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
|
|
||||||
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
|
||||||
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
||||||
|
|
||||||
import java.net.InetSocketAddress;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implementation of {@link EmbeddingStore} using Cassandra AstraDB.
|
|
||||||
*
|
|
||||||
* @see EmbeddingStore
|
|
||||||
* @see MetadataVectorCassandraTable
|
|
||||||
*/
|
|
||||||
public class CassandraEmbeddingStoreImpl extends CassandraEmbeddingStoreSupport {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the store from the configuration.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* configuration
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStoreImpl(CassandraEmbeddingConfiguration config) {
|
|
||||||
CqlSessionBuilder sessionBuilder = createCqlSessionBuilder(config);
|
|
||||||
createKeyspaceIfNotExist(sessionBuilder, config.getKeyspace());
|
|
||||||
sessionBuilder.withKeyspace(config.getKeyspace());
|
|
||||||
this.embeddingTable = new MetadataVectorCassandraTable(sessionBuilder.build(),
|
|
||||||
config.getKeyspace(), config.getTable(), config.getDimension(), SimilarityMetric.COS);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the cassandra session from the config. At the difference of adminSession there
|
|
||||||
* a keyspace attached to it.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* current configuration
|
|
||||||
* @return
|
|
||||||
* cassandra session
|
|
||||||
*/
|
|
||||||
private CqlSessionBuilder createCqlSessionBuilder(CassandraEmbeddingConfiguration config) {
|
|
||||||
CqlSessionBuilder cqlSessionBuilder = CqlSession.builder();
|
|
||||||
cqlSessionBuilder.withLocalDatacenter(config.getLocalDataCenter());
|
|
||||||
if (config.getUserName() != null && config.getPassword() != null) {
|
|
||||||
cqlSessionBuilder.withAuthCredentials(config.getUserName(), config.getPassword());
|
|
||||||
}
|
|
||||||
config.getContactPoints().forEach(cp ->
|
|
||||||
cqlSessionBuilder.addContactPoint(new InetSocketAddress(cp, config.getPort())));
|
|
||||||
return cqlSessionBuilder;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create the keyspace in cassandra Destination if not exist.
|
|
||||||
*/
|
|
||||||
private void createKeyspaceIfNotExist(CqlSessionBuilder cqlSessionBuilder, String keyspace) {
|
|
||||||
try(CqlSession adminSession = cqlSessionBuilder.build()) {
|
|
||||||
adminSession.execute(SchemaBuilder.createKeyspace(keyspace)
|
|
||||||
.ifNotExists()
|
|
||||||
.withSimpleStrategy(1)
|
|
||||||
.withDurableWrites(true)
|
|
||||||
.build());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -3,6 +3,7 @@ package dev.langchain4j.store.embedding.cassandra;
|
||||||
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
import com.dtsx.astra.sdk.cassio.MetadataVectorCassandraTable;
|
||||||
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
|
import com.dtsx.astra.sdk.cassio.SimilarityMetric;
|
||||||
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery;
|
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery;
|
||||||
|
import com.dtsx.astra.sdk.cassio.SimilaritySearchQuery.SimilaritySearchQueryBuilder;
|
||||||
import com.dtsx.astra.sdk.cassio.SimilaritySearchResult;
|
import com.dtsx.astra.sdk.cassio.SimilaritySearchResult;
|
||||||
import dev.langchain4j.data.document.Metadata;
|
import dev.langchain4j.data.document.Metadata;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
@ -18,11 +19,15 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||||
|
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Support for CassandraEmbeddingStore with and Without Astra.
|
* Support for CassandraEmbeddingStore with and Without Astra.
|
||||||
*/
|
*/
|
||||||
@Getter
|
@Getter
|
||||||
public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<TextSegment> {
|
abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<TextSegment> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents an embedding table in Cassandra, it is a table with a vector column.
|
* Represents an embedding table in Cassandra, it is a table with a vector column.
|
||||||
|
@ -34,10 +39,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
* - the row id is generated
|
* - the row id is generated
|
||||||
* - text and metadata are not stored
|
* - text and metadata are not stored
|
||||||
*
|
*
|
||||||
* @param embedding
|
* @param embedding representation of the list of floats
|
||||||
* representation of the list of floats
|
* @return newly created row id
|
||||||
* @return
|
|
||||||
* newly created row id
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public String add(@NonNull Embedding embedding) {
|
public String add(@NonNull Embedding embedding) {
|
||||||
|
@ -49,12 +52,9 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
* - the row id is generated
|
* - the row id is generated
|
||||||
* - text and metadata coming from the text Segment
|
* - text and metadata coming from the text Segment
|
||||||
*
|
*
|
||||||
* @param embedding
|
* @param embedding representation of the list of floats
|
||||||
* representation of the list of floats
|
* @param textSegment text content and metadata
|
||||||
* @param textSegment
|
* @return newly created row id
|
||||||
* text content and metadata
|
|
||||||
* @return
|
|
||||||
* newly created row id
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public String add(@NonNull Embedding embedding, TextSegment textSegment) {
|
public String add(@NonNull Embedding embedding, TextSegment textSegment) {
|
||||||
|
@ -70,10 +70,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
/**
|
/**
|
||||||
* Add a new embedding to the store.
|
* Add a new embedding to the store.
|
||||||
*
|
*
|
||||||
* @param rowId
|
* @param rowId the row id
|
||||||
* the row id
|
* @param embedding representation of the list of floats
|
||||||
* @param embedding
|
|
||||||
* representation of the list of floats
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void add(@NonNull String rowId, @NonNull Embedding embedding) {
|
public void add(@NonNull String rowId, @NonNull Embedding embedding) {
|
||||||
|
@ -83,10 +81,8 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
/**
|
/**
|
||||||
* Add multiple embeddings as a single action.
|
* Add multiple embeddings as a single action.
|
||||||
*
|
*
|
||||||
* @param embeddingList
|
* @param embeddingList embeddings list
|
||||||
* embeddings list
|
* @return list of new row if (same order as the input)
|
||||||
* @return
|
|
||||||
* list of new row if (same order as the input)
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<String> addAll(List<Embedding> embeddingList) {
|
public List<String> addAll(List<Embedding> embeddingList) {
|
||||||
|
@ -95,18 +91,15 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
.map(MetadataVectorCassandraTable.Record::new)
|
.map(MetadataVectorCassandraTable.Record::new)
|
||||||
.peek(embeddingTable::putAsync)
|
.peek(embeddingTable::putAsync)
|
||||||
.map(MetadataVectorCassandraTable.Record::getRowId)
|
.map(MetadataVectorCassandraTable.Record::getRowId)
|
||||||
.collect(Collectors.toList());
|
.collect(toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add multiple embeddings as a single action.
|
* Add multiple embeddings as a single action.
|
||||||
*
|
*
|
||||||
* @param embeddingList
|
* @param embeddingList embeddings
|
||||||
* embeddings
|
* @param textSegmentList text segments
|
||||||
* @param textSegmentList
|
* @return list of new row if (same order as the input)
|
||||||
* text segments
|
|
||||||
* @return
|
|
||||||
* list of new row if (same order as the input)
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
|
public List<String> addAll(List<Embedding> embeddingList, List<TextSegment> textSegmentList) {
|
||||||
|
@ -124,36 +117,30 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
/**
|
/**
|
||||||
* Search for relevant.
|
* Search for relevant.
|
||||||
*
|
*
|
||||||
* @param embedding
|
* @param embedding current embeddings
|
||||||
* current embeddings
|
* @param maxResults max number of result
|
||||||
* @param maxResults
|
* @param minScore threshold
|
||||||
* max number of result
|
* @return list of matching elements
|
||||||
* @param minScore
|
|
||||||
* threshold
|
|
||||||
* @return
|
|
||||||
* list of matching elements
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore) {
|
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore) {
|
||||||
return embeddingTable
|
return embeddingTable
|
||||||
.similaritySearch(SimilaritySearchQuery.builder()
|
.similaritySearch(SimilaritySearchQuery.builder()
|
||||||
.embeddings(embedding.vectorAsList())
|
.embeddings(embedding.vectorAsList())
|
||||||
.recordCount(maxResults)
|
.recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
|
||||||
.threshold(CosineSimilarity.fromRelevanceScore(minScore))
|
.threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")))
|
||||||
.distance(SimilarityMetric.COS)
|
.distance(SimilarityMetric.COS)
|
||||||
.build())
|
.build())
|
||||||
.stream()
|
.stream()
|
||||||
.map(CassandraEmbeddingStoreSupport::mapSearchResult)
|
.map(CassandraEmbeddingStoreSupport::mapSearchResult)
|
||||||
.collect(Collectors.toList());
|
.collect(toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map Search result coming from Astra.
|
* Map Search result coming from Astra.
|
||||||
*
|
*
|
||||||
* @param record
|
* @param record current record
|
||||||
* current record
|
* @return search result
|
||||||
* @return
|
|
||||||
* search result
|
|
||||||
*/
|
*/
|
||||||
private static EmbeddingMatch<TextSegment> mapSearchResult(SimilaritySearchResult<MetadataVectorCassandraTable.Record> record) {
|
private static EmbeddingMatch<TextSegment> mapSearchResult(SimilaritySearchResult<MetadataVectorCassandraTable.Record> record) {
|
||||||
return new EmbeddingMatch<>(
|
return new EmbeddingMatch<>(
|
||||||
|
@ -163,33 +150,24 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
record.getEmbedded().getRowId(),
|
record.getEmbedded().getRowId(),
|
||||||
// Embeddings vector
|
// Embeddings vector
|
||||||
Embedding.from(record.getEmbedded().getVector()),
|
Embedding.from(record.getEmbedded().getVector()),
|
||||||
// Text Fragment and metadata
|
// Text segment and metadata
|
||||||
TextSegment.from(record.getEmbedded().getBody(), new Metadata(record.getEmbedded().getMetadata())));
|
TextSegment.from(record.getEmbedded().getBody(), new Metadata(record.getEmbedded().getMetadata())));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Similarity Search ANN based on the embedding.
|
* Similarity Search ANN based on the embedding.
|
||||||
*
|
*
|
||||||
* @param embedding
|
* @param embedding vector
|
||||||
* vector
|
* @param maxResults max number of results
|
||||||
* @param maxResults
|
* @param minScore score minScore
|
||||||
* max number of record
|
* @param metadata map key-value to build a metadata filter
|
||||||
* @param minScore
|
* @return list of matching results
|
||||||
* score minScore
|
|
||||||
* @param metadata
|
|
||||||
* map key-value to build a metadata filter
|
|
||||||
* @return
|
|
||||||
* list of matching results
|
|
||||||
*/
|
*/
|
||||||
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, Integer maxResults, Double minScore, Metadata metadata) {
|
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int maxResults, double minScore, Metadata metadata) {
|
||||||
SimilaritySearchQuery.SimilaritySearchQueryBuilder builder =
|
SimilaritySearchQueryBuilder builder = SimilaritySearchQuery.builder()
|
||||||
SimilaritySearchQuery.builder().embeddings(embedding.vectorAsList());
|
.embeddings(embedding.vectorAsList())
|
||||||
if (maxResults == null || maxResults < 1) {
|
.recordCount(ensureGreaterThanZero(maxResults, "maxResults"))
|
||||||
throw new IllegalArgumentException("maxResults (param[1]) must not be null and greater than 0");
|
.threshold(CosineSimilarity.fromRelevanceScore(ensureBetween(minScore, 0, 1, "minScore")));
|
||||||
}
|
|
||||||
if (minScore == null || minScore < 1 || minScore > 1) {
|
|
||||||
throw new IllegalArgumentException("minScore (param[2]) must not be null and in between 0 and 1.");
|
|
||||||
}
|
|
||||||
if (metadata != null) {
|
if (metadata != null) {
|
||||||
builder.metaData(metadata.asMap());
|
builder.metaData(metadata.asMap());
|
||||||
}
|
}
|
||||||
|
@ -197,7 +175,6 @@ public abstract class CassandraEmbeddingStoreSupport implements EmbeddingStore<T
|
||||||
.similaritySearch(builder.build())
|
.similaritySearch(builder.build())
|
||||||
.stream()
|
.stream()
|
||||||
.map(CassandraEmbeddingStoreSupport::mapSearchResult)
|
.map(CassandraEmbeddingStoreSupport::mapSearchResult)
|
||||||
.collect(Collectors.toList());
|
.collect(toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,24 +4,18 @@ import com.datastax.astra.sdk.AstraClient;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AstraDb is a version of Cassandra running in Saas Mode.
|
* AstraDb is a version of Cassandra running in Saas Mode.
|
||||||
*
|
* <p>
|
||||||
* The initialization of the CQLSession will be done through an AstraClient
|
* The initialization of the CQLSession will be done through an AstraClient
|
||||||
*
|
|
||||||
* @author Cedrick Lunven (clun)
|
|
||||||
*/
|
*/
|
||||||
public class AstraDbChatMemoryStore extends CassandraChatMemoryStore {
|
public class AstraDbChatMemoryStore extends CassandraChatMemoryStore {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor with default table name.
|
* Constructor with default table name.
|
||||||
*
|
*
|
||||||
* @param token
|
* @param token token
|
||||||
* token
|
* @param dbId database identifier
|
||||||
* @param dbId
|
* @param dbRegion database region
|
||||||
* database idendifier
|
* @param keyspaceName keyspace name
|
||||||
* @param dbRegion
|
|
||||||
* database region
|
|
||||||
* @param keyspaceName
|
|
||||||
* keyspace name
|
|
||||||
*/
|
*/
|
||||||
public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName) {
|
public AstraDbChatMemoryStore(String token, String dbId, String dbRegion, String keyspaceName) {
|
||||||
this(token, dbId, dbRegion, keyspaceName, DEFAULT_TABLE_NAME);
|
this(token, dbId, dbRegion, keyspaceName, DEFAULT_TABLE_NAME);
|
||||||
|
|
|
@ -1,28 +1,19 @@
|
||||||
package dev.langchain4j.store.memory.chat.cassandra;
|
package dev.langchain4j.store.memory.chat.cassandra;
|
||||||
|
|
||||||
import com.datastax.astra.sdk.AstraClient;
|
|
||||||
import com.datastax.oss.driver.api.core.CqlSession;
|
import com.datastax.oss.driver.api.core.CqlSession;
|
||||||
import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record;
|
|
||||||
|
|
||||||
import com.datastax.oss.driver.api.core.uuid.Uuids;
|
import com.datastax.oss.driver.api.core.uuid.Uuids;
|
||||||
import com.dtsx.astra.sdk.cassio.ClusteredCassandraTable;
|
import com.dtsx.astra.sdk.cassio.ClusteredCassandraTable;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
import dev.langchain4j.data.message.ChatMessageDeserializer;
|
import dev.langchain4j.data.message.ChatMessageDeserializer;
|
||||||
import dev.langchain4j.data.message.ChatMessageSerializer;
|
import dev.langchain4j.data.message.ChatMessageSerializer;
|
||||||
import dev.langchain4j.data.message.SystemMessage;
|
|
||||||
import dev.langchain4j.data.message.UserMessage;
|
|
||||||
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
import static com.dtsx.astra.sdk.cassio.ClusteredCassandraTable.Record;
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of {@link ChatMemoryStore} using Astra DB Vector Search.
|
* Implementation of {@link ChatMemoryStore} using Astra DB Vector Search.
|
||||||
|
@ -30,8 +21,6 @@ import java.util.stream.Collectors;
|
||||||
* is a partition.Message id is a time uuid.
|
* is a partition.Message id is a time uuid.
|
||||||
*
|
*
|
||||||
* @see <a href="https://docs.datastax.com/en/astra-serverless/docs/vector-search/overview.html">Astra Vector Store Documentation</a>
|
* @see <a href="https://docs.datastax.com/en/astra-serverless/docs/vector-search/overview.html">Astra Vector Store Documentation</a>
|
||||||
* @author Cedrick Lunven (clun)
|
|
||||||
* @since 0.22.0
|
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class CassandraChatMemoryStore implements ChatMemoryStore {
|
public class CassandraChatMemoryStore implements ChatMemoryStore {
|
||||||
|
@ -46,18 +35,12 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
||||||
*/
|
*/
|
||||||
private final ClusteredCassandraTable messageTable;
|
private final ClusteredCassandraTable messageTable;
|
||||||
|
|
||||||
/** Object Mapper. */
|
|
||||||
private static final ObjectMapper OM = new ObjectMapper();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor for message store
|
* Constructor for message store
|
||||||
*
|
*
|
||||||
* @param session
|
* @param session cassandra session
|
||||||
* cassandra session
|
* @param keyspaceName keyspace name
|
||||||
* @param keyspaceName
|
* @param tableName table name
|
||||||
* keyspace name
|
|
||||||
* @param tableName
|
|
||||||
* table name
|
|
||||||
*/
|
*/
|
||||||
public CassandraChatMemoryStore(CqlSession session, String keyspaceName, String tableName) {
|
public CassandraChatMemoryStore(CqlSession session, String keyspaceName, String tableName) {
|
||||||
messageTable = new ClusteredCassandraTable(session, keyspaceName, tableName);
|
messageTable = new ClusteredCassandraTable(session, keyspaceName, tableName);
|
||||||
|
@ -66,35 +49,39 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
||||||
/**
|
/**
|
||||||
* Constructor for message store
|
* Constructor for message store
|
||||||
*
|
*
|
||||||
* @param session
|
* @param session cassandra session
|
||||||
* cassandra session
|
* @param keyspaceName keyspace name
|
||||||
* @param keyspaceName
|
|
||||||
* keyspace name
|
|
||||||
*/
|
*/
|
||||||
public CassandraChatMemoryStore(CqlSession session, String keyspaceName) {
|
public CassandraChatMemoryStore(CqlSession session, String keyspaceName) {
|
||||||
messageTable = new ClusteredCassandraTable(session, keyspaceName, DEFAULT_TABLE_NAME);
|
messageTable = new ClusteredCassandraTable(session, keyspaceName, DEFAULT_TABLE_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ChatMessage> getMessages(@NonNull Object memoryId) {
|
public List<ChatMessage> getMessages(@NonNull Object memoryId) {
|
||||||
return messageTable
|
return messageTable
|
||||||
.findPartition(getMemoryId(memoryId))
|
.findPartition(getMemoryId(memoryId))
|
||||||
.stream()
|
.stream()
|
||||||
.map(this::toChatMessage)
|
.map(this::toChatMessage)
|
||||||
.collect(Collectors.toList());
|
.collect(toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void updateMessages(@NonNull Object memoryId, @NonNull List<ChatMessage> list) {
|
public void updateMessages(@NonNull Object memoryId, @NonNull List<ChatMessage> messages) {
|
||||||
deleteMessages(memoryId);
|
deleteMessages(memoryId);
|
||||||
messageTable.upsertPartition(list.stream()
|
messageTable.upsertPartition(messages.stream()
|
||||||
.map(r -> this.fromChatMessage(getMemoryId(memoryId), r))
|
.map(record -> fromChatMessage(getMemoryId(memoryId), record))
|
||||||
.collect(Collectors.toList()));
|
.collect(toList()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void deleteMessages(@NonNull Object memoryId) {
|
public void deleteMessages(@NonNull Object memoryId) {
|
||||||
messageTable.deletePartition(getMemoryId(memoryId));
|
messageTable.deletePartition(getMemoryId(memoryId));
|
||||||
|
@ -103,10 +90,8 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
||||||
/**
|
/**
|
||||||
* Unmarshalling Cassandra row as a Message with proper sub-type.
|
* Unmarshalling Cassandra row as a Message with proper sub-type.
|
||||||
*
|
*
|
||||||
* @param record
|
* @param record cassandra record
|
||||||
* cassandra record
|
* @return chat message
|
||||||
* @return
|
|
||||||
* chat message
|
|
||||||
*/
|
*/
|
||||||
private ChatMessage toChatMessage(@NonNull Record record) {
|
private ChatMessage toChatMessage(@NonNull Record record) {
|
||||||
try {
|
try {
|
||||||
|
@ -119,12 +104,10 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Serialize the {@link ChatMessage} as a Cassandra Row.
|
* Serialize the {@link ChatMessage} as a Cassandra Row.
|
||||||
* @param memoryId
|
*
|
||||||
* chat session identifier
|
* @param memoryId chat session identifier
|
||||||
* @param chatMessage
|
* @param chatMessage chat message
|
||||||
* chat message
|
* @return cassandra row.
|
||||||
* @return
|
|
||||||
* cassandra row.
|
|
||||||
*/
|
*/
|
||||||
private Record fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) {
|
private Record fromChatMessage(@NonNull String memoryId, @NonNull ChatMessage chatMessage) {
|
||||||
try {
|
try {
|
||||||
|
@ -145,6 +128,4 @@ public class CassandraChatMemoryStore implements ChatMemoryStore {
|
||||||
}
|
}
|
||||||
return (String) memoryId;
|
return (String) memoryId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
package dev.langchain4j.store.embedding.cassandra;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The configuration objects include a few validation rules.
|
||||||
|
*/
|
||||||
|
public class AstraDbEmbeddingConfigurationTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_build_configuration_test() {
|
||||||
|
AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder()
|
||||||
|
.token("token")
|
||||||
|
.databaseId("dbId")
|
||||||
|
.databaseRegion("dbRegion")
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.table("table")
|
||||||
|
.build();
|
||||||
|
assertNotNull(config);
|
||||||
|
assertNotNull(config.getToken());
|
||||||
|
assertNotNull(config.getDatabaseId());
|
||||||
|
assertNotNull(config.getDatabaseRegion());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_table_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> AstraDbEmbeddingConfiguration.builder()
|
||||||
|
.token("token")
|
||||||
|
.databaseId("dbId")
|
||||||
|
.databaseRegion("dbRegion")
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.build());
|
||||||
|
assertEquals("table is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_keyspace_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> AstraDbEmbeddingConfiguration.builder()
|
||||||
|
.token("token")
|
||||||
|
.databaseId("dbId")
|
||||||
|
.databaseRegion("dbRegion")
|
||||||
|
.table("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.build());
|
||||||
|
assertEquals("keyspace is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_dimension_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> AstraDbEmbeddingConfiguration.builder()
|
||||||
|
.token("token")
|
||||||
|
.databaseId("dbId")
|
||||||
|
.databaseRegion("dbRegion")
|
||||||
|
.table("ks")
|
||||||
|
.keyspace("ks")
|
||||||
|
.build());
|
||||||
|
assertEquals("dimension is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_token_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> AstraDbEmbeddingConfiguration.builder()
|
||||||
|
.databaseId("dbId")
|
||||||
|
.databaseRegion("dbRegion")
|
||||||
|
.table("ks")
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.build());
|
||||||
|
assertEquals("token is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_database_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> AstraDbEmbeddingConfiguration.builder()
|
||||||
|
.token("token")
|
||||||
|
.table("ks")
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.build());
|
||||||
|
assertEquals("databaseId is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,14 +1,12 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
package dev.langchain4j.store.embedding.cassandra;
|
||||||
|
|
||||||
import com.datastax.astra.sdk.AstraClient;
|
import com.datastax.astra.sdk.AstraClient;
|
||||||
import com.datastax.oss.driver.api.core.CqlSession;
|
|
||||||
import com.dtsx.astra.sdk.utils.TestUtils;
|
import com.dtsx.astra.sdk.utils.TestUtils;
|
||||||
import dev.langchain4j.data.document.Metadata;
|
import dev.langchain4j.data.document.Metadata;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import org.junit.jupiter.api.Assertions;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Disabled;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||||
|
|
||||||
|
@ -16,53 +14,59 @@ import java.util.List;
|
||||||
|
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Testing implementation of Embedding Store using AstraDB.
|
* Testing implementation of Embedding Store using AstraDB.
|
||||||
*/
|
*/
|
||||||
class AstraDbEmbeddingStoreTest {
|
class AstraDbEmbeddingStoreTest {
|
||||||
|
|
||||||
public static final String TEST_DB = "langchain4j";
|
private static final String TEST_KEYSPACE = "langchain4j";
|
||||||
public static final String TEST_KEYSPACE = "langchain4j";
|
private static final String TEST_INDEX = "test_embedding_store";
|
||||||
public static final String TEST_INDEX = "test_embedding_store";
|
|
||||||
|
|
||||||
@Test
|
private final String astraToken = getAstraToken();
|
||||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
private final String databaseId = setupDatabase("langchain4j", TEST_KEYSPACE);
|
||||||
public void testAddEmbeddingAndFindRelevant()
|
private final AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration
|
||||||
throws Exception {
|
.builder()
|
||||||
|
.token(astraToken)
|
||||||
// Read Token from environment variable ASTRA_DB_APPLICATION_TOKEN
|
.databaseId(databaseId)
|
||||||
String astraToken = getAstraToken();
|
|
||||||
|
|
||||||
// Database will be created if not exist (can take 90 seconds on first run)
|
|
||||||
String databaseId = setupDatabase(TEST_DB, TEST_KEYSPACE);
|
|
||||||
|
|
||||||
// Store initialization
|
|
||||||
AstraDbEmbeddingStore astraDbEmbeddingStore = new AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration
|
|
||||||
.builder().token(astraToken).databaseId(databaseId)
|
|
||||||
.databaseRegion(TestUtils.TEST_REGION)
|
.databaseRegion(TestUtils.TEST_REGION)
|
||||||
.keyspace(TEST_KEYSPACE)
|
.keyspace(TEST_KEYSPACE)
|
||||||
.table(TEST_INDEX)
|
.table(TEST_INDEX)
|
||||||
.dimension(11).build());
|
.dimension(11)
|
||||||
|
.build());
|
||||||
|
|
||||||
// Flushing Table before Start (idem potent)
|
@BeforeEach
|
||||||
|
void truncateTable() {
|
||||||
AstraClient.builder()
|
AstraClient.builder()
|
||||||
.withToken(astraToken)
|
.withToken(getAstraToken())
|
||||||
.withCqlKeyspace(TEST_KEYSPACE)
|
.withCqlKeyspace(TEST_KEYSPACE)
|
||||||
.withDatabaseId(databaseId)
|
.withDatabaseId(databaseId)
|
||||||
.withDatabaseRegion(TestUtils.TEST_REGION)
|
.withDatabaseRegion(TestUtils.TEST_REGION)
|
||||||
.enableCql()
|
.enableCql()
|
||||||
.enableDownloadSecureConnectBundle()
|
.enableDownloadSecureConnectBundle()
|
||||||
.build().cqlSession().execute("TRUNCATE TABLE " + TEST_INDEX);
|
.build().cqlSession().execute("TRUNCATE TABLE " + TEST_INDEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||||
|
void testAddEmbeddingAndFindRelevant() {
|
||||||
|
|
||||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||||
TextSegment textSegment = TextSegment.textSegment("Text", Metadata.from("Key", "Value"));
|
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||||
String added = astraDbEmbeddingStore.add(embedding, textSegment);
|
String id = astraDbEmbeddingStore.add(embedding, textSegment);
|
||||||
Assertions.assertTrue(added != null && !added.isEmpty());
|
assertTrue(id != null && !id.isEmpty());
|
||||||
|
|
||||||
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = astraDbEmbeddingStore.findRelevant(refereceEmbedding, 10);
|
List<EmbeddingMatch<TextSegment>> embeddingMatches = astraDbEmbeddingStore.findRelevant(refereceEmbedding, 10);
|
||||||
Assertions.assertEquals(1, embeddingMatches.size());
|
assertEquals(1, embeddingMatches.size());
|
||||||
}
|
|
||||||
|
|
||||||
|
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||||
|
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
|
||||||
|
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||||
|
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||||
|
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
package dev.langchain4j.store.embedding.cassandra;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.DEFAULT_PORT;
|
||||||
|
import static java.util.Collections.singletonList;
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class CassandraEmbeddingConfigurationTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_build_configuration_test() {
|
||||||
|
CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder()
|
||||||
|
.contactPoints(singletonList("localhost"))
|
||||||
|
.port(DEFAULT_PORT)
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.table("table")
|
||||||
|
.localDataCenter("dc1")
|
||||||
|
.build();
|
||||||
|
assertNotNull(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_datacenter_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> CassandraEmbeddingConfiguration.builder()
|
||||||
|
.contactPoints(singletonList("localhost"))
|
||||||
|
.port(DEFAULT_PORT)
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.table("table")
|
||||||
|
.build());
|
||||||
|
assertEquals("localDataCenter is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_table_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> CassandraEmbeddingConfiguration.builder()
|
||||||
|
.contactPoints(singletonList("localhost"))
|
||||||
|
.port(DEFAULT_PORT)
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.localDataCenter("dc1")
|
||||||
|
.build());
|
||||||
|
assertEquals("table is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_keyspace_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> CassandraEmbeddingConfiguration.builder()
|
||||||
|
.contactPoints(singletonList("localhost"))
|
||||||
|
.port(DEFAULT_PORT)
|
||||||
|
.table("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.localDataCenter("dc1")
|
||||||
|
.build());
|
||||||
|
assertEquals("keyspace is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_dimension_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> CassandraEmbeddingConfiguration.builder()
|
||||||
|
.contactPoints(singletonList("localhost"))
|
||||||
|
.port(DEFAULT_PORT)
|
||||||
|
.table("ks")
|
||||||
|
.keyspace("ks")
|
||||||
|
.localDataCenter("dc1")
|
||||||
|
.build());
|
||||||
|
assertEquals("dimension is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void should_error_if_no_contact_points_test() {
|
||||||
|
// Table is required
|
||||||
|
NullPointerException exception = assertThrows(NullPointerException.class,
|
||||||
|
() -> CassandraEmbeddingConfiguration.builder()
|
||||||
|
.port(DEFAULT_PORT)
|
||||||
|
.table("ks")
|
||||||
|
.keyspace("ks")
|
||||||
|
.dimension(20)
|
||||||
|
.localDataCenter("dc1")
|
||||||
|
.build());
|
||||||
|
assertEquals("contactPoints is marked non-null but is null", exception.getMessage());
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,49 +4,50 @@ import dev.langchain4j.data.document.Metadata;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
||||||
import org.junit.jupiter.api.Assertions;
|
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Work with Cassandra Embedding Store.
|
* Work with Cassandra Embedding Store.
|
||||||
*/
|
*/
|
||||||
class CassandraEmbeddingStoreTest {
|
class CassandraEmbeddingStoreTest {
|
||||||
|
|
||||||
public static final String TEST_KEYSPACE = "langchain4j";
|
|
||||||
public static final String TEST_INDEX = "test_embedding_store";
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Disabled("To run this test, you must have a local Cassandra instance, a docker-compose is provided")
|
@Disabled("To run this test, you must have a local Cassandra instance, a docker-compose is provided")
|
||||||
public void testAddEmbeddingAndFindRelevant()
|
public void testAddEmbeddingAndFindRelevant() {
|
||||||
throws Exception {
|
|
||||||
|
|
||||||
CassandraEmbeddingStore cassandraEmbeddingStore = initStore();
|
CassandraEmbeddingStore cassandraEmbeddingStore = initStore();
|
||||||
|
|
||||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||||
TextSegment textSegment = TextSegment.textSegment("Text", Metadata.from("Key", "Value"));
|
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||||
String added = cassandraEmbeddingStore.add(embedding, textSegment);
|
String id = cassandraEmbeddingStore.add(embedding, textSegment);
|
||||||
Assertions.assertTrue(added != null && !added.isEmpty());
|
assertTrue(id != null && !id.isEmpty());
|
||||||
|
|
||||||
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 10);
|
List<EmbeddingMatch<TextSegment>> embeddingMatches = cassandraEmbeddingStore.findRelevant(refereceEmbedding, 1);
|
||||||
Assertions.assertEquals(1, embeddingMatches.size());
|
assertEquals(1, embeddingMatches.size());
|
||||||
|
|
||||||
|
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||||
|
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
|
||||||
|
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||||
|
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||||
|
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||||
}
|
}
|
||||||
|
|
||||||
private CassandraEmbeddingStore initStore()
|
private CassandraEmbeddingStore initStore() {
|
||||||
throws Exception {
|
return CassandraEmbeddingStore.builder()
|
||||||
return CassandraEmbeddingStore
|
|
||||||
.builder()
|
|
||||||
.contactPoints("127.0.0.1")
|
.contactPoints("127.0.0.1")
|
||||||
.port(9042)
|
.port(9042)
|
||||||
.localDataCenter("datacenter1")
|
.localDataCenter("datacenter1")
|
||||||
.table(TEST_KEYSPACE, TEST_INDEX)
|
.table("langchain4j", "table_" + randomUUID().replace("-", ""))
|
||||||
.vectorDimension(11)
|
.vectorDimension(11)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@ import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||||
import org.junit.jupiter.api.Assertions;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||||
|
|
||||||
|
@ -29,7 +28,6 @@ import java.nio.file.Path;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
||||||
|
@ -37,22 +35,23 @@ import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
|
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
|
||||||
import static java.time.Duration.ofSeconds;
|
import static java.time.Duration.ofSeconds;
|
||||||
import static java.util.stream.Collectors.joining;
|
import static java.util.stream.Collectors.joining;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
class SampleDocumentLoaderAndRagWithAstraTest {
|
class SampleDocumentLoaderAndRagWithAstraTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||||
public void shouldRagWithOpenAiAndAstra() throws InterruptedException {
|
void shouldRagWithOpenAiAndAstra() {
|
||||||
// Initialization
|
// Initialization
|
||||||
String astraToken = getAstraToken();
|
String astraToken = getAstraToken();
|
||||||
String databaseId = setupDatabase("langchain4j", "langchain4j");
|
String databaseId = setupDatabase("langchain4j", "langchain4j");
|
||||||
String openAIKey = System.getenv("OPENAI_API_KEY");
|
String openAIKey = System.getenv("OPENAI_API_KEY");
|
||||||
|
|
||||||
// Given
|
// Given
|
||||||
Assertions.assertNotNull(openAIKey);
|
assertNotNull(openAIKey);
|
||||||
Assertions.assertNotNull(databaseId);
|
assertNotNull(databaseId);
|
||||||
Assertions.assertNotNull(astraToken);
|
assertNotNull(astraToken);
|
||||||
|
|
||||||
// --- Ingesting documents ---
|
// --- Ingesting documents ---
|
||||||
|
|
||||||
|
|
|
@ -1,26 +1,23 @@
|
||||||
package dev.langchain4j.store.memory.chat.cassandra;
|
package dev.langchain4j.store.memory.chat.cassandra;
|
||||||
|
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.memory.ChatMemory;
|
import dev.langchain4j.memory.ChatMemory;
|
||||||
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
|
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
||||||
import org.junit.jupiter.api.Assertions;
|
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||||
|
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
|
import static com.dtsx.astra.sdk.utils.TestUtils.*;
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||||
import static com.dtsx.astra.sdk.utils.TestUtils.setupDatabase;
|
|
||||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||||
import static java.time.Duration.ofSeconds;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test Cassandra Chat Memory Store with a Saas DB.
|
* Test Cassandra Chat Memory Store with a Saas DB.
|
||||||
|
@ -28,27 +25,19 @@ import static java.time.Duration.ofSeconds;
|
||||||
public class ChatMemoryStoreAstraTest {
|
public class ChatMemoryStoreAstraTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Disabled("bug: order of retrieved messages is wrong")
|
||||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||||
void chatMemoryAstraTest() throws InterruptedException {
|
void chatMemoryAstraTest() {
|
||||||
|
|
||||||
// Initialization
|
// Initialization
|
||||||
String astraToken = getAstraToken();
|
String astraToken = getAstraToken();
|
||||||
String databaseId = setupDatabase("langchain4j", "langchain4j");
|
String databaseId = setupDatabase("langchain4j", "langchain4j");
|
||||||
String openAIKey = System.getenv("OPENAI_API_KEY");
|
|
||||||
|
|
||||||
// Given
|
// Given
|
||||||
Assertions.assertNotNull(openAIKey);
|
assertNotNull(databaseId);
|
||||||
Assertions.assertNotNull(databaseId);
|
assertNotNull(astraToken);
|
||||||
Assertions.assertNotNull(astraToken);
|
|
||||||
// Given
|
|
||||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
|
||||||
.apiKey(openAIKey)
|
|
||||||
.modelName(GPT_3_5_TURBO)
|
|
||||||
.temperature(0.3)
|
|
||||||
.timeout(ofSeconds(120))
|
|
||||||
.logRequests(true)
|
|
||||||
.logResponses(true)
|
|
||||||
.build();
|
|
||||||
// When
|
// When
|
||||||
ChatMemoryStore chatMemoryStore =
|
ChatMemoryStore chatMemoryStore =
|
||||||
new AstraDbChatMemoryStore(astraToken, databaseId, TEST_REGION, "langchain4j");
|
new AstraDbChatMemoryStore(astraToken, databaseId, TEST_REGION, "langchain4j");
|
||||||
|
@ -61,11 +50,13 @@ public class ChatMemoryStoreAstraTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
// When
|
// When
|
||||||
chatMemory.add(userMessage("I will ask you a few question about ff4j. Response in a single sentence"));
|
UserMessage userMessage = userMessage("I will ask you a few question about ff4j.");
|
||||||
chatMemory.add(userMessage("Can I use it with Javascript ? "));
|
chatMemory.add(userMessage);
|
||||||
|
|
||||||
|
AiMessage aiMessage = aiMessage("Sure, go ahead!");
|
||||||
|
chatMemory.add(aiMessage);
|
||||||
|
|
||||||
// Then
|
// Then
|
||||||
Response<AiMessage> output = model.generate(chatMemory.messages());
|
assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage);
|
||||||
Assertions.assertNotNull(output.content().text());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
<configuration debug="false">
|
|
||||||
<statusListener class="ch.qos.logback.core.status.NopStatusListener" />
|
|
||||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
|
||||||
<encoder>
|
|
||||||
<pattern>%d{HH:mm:ss.SSS} %magenta(%-5level) %cyan(%-47logger) : %msg%n</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
<logger name="com.datastax.astra.sdk" level="INFO" additivity="false">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
</logger>
|
|
||||||
<logger name="dev.langchain4j.store" level="INFO" additivity="false">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
</logger>
|
|
||||||
<logger name="com.datastax.oss.driver" level="ERROR" additivity="false">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
</logger>
|
|
||||||
<root level="WARN">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
</root>
|
|
||||||
</configuration>
|
|
|
@ -15,6 +15,7 @@ import java.util.List;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
@Disabled("needs Elasticsearch to be running locally")
|
@Disabled("needs Elasticsearch to be running locally")
|
||||||
class ElasticsearchEmbeddingStoreTest {
|
class ElasticsearchEmbeddingStoreTest {
|
||||||
|
@ -77,7 +78,7 @@ class ElasticsearchEmbeddingStoreTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAddNotEqualSizeEmbeddingAndEmbedded() {
|
void testAddNotEqualSizeEmbeddingAndEmbedded() {
|
||||||
Throwable ex = Assertions.assertThrows(IllegalArgumentException.class, () -> store.addAll(asList(
|
Throwable ex = assertThrows(IllegalArgumentException.class, () -> store.addAll(asList(
|
||||||
Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
||||||
Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f, 0.55f)),
|
Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f, 0.55f)),
|
||||||
Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
||||||
|
|
|
@ -1,167 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding;
|
|
||||||
|
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
|
||||||
|
|
||||||
import java.lang.reflect.InvocationTargetException;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.stream.IntStream;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This parent class would implement the boilerplate code to invoke the concrete implementation.
|
|
||||||
*/
|
|
||||||
public abstract class AbstractEmbeddingStore<T> implements EmbeddingStore<T> {
|
|
||||||
|
|
||||||
/** Concrete Implementation if class is available. */
|
|
||||||
protected String implementationClassName;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Concrete Implementation (delegate Pattern).
|
|
||||||
*/
|
|
||||||
protected EmbeddingStore<T> delegateImplementation;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize the concrete implementation.
|
|
||||||
*/
|
|
||||||
protected abstract EmbeddingStore<T> loadImplementation()
|
|
||||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
|
|
||||||
IllegalAccessException, InvocationTargetException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Getter for the concrete implementation with an initialization.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
* delegate implementation
|
|
||||||
*/
|
|
||||||
protected EmbeddingStore<T> getDelegateImplementation() {
|
|
||||||
if (delegateImplementation == null) {
|
|
||||||
try {
|
|
||||||
this.delegateImplementation = loadImplementation();
|
|
||||||
} catch(ClassNotFoundException cnf) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Class '" + implementationClassName + "' not found, please checks your dependencies.", cnf);
|
|
||||||
} catch(NoSuchMethodException e) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Class '" + implementationClassName + "' does not have constructor with expected parameters", e);
|
|
||||||
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Constructor of class '" + implementationClassName + "' cannot be invoked check visibility or code.", e);
|
|
||||||
} catch(Exception e) {
|
|
||||||
throw new RuntimeException("Unexpected error while loading implementation", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return delegateImplementation;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add embedding (vector only) to the store. The id is generated by the store.
|
|
||||||
*
|
|
||||||
* @param embedding
|
|
||||||
* The embedding (vector) to be added to the store.
|
|
||||||
* @return
|
|
||||||
* unique id of the embedding
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public String add(Embedding embedding) {
|
|
||||||
return getDelegateImplementation().add(embedding);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add embedding (vector only) to the store enforcing its identifier.
|
|
||||||
*
|
|
||||||
* @param embeddingId
|
|
||||||
* unique identifier of the embedding
|
|
||||||
* @param embedding
|
|
||||||
* The embedding (vector) to be added to the store.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void add(String embeddingId, Embedding embedding) {
|
|
||||||
Objects.requireNonNull(embeddingId, "embeddingId (param[0]) must not be null");
|
|
||||||
Objects.requireNonNull(embedding, "embedding (param[1]) must not be null");
|
|
||||||
getDelegateImplementation().add(embeddingId, embedding);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add embedding to the store with text and metadata. The id is generated by the store.
|
|
||||||
*
|
|
||||||
* @param embedding
|
|
||||||
* The embedding (vector) to be added to the store.
|
|
||||||
* @param textSegment
|
|
||||||
* Text and metadata
|
|
||||||
* @return
|
|
||||||
* unique id of the embedding
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public String add(Embedding embedding, T textSegment) {
|
|
||||||
Objects.requireNonNull(embedding, "embedding (param[0]) must not be null");
|
|
||||||
return getDelegateImplementation().add(embedding, textSegment);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a list of embeddings to the store.
|
|
||||||
*
|
|
||||||
* @param embeddings
|
|
||||||
* list of embeddings
|
|
||||||
* @return
|
|
||||||
* list of ids
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public List<String> addAll(List<Embedding> embeddings) {
|
|
||||||
Objects.requireNonNull(embeddings, "embeddings must not be null");
|
|
||||||
return embeddings.stream().map(this::add).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a list of embeddings to the store.
|
|
||||||
*
|
|
||||||
* @param embeddings
|
|
||||||
* list of embeddings
|
|
||||||
* @param textSegments
|
|
||||||
* list of text segments
|
|
||||||
* @return
|
|
||||||
* list of ids
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public List<String> addAll(List<Embedding> embeddings, List<T> textSegments) {
|
|
||||||
Objects.requireNonNull(embeddings, "embeddings (param[0] must not be null");
|
|
||||||
Objects.requireNonNull(textSegments, "textSegments (param[1] must not be null");
|
|
||||||
if (embeddings.size() != textSegments.size()) {
|
|
||||||
throw new IllegalArgumentException("embeddings and textSegment lists must have the same size");
|
|
||||||
}
|
|
||||||
return getDelegateImplementation().addAll(embeddings, textSegments);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Search for relevant embeddings.
|
|
||||||
*
|
|
||||||
* @param referenceEmbedding
|
|
||||||
* The embedding used as a reference. Returned embeddings should be relevant (closest) to this one.
|
|
||||||
* @param maxResults
|
|
||||||
* The maximum number of embeddings to be returned.
|
|
||||||
* @return
|
|
||||||
* List of relevant embeddings.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public List<EmbeddingMatch<T>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
|
||||||
return getDelegateImplementation().findRelevant(referenceEmbedding, maxResults);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Search for relevant embeddings.
|
|
||||||
*
|
|
||||||
* @param referenceEmbedding
|
|
||||||
* The embedding used as a reference. Returned embeddings should be relevant (closest) to this one.
|
|
||||||
* @param maxResults
|
|
||||||
* The maximum number of embeddings to be returned.
|
|
||||||
* @param minScore
|
|
||||||
* The minimum similarity score of the returned embeddings.
|
|
||||||
* @return
|
|
||||||
* List of relevant embeddings.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public List<EmbeddingMatch<T>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
|
||||||
return getDelegateImplementation().findRelevant(referenceEmbedding, maxResults, minScore);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,152 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
|
||||||
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.store.embedding.AbstractEmbeddingStore;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.lang.reflect.InvocationTargetException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent an Embedding store using Cassandra AstraDB.
|
|
||||||
*
|
|
||||||
* @author Cedrick Lunven (clun)
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class AstraDbEmbeddingStore extends AbstractEmbeddingStore<TextSegment> {
|
|
||||||
|
|
||||||
/** Default implementation bu can be override. */
|
|
||||||
private static final String DEFAULT_IMPLEMENTATION =
|
|
||||||
"dev.langchain4j.store.embedding.cassandra.AstraDbEmbeddingStoreImpl";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Store Configuration.
|
|
||||||
*/
|
|
||||||
private final AstraDbEmbeddingConfiguration configuration;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor with default table name.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* load configuration
|
|
||||||
*/
|
|
||||||
public AstraDbEmbeddingStore(AstraDbEmbeddingConfiguration config) {
|
|
||||||
this(DEFAULT_IMPLEMENTATION, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor with default table name.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* load configuration
|
|
||||||
*/
|
|
||||||
public AstraDbEmbeddingStore(String impl, AstraDbEmbeddingConfiguration config) {
|
|
||||||
this.configuration = config;
|
|
||||||
this.implementationClassName = impl;
|
|
||||||
getDelegateImplementation();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
protected EmbeddingStore<TextSegment> loadImplementation()
|
|
||||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
|
|
||||||
IllegalAccessException, InvocationTargetException {
|
|
||||||
return (EmbeddingStore<TextSegment>) Class
|
|
||||||
.forName(implementationClassName)
|
|
||||||
.getConstructor(AstraDbEmbeddingConfiguration.class)
|
|
||||||
.newInstance(configuration);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Builder builder() {
|
|
||||||
return new Builder();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Syntax Sugar Builder.
|
|
||||||
*/
|
|
||||||
public static class Builder {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration built with the builder
|
|
||||||
*/
|
|
||||||
private final AstraDbEmbeddingConfiguration.AstraDbEmbeddingConfigurationBuilder conf;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialization
|
|
||||||
*/
|
|
||||||
public Builder() {
|
|
||||||
conf = AstraDbEmbeddingConfiguration.builder();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating token.
|
|
||||||
*
|
|
||||||
* @param token
|
|
||||||
* token
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public Builder token(String token) {
|
|
||||||
conf.token(token);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating token.
|
|
||||||
*
|
|
||||||
* @param databaseId
|
|
||||||
* database Identifier
|
|
||||||
* @param databaseRegion
|
|
||||||
* database region
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public Builder database(String databaseId, String databaseRegion) {
|
|
||||||
conf.databaseId(databaseId);
|
|
||||||
conf.databaseRegion(databaseRegion);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating model dimension.
|
|
||||||
*
|
|
||||||
* @param dimension
|
|
||||||
* model dimension
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public Builder vectorDimension(int dimension) {
|
|
||||||
conf.dimension(dimension);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating table name.
|
|
||||||
*
|
|
||||||
* @param keyspace
|
|
||||||
* keyspace name
|
|
||||||
* @param table
|
|
||||||
* table name
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public Builder table(String keyspace, String table) {
|
|
||||||
conf.keyspace(keyspace);
|
|
||||||
conf.table(table);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Building the Store.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
* store for Astra.
|
|
||||||
*/
|
|
||||||
public AstraDbEmbeddingStore build() {
|
|
||||||
return new AstraDbEmbeddingStore(conf.build());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,161 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
|
||||||
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.store.embedding.AbstractEmbeddingStore;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
||||||
|
|
||||||
import java.lang.reflect.InvocationTargetException;
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
import static dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingConfiguration.CassandraEmbeddingConfigurationBuilder;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents an embeddings with
|
|
||||||
*/
|
|
||||||
public class CassandraEmbeddingStore extends AbstractEmbeddingStore<TextSegment> {
|
|
||||||
|
|
||||||
/** Default implementation bu can be override. */
|
|
||||||
private static final String DEFAULT_IMPLEMENTATION =
|
|
||||||
"dev.langchain4j.store.embedding.cassandra.CassandraEmbeddingStoreImpl";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Store Configuration.
|
|
||||||
*/
|
|
||||||
private final CassandraEmbeddingConfiguration configuration;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor with default table name.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* load configuration
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore(CassandraEmbeddingConfiguration config) {
|
|
||||||
this(DEFAULT_IMPLEMENTATION, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor with default table name.
|
|
||||||
*
|
|
||||||
* @param config
|
|
||||||
* load configuration
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore(String impl, CassandraEmbeddingConfiguration config) {
|
|
||||||
this.configuration = config;
|
|
||||||
this.implementationClassName = impl;
|
|
||||||
getDelegateImplementation();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
protected EmbeddingStore<TextSegment> loadImplementation()
|
|
||||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
|
|
||||||
IllegalAccessException, InvocationTargetException {
|
|
||||||
return (EmbeddingStore<TextSegment>) Class
|
|
||||||
.forName(implementationClassName)
|
|
||||||
.getConstructor(CassandraEmbeddingConfiguration.class)
|
|
||||||
.newInstance(configuration);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static CassandraEmbeddingStore.Builder builder() {
|
|
||||||
return new CassandraEmbeddingStore.Builder();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Syntax Sugar Builder.
|
|
||||||
*/
|
|
||||||
public static class Builder {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration built with the builder
|
|
||||||
*/
|
|
||||||
private final CassandraEmbeddingConfigurationBuilder conf;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialization
|
|
||||||
*/
|
|
||||||
public Builder() {
|
|
||||||
conf = CassandraEmbeddingConfiguration.builder();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating cassandra port.
|
|
||||||
*
|
|
||||||
* @param port
|
|
||||||
* port
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore.Builder port(int port) {
|
|
||||||
conf.port(port);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating cassandra contact points.
|
|
||||||
*
|
|
||||||
* @param hosts
|
|
||||||
* port
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore.Builder contactPoints(String... hosts) {
|
|
||||||
conf.contactPoints(Arrays.asList(hosts));
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating model dimension.
|
|
||||||
*
|
|
||||||
* @param dimension
|
|
||||||
* model dimension
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore.Builder vectorDimension(int dimension) {
|
|
||||||
conf.dimension(dimension);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating datacenter.
|
|
||||||
*
|
|
||||||
* @param dc
|
|
||||||
* datacenter
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore.Builder localDataCenter(String dc) {
|
|
||||||
conf.localDataCenter(dc);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populating table name.
|
|
||||||
*
|
|
||||||
* @param keyspace
|
|
||||||
* keyspace name
|
|
||||||
* @param table
|
|
||||||
* table name
|
|
||||||
* @return
|
|
||||||
* current reference
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore.Builder table(String keyspace, String table) {
|
|
||||||
conf.keyspace(keyspace);
|
|
||||||
conf.table(table);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Building the Store.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
* store for Astra.
|
|
||||||
*/
|
|
||||||
public CassandraEmbeddingStore build() {
|
|
||||||
return new CassandraEmbeddingStore(conf.build());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,80 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Assertions;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The configuration objects include a few validation rules.
|
|
||||||
*/
|
|
||||||
public class AstraDbEmbeddingConfigurationTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_build_configuration_test() {
|
|
||||||
AstraDbEmbeddingConfiguration config = AstraDbEmbeddingConfiguration.builder()
|
|
||||||
.token("token").databaseId("dbId").databaseRegion("dbRegion")
|
|
||||||
.keyspace("ks").dimension(20).table("table")
|
|
||||||
.build();
|
|
||||||
Assertions.assertNotNull(config);
|
|
||||||
Assertions.assertNotNull(config.getToken());
|
|
||||||
Assertions.assertNotNull(config.getDatabaseId());
|
|
||||||
Assertions.assertNotNull(config.getDatabaseRegion());
|
|
||||||
}
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_table_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> AstraDbEmbeddingConfiguration.builder()
|
|
||||||
.token("token").databaseId("dbId").databaseRegion("dbRegion")
|
|
||||||
.keyspace("ks").dimension(20)
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("table is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_keyspace_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> AstraDbEmbeddingConfiguration.builder()
|
|
||||||
.token("token").databaseId("dbId").databaseRegion("dbRegion")
|
|
||||||
.table("ks").dimension(20)
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("keyspace is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_dimension_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> AstraDbEmbeddingConfiguration.builder()
|
|
||||||
.token("token").databaseId("dbId").databaseRegion("dbRegion")
|
|
||||||
.table("ks").keyspace("ks")
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("dimension is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_token_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> AstraDbEmbeddingConfiguration.builder()
|
|
||||||
.databaseId("dbId").databaseRegion("dbRegion")
|
|
||||||
.table("ks").keyspace("ks").dimension(20)
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("token is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_database_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> AstraDbEmbeddingConfiguration.builder()
|
|
||||||
.token("token")
|
|
||||||
.table("ks").keyspace("ks").dimension(20)
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("databaseId is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,84 +0,0 @@
|
||||||
package dev.langchain4j.store.embedding.cassandra;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Assertions;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
|
|
||||||
public class CassandraEmbeddingConfigurationTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_build_configuration_test() {
|
|
||||||
CassandraEmbeddingConfiguration config = CassandraEmbeddingConfiguration.builder()
|
|
||||||
.contactPoints(Collections.singletonList("localhost"))
|
|
||||||
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
|
|
||||||
.keyspace("ks").dimension(20).table("table")
|
|
||||||
.localDataCenter("dc1")
|
|
||||||
.build();
|
|
||||||
Assertions.assertNotNull(config);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_datacenter_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> CassandraEmbeddingConfiguration.builder()
|
|
||||||
.contactPoints(Collections.singletonList("localhost"))
|
|
||||||
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
|
|
||||||
.keyspace("ks").dimension(20).table("table")
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("localDataCenter is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_table_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> CassandraEmbeddingConfiguration.builder()
|
|
||||||
.contactPoints(Collections.singletonList("localhost"))
|
|
||||||
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
|
|
||||||
.keyspace("ks").dimension(20)
|
|
||||||
.localDataCenter("dc1")
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("table is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_keyspace_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() ->CassandraEmbeddingConfiguration.builder()
|
|
||||||
.contactPoints(Collections.singletonList("localhost"))
|
|
||||||
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
|
|
||||||
.table("ks").dimension(20)
|
|
||||||
.localDataCenter("dc1")
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("keyspace is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_dimension_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> CassandraEmbeddingConfiguration.builder()
|
|
||||||
.contactPoints(Collections.singletonList("localhost"))
|
|
||||||
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
|
|
||||||
.table("ks").keyspace("ks")
|
|
||||||
.localDataCenter("dc1")
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("dimension is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void should_error_if_no_contact_points_test() {
|
|
||||||
// Table is required
|
|
||||||
NullPointerException exception = Assertions.assertThrows(NullPointerException.class,
|
|
||||||
() -> CassandraEmbeddingConfiguration.builder()
|
|
||||||
.port(CassandraEmbeddingConfiguration.DEFAULT_PORT)
|
|
||||||
.table("ks").keyspace("ks").dimension(20)
|
|
||||||
.localDataCenter("dc1")
|
|
||||||
.build());
|
|
||||||
Assertions. assertEquals("contactPoints is marked non-null but is null", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Loading…
Reference in New Issue