Add langchain4j-tablestore Integration: TablestoreEmbeddingStore/TablestoreChatMemoryStore (#1650)
## Change Add langchain4j-tablestore Integration: TablestoreEmbeddingStore / TablestoreChatMemoryStore ## General checklist - [x] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [x] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable) ## Checklist for adding new embedding store integration <!-- Please double-check the following points and mark them like this: [X] --> - [x] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT` - [x] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
This commit is contained in:
parent
9ea2e27337
commit
be7454a7c6
|
@ -25,6 +25,7 @@ sidebar_position: 0
|
|||
| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ |
|
||||
| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | ✅ | |
|
||||
| [Redis](/integrations/embedding-stores/redis) | ✅ | | |
|
||||
| [Tablestore](/integrations/embedding-stores/tablestore) | ✅ |✅ |✅ |
|
||||
| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | |
|
||||
| [Vespa](/integrations/embedding-stores/vespa) | | | |
|
||||
| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ |
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
---
|
||||
sidebar_position: 16
|
||||
---
|
||||
|
||||
# Tablestore
|
||||
|
||||
https://www.aliyun.com/product/ots
|
||||
|
||||
## Maven Dependency
|
||||
|
||||
```xml
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-tablestore</artifactId>
|
||||
<version>0.35.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
## APIs
|
||||
|
||||
- `TablestoreEmbeddingStore`
|
||||
|
||||
## Examples
|
||||
|
||||
- [TablestoreEmbeddingStoreExampleIT](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-tablestore/src/test/java/dev/langchain4j/store/embedding/tablestore/TablestoreEmbeddingStoreExampleIT.java)
|
|
@ -257,6 +257,12 @@
|
|||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-tablestore</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-vearch</artifactId>
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.35.0-SNAPSHOT</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-tablestore</artifactId>
|
||||
<name>LangChain4j :: Integration :: Tablestore</name>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.aliyun.openservices</groupId>
|
||||
<artifactId>tablestore</artifactId>
|
||||
<version>5.17.3</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<classifier>tests</classifier>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- junit-jupiter-params should be declared explicitly
|
||||
to run parameterized tests inherited from EmbeddingStore*IT-->
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java-util</artifactId>
|
||||
<version>3.25.3</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,58 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.StringJoiner;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
public class IsTextMatch implements Filter {
|
||||
|
||||
private final String key;
|
||||
private final String comparisonValue;
|
||||
|
||||
public IsTextMatch(String key, String comparisonValue) {
|
||||
this.key = ensureNotBlank(key, "key");
|
||||
this.comparisonValue = ensureNotNull(comparisonValue, "comparisonValue with key '" + key + "'");
|
||||
}
|
||||
|
||||
public String key() {
|
||||
return key;
|
||||
}
|
||||
|
||||
public String comparisonValue() {
|
||||
return comparisonValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
throw new UnsupportedOperationException("only used in search filters");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (!(o instanceof IsTextMatch)) {
|
||||
return false;
|
||||
}
|
||||
IsTextMatch that = (IsTextMatch) o;
|
||||
return Objects.equals(key, that.key) && Objects.equals(comparisonValue, that.comparisonValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return new StringJoiner(", ", IsTextMatch.class.getSimpleName() + "[", "]")
|
||||
.add("key='" + key + "'")
|
||||
.add("comparisonValue='" + comparisonValue + "'")
|
||||
.toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.StringJoiner;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
public class IsTextMatchPhrase implements Filter {
|
||||
|
||||
private final String key;
|
||||
private final String comparisonValue;
|
||||
|
||||
public IsTextMatchPhrase(String key, String comparisonValue) {
|
||||
this.key = ensureNotBlank(key, "key");
|
||||
this.comparisonValue = ensureNotNull(comparisonValue, "comparisonValue with key '" + key + "'");
|
||||
}
|
||||
|
||||
public String key() {
|
||||
return key;
|
||||
}
|
||||
|
||||
public String comparisonValue() {
|
||||
return comparisonValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
throw new UnsupportedOperationException("only used in search filters");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (!(o instanceof IsTextMatchPhrase)) {
|
||||
return false;
|
||||
}
|
||||
IsTextMatchPhrase that = (IsTextMatchPhrase) o;
|
||||
return Objects.equals(key, that.key) && Objects.equals(comparisonValue, that.comparisonValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return new StringJoiner(", ", IsTextMatchPhrase.class.getSimpleName() + "[", "]")
|
||||
.add("key='" + key + "'")
|
||||
.add("comparisonValue='" + comparisonValue + "'")
|
||||
.toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,523 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.SyncClient;
|
||||
import com.alicloud.openservices.tablestore.core.utils.ValueUtil;
|
||||
import com.alicloud.openservices.tablestore.model.CapacityUnit;
|
||||
import com.alicloud.openservices.tablestore.model.Column;
|
||||
import com.alicloud.openservices.tablestore.model.ColumnType;
|
||||
import com.alicloud.openservices.tablestore.model.ColumnValue;
|
||||
import com.alicloud.openservices.tablestore.model.CreateTableRequest;
|
||||
import com.alicloud.openservices.tablestore.model.DeleteRowRequest;
|
||||
import com.alicloud.openservices.tablestore.model.DeleteTableRequest;
|
||||
import com.alicloud.openservices.tablestore.model.Direction;
|
||||
import com.alicloud.openservices.tablestore.model.GetRangeRequest;
|
||||
import com.alicloud.openservices.tablestore.model.GetRangeResponse;
|
||||
import com.alicloud.openservices.tablestore.model.ListTableResponse;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKey;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeyBuilder;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeySchema;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeyType;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeyValue;
|
||||
import com.alicloud.openservices.tablestore.model.PutRowRequest;
|
||||
import com.alicloud.openservices.tablestore.model.RangeRowQueryCriteria;
|
||||
import com.alicloud.openservices.tablestore.model.ReservedThroughput;
|
||||
import com.alicloud.openservices.tablestore.model.Row;
|
||||
import com.alicloud.openservices.tablestore.model.RowDeleteChange;
|
||||
import com.alicloud.openservices.tablestore.model.RowPutChange;
|
||||
import com.alicloud.openservices.tablestore.model.TableMeta;
|
||||
import com.alicloud.openservices.tablestore.model.TableOptions;
|
||||
import com.alicloud.openservices.tablestore.model.search.CreateSearchIndexRequest;
|
||||
import com.alicloud.openservices.tablestore.model.search.DeleteSearchIndexRequest;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldType;
|
||||
import com.alicloud.openservices.tablestore.model.search.IndexSchema;
|
||||
import com.alicloud.openservices.tablestore.model.search.ListSearchIndexRequest;
|
||||
import com.alicloud.openservices.tablestore.model.search.ListSearchIndexResponse;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchHit;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchIndexInfo;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchQuery;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchRequest;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchResponse;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.KnnVectorQuery;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.Query;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.QueryBuilders;
|
||||
import com.alicloud.openservices.tablestore.model.search.sort.ScoreSort;
|
||||
import com.alicloud.openservices.tablestore.model.search.sort.Sort;
|
||||
import com.alicloud.openservices.tablestore.model.search.vector.VectorDataType;
|
||||
import com.alicloud.openservices.tablestore.model.search.vector.VectorMetricType;
|
||||
import com.alicloud.openservices.tablestore.model.search.vector.VectorOptions;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.Exceptions;
|
||||
import dev.langchain4j.internal.ValidationUtils;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
|
||||
public class TablestoreEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
private final Logger log = LoggerFactory.getLogger(getClass());
|
||||
private final SyncClient client;
|
||||
private final String tableName;
|
||||
private final String searchIndexName;
|
||||
private final String pkName;
|
||||
private final String textField;
|
||||
private final String embeddingField;
|
||||
private final int vectorDimension;
|
||||
private final VectorMetricType vectorMetricType;
|
||||
private final List<FieldSchema> metadataSchemaList;
|
||||
|
||||
private static final String DEFAULT_TABLE_NAME = "langchain4j_embedding_store_ots_v1";
|
||||
private static final String DEFAULT_INDEX_NAME = "langchain4j_embedding_ots_index_v1";
|
||||
private static final String DEFAULT_TABLE_PK_NAME = "id";
|
||||
private static final String DEFAULT_TEXT_FIELD_NAME = "default_content";
|
||||
private static final String DEFAULT_VECTOR_FIELD_NAME = "default_embedding";
|
||||
private static final VectorMetricType DEFAULT_VECTOR_METRIC_TYPE = VectorMetricType.COSINE;
|
||||
|
||||
|
||||
public TablestoreEmbeddingStore(SyncClient client, int vectorDimension) {
|
||||
this(client, vectorDimension, Collections.emptyList());
|
||||
}
|
||||
|
||||
public TablestoreEmbeddingStore(SyncClient client, int vectorDimension, List<FieldSchema> metadataSchemaList) {
|
||||
this(client, DEFAULT_TABLE_NAME, DEFAULT_INDEX_NAME, DEFAULT_TABLE_PK_NAME, DEFAULT_TEXT_FIELD_NAME, DEFAULT_VECTOR_FIELD_NAME, vectorDimension, DEFAULT_VECTOR_METRIC_TYPE, metadataSchemaList);
|
||||
}
|
||||
|
||||
public TablestoreEmbeddingStore(SyncClient client, String tableName, String searchIndexName, String pkName, String textField, String embeddingField, int vectorDimension, VectorMetricType vectorMetricType, List<FieldSchema> metadataSchemaList) {
|
||||
this.client = ValidationUtils.ensureNotNull(client, "client");
|
||||
this.tableName = ValidationUtils.ensureNotBlank(tableName, "tableName");
|
||||
this.searchIndexName = ValidationUtils.ensureNotBlank(searchIndexName, "searchIndexName");
|
||||
this.pkName = ValidationUtils.ensureNotBlank(pkName, "pkName");
|
||||
this.textField = ValidationUtils.ensureNotBlank(textField, "textField");
|
||||
this.embeddingField = ValidationUtils.ensureNotBlank(embeddingField, "embeddingField");
|
||||
this.vectorDimension = ValidationUtils.ensureGreaterThanZero(vectorDimension, "vectorDimension");
|
||||
this.vectorMetricType = ValidationUtils.ensureNotNull(vectorMetricType, "vectorMetricType");
|
||||
ValidationUtils.ensureNotNull(metadataSchemaList, "metadataSchemaList");
|
||||
List<FieldSchema> tmpMetaList = new ArrayList<>();
|
||||
tmpMetaList.add(new FieldSchema(textField, FieldType.TEXT).setIndex(true).setAnalyzer(FieldSchema.Analyzer.MaxWord));
|
||||
tmpMetaList.add(new FieldSchema(embeddingField, FieldType.VECTOR).setIndex(true).setVectorOptions(new VectorOptions(VectorDataType.FLOAT_32, vectorDimension, vectorMetricType)));
|
||||
for (FieldSchema fieldSchema : metadataSchemaList) {
|
||||
if (fieldSchema.getFieldName().equals(textField)) {
|
||||
throw Exceptions.illegalArgument("the custom meta data field name matches the system text field:{}", textField);
|
||||
}
|
||||
if (fieldSchema.getFieldName().equals(embeddingField)) {
|
||||
throw Exceptions.illegalArgument("the custom meta data field name matches the system embedding field:{}", embeddingField);
|
||||
}
|
||||
tmpMetaList.add(fieldSchema);
|
||||
}
|
||||
this.metadataSchemaList = Collections.unmodifiableList(tmpMetaList);
|
||||
}
|
||||
|
||||
public void init() {
|
||||
createTableIfNotExist();
|
||||
createSearchIndexIfNotExist();
|
||||
}
|
||||
|
||||
public SyncClient getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
public String getTableName() {
|
||||
return tableName;
|
||||
}
|
||||
|
||||
public String getSearchIndexName() {
|
||||
return searchIndexName;
|
||||
}
|
||||
|
||||
public String getPkName() {
|
||||
return pkName;
|
||||
}
|
||||
|
||||
public String getTextField() {
|
||||
return textField;
|
||||
}
|
||||
|
||||
public String getEmbeddingField() {
|
||||
return embeddingField;
|
||||
}
|
||||
|
||||
public int getVectorDimension() {
|
||||
return vectorDimension;
|
||||
}
|
||||
|
||||
public VectorMetricType getVectorMetricType() {
|
||||
return vectorMetricType;
|
||||
}
|
||||
|
||||
public List<FieldSchema> getMetadataSchemaList() {
|
||||
return metadataSchemaList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
String id = UUID.randomUUID().toString();
|
||||
innerAdd(id, embedding, null);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
innerAdd(id, embedding, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
String id = UUID.randomUUID().toString();
|
||||
innerAdd(id, embedding, textSegment);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
return addAll(embeddings, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
if (embedded != null) {
|
||||
ValidationUtils.ensureEq(embeddings.size(), embedded.size(), "the size of embeddings should be the same as the size of embedded");
|
||||
}
|
||||
List<String> ids = new ArrayList<>(embeddings.size());
|
||||
List<Exception> exceptions = new ArrayList<>();
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
Embedding embedding = embeddings.get(i);
|
||||
TextSegment textSegment = null;
|
||||
if (embedded != null) {
|
||||
textSegment = embedded.get(i);
|
||||
}
|
||||
try {
|
||||
String id = UUID.randomUUID().toString();
|
||||
innerAdd(id, embedding, textSegment);
|
||||
ids.add(id);
|
||||
} catch (Exception e) {
|
||||
exceptions.add(e);
|
||||
}
|
||||
}
|
||||
if (!exceptions.isEmpty()) {
|
||||
IllegalStateException exception = new IllegalStateException("Add all embeddings with error, failed:" + exceptions.size());
|
||||
for (Exception e : exceptions) {
|
||||
exception.addSuppressed(e);
|
||||
}
|
||||
throw exception;
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove(String id) {
|
||||
ensureNotBlank(id, "id");
|
||||
innerDelete(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAll(Collection<String> ids) {
|
||||
if (ids == null || ids.isEmpty()) {
|
||||
throw Exceptions.illegalArgument("ids cannot be null or empty");
|
||||
}
|
||||
log.debug("remove all:{}", ids);
|
||||
List<Exception> exceptions = new ArrayList<>();
|
||||
for (String id : ids) {
|
||||
try {
|
||||
remove(id);
|
||||
} catch (Exception e) {
|
||||
exceptions.add(e);
|
||||
}
|
||||
}
|
||||
if (!exceptions.isEmpty()) {
|
||||
IllegalStateException exception = new IllegalStateException("remove all embeddings with error, failed:" + exceptions.size());
|
||||
for (Exception e : exceptions) {
|
||||
exception.addSuppressed(e);
|
||||
}
|
||||
throw exception;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAll(Filter filter) {
|
||||
if (filter == null) {
|
||||
throw Exceptions.illegalArgument("filter cannot be null");
|
||||
}
|
||||
forEachAllData(Collections.emptyList(), (row -> {
|
||||
Metadata metadata = rowToMetadata(row);
|
||||
if (filter.test(metadata)) {
|
||||
remove(row.getPrimaryKey().getPrimaryKeyColumn(pkName).getValue().asString());
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAll() {
|
||||
log.debug("remove all");
|
||||
forEachAllData(Collections.emptyList(), (row) -> {
|
||||
this.innerDelete(row.getPrimaryKey().getPrimaryKeyColumn(pkName).getValue().asString());
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
|
||||
log.debug("search ([...{}...], {}, {})", request.queryEmbedding().vector().length, request.maxResults(), request.minScore());
|
||||
KnnVectorQuery knnVectorQuery = QueryBuilders.knnVector(embeddingField, request.maxResults(), request.queryEmbedding().vector())
|
||||
.filter(mapFilterToQuery(request.filter()))
|
||||
.build();
|
||||
SearchQuery searchQuery = SearchQuery.newBuilder()
|
||||
.query(knnVectorQuery)
|
||||
.getTotalCount(false)
|
||||
.limit(request.maxResults())
|
||||
.offset(0)
|
||||
.sort(new Sort(Collections.singletonList(new ScoreSort())))
|
||||
.build();
|
||||
SearchRequest searchRequest = SearchRequest.newBuilder()
|
||||
.tableName(tableName)
|
||||
.indexName(searchIndexName)
|
||||
.searchQuery(searchQuery)
|
||||
.returnAllColumns(true)
|
||||
.build();
|
||||
SearchResponse response = client.search(searchRequest);
|
||||
log.debug("search requestId:{}", response.getRequestId());
|
||||
return searchResponseToEmbeddingSearchResult(request, response);
|
||||
}
|
||||
|
||||
protected Query mapFilterToQuery(Filter filter) {
|
||||
return TablestoreMetadataFilterMapper.map(filter);
|
||||
}
|
||||
|
||||
private EmbeddingSearchResult<TextSegment> searchResponseToEmbeddingSearchResult(EmbeddingSearchRequest request, SearchResponse response) {
|
||||
List<SearchHit> searchHits = response.getSearchHits();
|
||||
List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>(searchHits.size());
|
||||
for (SearchHit hit : searchHits) {
|
||||
Double score = hit.getScore();
|
||||
if (score < request.minScore()) {
|
||||
continue;
|
||||
}
|
||||
Row row = hit.getRow();
|
||||
|
||||
String text = null;
|
||||
if (row.getLatestColumn(textField) != null) {
|
||||
text = row.getLatestColumn(textField).getValue().asString();
|
||||
}
|
||||
|
||||
float[] embedding = null;
|
||||
if (row.getLatestColumn(embeddingField) != null) {
|
||||
String embeddingString = row.getLatestColumn(embeddingField).getValue().asString();
|
||||
embedding = TablestoreUtils.parseEmbeddingString(embeddingString);
|
||||
}
|
||||
|
||||
Metadata metadata = rowToMetadata(row);
|
||||
|
||||
TextSegment textSegment = null;
|
||||
if (text != null && embedding != null) {
|
||||
textSegment = new TextSegment(text, metadata);
|
||||
}
|
||||
|
||||
EmbeddingMatch<TextSegment> match = new EmbeddingMatch<TextSegment>(
|
||||
score,
|
||||
row.getPrimaryKey().getPrimaryKeyColumn(pkName).getValue().asString(),
|
||||
new Embedding(embedding),
|
||||
textSegment
|
||||
);
|
||||
matches.add(match);
|
||||
}
|
||||
return new EmbeddingSearchResult<>(matches);
|
||||
}
|
||||
|
||||
private void createTableIfNotExist() {
|
||||
if (tableExists()) {
|
||||
log.info("table:{} already exists", tableName);
|
||||
return;
|
||||
}
|
||||
TableMeta tableMeta = new TableMeta(this.tableName);
|
||||
tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(pkName, PrimaryKeyType.STRING));
|
||||
TableOptions tableOptions = new TableOptions(-1, 1);
|
||||
CreateTableRequest request = new CreateTableRequest(tableMeta, tableOptions);
|
||||
request.setReservedThroughput(new ReservedThroughput(new CapacityUnit(0, 0)));
|
||||
client.createTable(request);
|
||||
log.info("create table:{}", tableName);
|
||||
}
|
||||
|
||||
private void createSearchIndexIfNotExist() {
|
||||
if (searchindexExists()) {
|
||||
log.info("index:{} already exists", searchIndexName);
|
||||
return;
|
||||
}
|
||||
CreateSearchIndexRequest request = new CreateSearchIndexRequest();
|
||||
request.setTableName(tableName);
|
||||
request.setIndexName(searchIndexName);
|
||||
IndexSchema indexSchema = new IndexSchema();
|
||||
indexSchema.setFieldSchemas(metadataSchemaList);
|
||||
request.setIndexSchema(indexSchema);
|
||||
client.createSearchIndex(request);
|
||||
log.info("create index:{}", searchIndexName);
|
||||
}
|
||||
|
||||
protected void deleteTableAndIndex() {
|
||||
List<SearchIndexInfo> searchIndexInfos = listSearchIndex();
|
||||
deleteIndex(searchIndexInfos);
|
||||
deleteTable();
|
||||
}
|
||||
|
||||
private boolean tableExists() {
|
||||
ListTableResponse listTableResponse = client.listTable();
|
||||
return listTableResponse.getTableNames().contains(tableName);
|
||||
}
|
||||
|
||||
private boolean searchindexExists() {
|
||||
List<SearchIndexInfo> searchIndexInfos = listSearchIndex();
|
||||
for (SearchIndexInfo indexInfo : searchIndexInfos) {
|
||||
if (indexInfo.getIndexName().equals(searchIndexName)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private void deleteIndex(List<SearchIndexInfo> indexNames) {
|
||||
indexNames.forEach(info -> {
|
||||
DeleteSearchIndexRequest request = new DeleteSearchIndexRequest();
|
||||
request.setTableName(info.getTableName());
|
||||
request.setIndexName(info.getIndexName());
|
||||
client.deleteSearchIndex(request);
|
||||
log.info("delete table:{}, index:{}", info.getTableName(), info.getIndexName());
|
||||
});
|
||||
}
|
||||
|
||||
private void deleteTable() {
|
||||
DeleteTableRequest request = new DeleteTableRequest(tableName);
|
||||
client.deleteTable(request);
|
||||
log.info("delete table:{}", tableName);
|
||||
}
|
||||
|
||||
private List<SearchIndexInfo> listSearchIndex() {
|
||||
ListSearchIndexRequest request = new ListSearchIndexRequest();
|
||||
request.setTableName(tableName);
|
||||
ListSearchIndexResponse listSearchIndexResponse = client.listSearchIndex(request);
|
||||
return listSearchIndexResponse.getIndexInfos();
|
||||
}
|
||||
|
||||
protected void innerAdd(String id, Embedding embedding, TextSegment textSegment) {
|
||||
ValidationUtils.ensureNotNull(embedding, "embedding");
|
||||
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.fromString(id));
|
||||
PrimaryKey primaryKey = primaryKeyBuilder.build();
|
||||
RowPutChange rowPutChange = new RowPutChange(this.tableName, primaryKey);
|
||||
String embeddinged = TablestoreUtils.embeddingToString(embedding.vector());
|
||||
rowPutChange.addColumn(new Column(this.embeddingField, ColumnValue.fromString(embeddinged)));
|
||||
if (textSegment != null) {
|
||||
String text = textSegment.text();
|
||||
if (text != null) {
|
||||
rowPutChange.addColumn(new Column(this.textField, ColumnValue.fromString(text)));
|
||||
}
|
||||
Metadata metadata = textSegment.metadata();
|
||||
if (metadata != null) {
|
||||
Map<String, Object> map = metadata.toMap();
|
||||
for (Map.Entry<String, Object> entry : map.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
Object value = entry.getValue();
|
||||
if (this.textField.equals(key)) {
|
||||
throw Exceptions.illegalArgument("there is a metadata(%s,%s) that is consistent with the name of the text field:%s", key, value, this.textField);
|
||||
}
|
||||
if (this.embeddingField.equals(key)) {
|
||||
throw Exceptions.illegalArgument("there is a metadata(%s,%s) that is consistent with the name of the vector field:%s", key, value, this.embeddingField);
|
||||
}
|
||||
if (value instanceof Float) {
|
||||
rowPutChange.addColumn(new Column(key, ColumnValue.fromDouble((Float) value)));
|
||||
} else if (value instanceof UUID) {
|
||||
rowPutChange.addColumn(new Column(key, ColumnValue.fromString(((UUID) value).toString())));
|
||||
} else {
|
||||
rowPutChange.addColumn(new Column(key, ValueUtil.toColumnValue(value)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
try {
|
||||
client.putRow(new PutRowRequest(rowPutChange));
|
||||
if (log.isDebugEnabled()) {
|
||||
log.debug("add id:{}, textSegment:{}, embedding:{}", id, textSegment, TablestoreUtils.maxLogOrNull(embedding.toString()));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("add embedding data failed, id:%s, textSegment:%s,embedding:%s", id, textSegment, embedding), e);
|
||||
}
|
||||
}
|
||||
|
||||
protected void innerDelete(String id) {
|
||||
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.fromString(id));
|
||||
PrimaryKey primaryKey = primaryKeyBuilder.build();
|
||||
RowDeleteChange rowDeleteChange = new RowDeleteChange(this.tableName, primaryKey);
|
||||
try {
|
||||
client.deleteRow(new DeleteRowRequest(rowDeleteChange));
|
||||
log.debug("delete id:{}", id);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("delete embedding data failed, id:%s", id), e);
|
||||
}
|
||||
}
|
||||
|
||||
private void forEachAllData(Collection<String> columnsToGet, Consumer<Row> rowConsumer) {
|
||||
RangeRowQueryCriteria rangeRowQueryCriteria = new RangeRowQueryCriteria(this.tableName);
|
||||
PrimaryKeyBuilder start = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
start.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.INF_MIN);
|
||||
PrimaryKeyBuilder end = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
end.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.INF_MAX);
|
||||
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(start.build());
|
||||
rangeRowQueryCriteria.setExclusiveEndPrimaryKey(end.build());
|
||||
rangeRowQueryCriteria.setMaxVersions(1);
|
||||
rangeRowQueryCriteria.setLimit(5000);
|
||||
rangeRowQueryCriteria.addColumnsToGet(columnsToGet);
|
||||
rangeRowQueryCriteria.setDirection(Direction.FORWARD);
|
||||
GetRangeRequest getRangeRequest = new GetRangeRequest(rangeRowQueryCriteria);
|
||||
GetRangeResponse getRangeResponse;
|
||||
while (true) {
|
||||
getRangeResponse = client.getRange(getRangeRequest);
|
||||
for (Row row : getRangeResponse.getRows()) {
|
||||
rowConsumer.accept(row);
|
||||
}
|
||||
if (getRangeResponse.getNextStartPrimaryKey() != null) {
|
||||
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(getRangeResponse.getNextStartPrimaryKey());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Metadata rowToMetadata(Row row) {
|
||||
Metadata metadata = new Metadata();
|
||||
for (Column column : row.getColumns()) {
|
||||
if (column.getName().equals(embeddingField)) {
|
||||
continue;
|
||||
}
|
||||
if (column.getName().equals(textField)) {
|
||||
continue;
|
||||
}
|
||||
ColumnType columnType = column.getValue().getType();
|
||||
switch (columnType) {
|
||||
case STRING:
|
||||
metadata.put(column.getName(), column.getValue().asString());
|
||||
break;
|
||||
case INTEGER:
|
||||
metadata.put(column.getName(), column.getValue().asLong());
|
||||
break;
|
||||
case DOUBLE:
|
||||
metadata.put(column.getName(), column.getValue().asDouble());
|
||||
break;
|
||||
default:
|
||||
log.warn("unsupported columnType:{}, key:{}, value:{}", columnType, column.getName(), column.getValue());
|
||||
}
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.model.search.query.Query;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.QueryBuilders;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.TermsQuery;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
|
||||
import dev.langchain4j.store.embedding.filter.logical.And;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Not;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Or;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
class TablestoreMetadataFilterMapper {
|
||||
|
||||
static Query map(Filter filter) {
|
||||
if (filter == null) {
|
||||
return QueryBuilders.matchAll().build();
|
||||
}
|
||||
if (filter instanceof IsEqualTo) {
|
||||
return mapEqual((IsEqualTo) filter);
|
||||
} else if (filter instanceof IsNotEqualTo) {
|
||||
return mapNotEqual((IsNotEqualTo) filter);
|
||||
} else if (filter instanceof IsTextMatch) {
|
||||
return mapMatch((IsTextMatch) filter);
|
||||
} else if (filter instanceof IsTextMatchPhrase) {
|
||||
return mapMatchPhrase((IsTextMatchPhrase) filter);
|
||||
} else if (filter instanceof IsGreaterThan) {
|
||||
return mapGreaterThan((IsGreaterThan) filter);
|
||||
} else if (filter instanceof IsGreaterThanOrEqualTo) {
|
||||
return mapGreaterThanOrEqual((IsGreaterThanOrEqualTo) filter);
|
||||
} else if (filter instanceof IsLessThan) {
|
||||
return mapLessThan((IsLessThan) filter);
|
||||
} else if (filter instanceof IsLessThanOrEqualTo) {
|
||||
return mapLessThanOrEqual((IsLessThanOrEqualTo) filter);
|
||||
} else if (filter instanceof IsIn) {
|
||||
return mapIn((IsIn) filter);
|
||||
} else if (filter instanceof IsNotIn) {
|
||||
return mapNotIn((IsNotIn) filter);
|
||||
} else if (filter instanceof And) {
|
||||
return mapAnd((And) filter);
|
||||
} else if (filter instanceof Not) {
|
||||
return mapNot((Not) filter);
|
||||
} else if (filter instanceof Or) {
|
||||
return mapOr((Or) filter);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported filter type: " + filter.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
private static Object transformType(Object object) {
|
||||
if (object instanceof Float) {
|
||||
object = ((Float) object).doubleValue();
|
||||
}
|
||||
if (object instanceof UUID) {
|
||||
object = ((UUID) object).toString();
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
private static Query mapEqual(IsEqualTo isEqualTo) {
|
||||
return QueryBuilders.term(isEqualTo.key(), transformType(isEqualTo.comparisonValue()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapMatch(IsTextMatch isTextMatch) {
|
||||
return QueryBuilders.match(isTextMatch.key(), isTextMatch.comparisonValue())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapMatchPhrase(IsTextMatchPhrase isTextMatchPhrase) {
|
||||
return QueryBuilders.matchPhrase(isTextMatchPhrase.key(), isTextMatchPhrase.comparisonValue())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapNotEqual(IsNotEqualTo isNotEqualTo) {
|
||||
return QueryBuilders.bool()
|
||||
.mustNot(QueryBuilders.term(isNotEqualTo.key(), transformType(isNotEqualTo.comparisonValue())))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapGreaterThan(IsGreaterThan isGreaterThan) {
|
||||
return QueryBuilders.range(isGreaterThan.key())
|
||||
.greaterThan(transformType(isGreaterThan.comparisonValue()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapGreaterThanOrEqual(IsGreaterThanOrEqualTo isGreaterThanOrEqualTo) {
|
||||
return QueryBuilders.range(isGreaterThanOrEqualTo.key())
|
||||
.greaterThanOrEqual(transformType(isGreaterThanOrEqualTo.comparisonValue()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapLessThan(IsLessThan isLessThan) {
|
||||
return QueryBuilders.range(isLessThan.key())
|
||||
.lessThan(transformType(isLessThan.comparisonValue()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapLessThanOrEqual(IsLessThanOrEqualTo isLessThanOrEqualTo) {
|
||||
return QueryBuilders.range(isLessThanOrEqualTo.key())
|
||||
.lessThanOrEqual(transformType(isLessThanOrEqualTo.comparisonValue()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapIn(IsIn isIn) {
|
||||
TermsQuery.Builder builder = QueryBuilders.terms(isIn.key());
|
||||
for (Object object : isIn.comparisonValues()) {
|
||||
builder.addTerm(transformType(object));
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private static Query mapNotIn(IsNotIn isNotIn) {
|
||||
TermsQuery.Builder builder = QueryBuilders.terms(isNotIn.key());
|
||||
for (Object object : isNotIn.comparisonValues()) {
|
||||
builder.addTerm(transformType(object));
|
||||
}
|
||||
return QueryBuilders.bool()
|
||||
.mustNot(builder.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapAnd(And and) {
|
||||
return QueryBuilders.bool()
|
||||
.must(map(and.left()))
|
||||
.must(map(and.right()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapNot(Not not) {
|
||||
return QueryBuilders.bool()
|
||||
.mustNot(map(not.expression()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Query mapOr(Or or) {
|
||||
return QueryBuilders.bool()
|
||||
.should(map(or.left()))
|
||||
.should(map(or.right()))
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import dev.langchain4j.internal.ValidationUtils;
|
||||
|
||||
class TablestoreUtils {
|
||||
|
||||
private static final int MAX_DEBUG_LOG_LENGTH = 100;
|
||||
private static final Gson GSON = new GsonBuilder().create();
|
||||
|
||||
protected static float[] parseEmbeddingString(String embeddingString) {
|
||||
ValidationUtils.ensureNotBlank(embeddingString, "embeddingString");
|
||||
return GSON.fromJson(embeddingString, float[].class);
|
||||
}
|
||||
|
||||
protected static String embeddingToString(float[] embedding) {
|
||||
ValidationUtils.ensureNotNull(embedding, "embedding");
|
||||
return GSON.toJson(embedding);
|
||||
}
|
||||
|
||||
protected static String maxLogOrNull(String str) {
|
||||
if (str == null) {
|
||||
return null;
|
||||
}
|
||||
int max = MAX_DEBUG_LOG_LENGTH;
|
||||
if (str.length() <= max) {
|
||||
return str;
|
||||
}
|
||||
return str.substring(0, max) + "......";
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
package dev.langchain4j.store.memory.chat.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.SyncClient;
|
||||
import com.alicloud.openservices.tablestore.model.CapacityUnit;
|
||||
import com.alicloud.openservices.tablestore.model.Column;
|
||||
import com.alicloud.openservices.tablestore.model.ColumnValue;
|
||||
import com.alicloud.openservices.tablestore.model.CreateTableRequest;
|
||||
import com.alicloud.openservices.tablestore.model.DeleteRowRequest;
|
||||
import com.alicloud.openservices.tablestore.model.Direction;
|
||||
import com.alicloud.openservices.tablestore.model.GetRangeRequest;
|
||||
import com.alicloud.openservices.tablestore.model.GetRangeResponse;
|
||||
import com.alicloud.openservices.tablestore.model.ListTableResponse;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKey;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeyBuilder;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeySchema;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeyType;
|
||||
import com.alicloud.openservices.tablestore.model.PrimaryKeyValue;
|
||||
import com.alicloud.openservices.tablestore.model.PutRowRequest;
|
||||
import com.alicloud.openservices.tablestore.model.RangeRowQueryCriteria;
|
||||
import com.alicloud.openservices.tablestore.model.ReservedThroughput;
|
||||
import com.alicloud.openservices.tablestore.model.Row;
|
||||
import com.alicloud.openservices.tablestore.model.RowDeleteChange;
|
||||
import com.alicloud.openservices.tablestore.model.RowPutChange;
|
||||
import com.alicloud.openservices.tablestore.model.TableMeta;
|
||||
import com.alicloud.openservices.tablestore.model.TableOptions;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ChatMessageDeserializer;
|
||||
import dev.langchain4j.data.message.ChatMessageSerializer;
|
||||
import dev.langchain4j.internal.ValidationUtils;
|
||||
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class TablestoreChatMemoryStore implements ChatMemoryStore {
|
||||
|
||||
private final Logger log = LoggerFactory.getLogger(getClass());
|
||||
private final SyncClient client;
|
||||
private final String tableName;
|
||||
private final String pkName1;
|
||||
private final String pkName2;
|
||||
private final String chatMessageFieldName;
|
||||
|
||||
private static final String DEFAULT_TABLE_NAME = "langchain4j_chat_memory_store_ots_v1";
|
||||
private static final String DEFAULT_TABLE_PK_1_NAME = "memory_id";
|
||||
private static final String DEFAULT_TABLE_PK_2_NAME = "seq_no";
|
||||
private static final String DEFAULT_CHAT_MESSAGE_FIELD_NAME = "chat_message";
|
||||
|
||||
public TablestoreChatMemoryStore(SyncClient client) {
|
||||
this(client, DEFAULT_TABLE_NAME, DEFAULT_TABLE_PK_1_NAME, DEFAULT_TABLE_PK_2_NAME, DEFAULT_CHAT_MESSAGE_FIELD_NAME);
|
||||
}
|
||||
|
||||
public TablestoreChatMemoryStore(SyncClient client, String tableName, String pkName1, String pkName2, String chatMessageFieldName) {
|
||||
this.client = client;
|
||||
this.tableName = tableName;
|
||||
this.pkName1 = pkName1;
|
||||
this.pkName2 = pkName2;
|
||||
this.chatMessageFieldName = chatMessageFieldName;
|
||||
}
|
||||
|
||||
public void init() {
|
||||
createTableIfNotExist();
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all message.
|
||||
*/
|
||||
public void clear() {
|
||||
forEachAllData(PrimaryKeyValue.INF_MIN, PrimaryKeyValue.INF_MAX, row -> {
|
||||
String id = row.getPrimaryKey().getPrimaryKeyColumn(pkName1).getValue().asString();
|
||||
long seqNo = row.getPrimaryKey().getPrimaryKeyColumn(pkName2).getValue().asLong();
|
||||
innerDelete(id, seqNo);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessage> getMessages(Object memoryId) {
|
||||
String memoryIdStr = getMemoryId(memoryId);
|
||||
log.debug("get messages, memoryIdStr:{}", memoryIdStr);
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
forEachAllData(PrimaryKeyValue.fromString(memoryIdStr), row -> {
|
||||
Column column = row.getLatestColumn(chatMessageFieldName);
|
||||
if (column != null) {
|
||||
String jsonString = column.getValue().asString();
|
||||
try {
|
||||
ChatMessage chatMessage = ChatMessageDeserializer.messageFromJson(jsonString);
|
||||
messages.add(chatMessage);
|
||||
} catch (Exception e) {
|
||||
String id = row.getPrimaryKey().getPrimaryKeyColumn(pkName1).getValue().asString();
|
||||
long seqNo = row.getPrimaryKey().getPrimaryKeyColumn(pkName2).getValue().asLong();
|
||||
throw new RuntimeException(String.format("unable to parse message body, memoryId:%s, seqNo:%s", id, seqNo), e);
|
||||
}
|
||||
}
|
||||
});
|
||||
return messages;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
|
||||
String memoryIdStr = getMemoryId(memoryId);
|
||||
log.debug("update messages, memoryIdStr:{}", memoryIdStr);
|
||||
ValidationUtils.ensureNotEmpty(messages, "messages");
|
||||
deleteMessages(memoryId);
|
||||
List<Exception> exceptions = new ArrayList<>();
|
||||
for (int i = 0; i < messages.size(); i++) {
|
||||
ChatMessage message = messages.get(i);
|
||||
try {
|
||||
innerAdd(memoryIdStr, i, ChatMessageSerializer.messageToJson(message));
|
||||
} catch (Exception e) {
|
||||
exceptions.add(e);
|
||||
}
|
||||
}
|
||||
if (!exceptions.isEmpty()) {
|
||||
IllegalStateException exception = new IllegalStateException("update messages with error, failed:" + exceptions.size());
|
||||
for (Exception e : exceptions) {
|
||||
exception.addSuppressed(e);
|
||||
}
|
||||
throw exception;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteMessages(Object memoryId) {
|
||||
String memoryIdStr = getMemoryId(memoryId);
|
||||
log.debug("delete messages, memoryIdStr:{}", memoryIdStr);
|
||||
forEachAllData(PrimaryKeyValue.fromString(memoryIdStr), row -> {
|
||||
String id = row.getPrimaryKey().getPrimaryKeyColumn(pkName1).getValue().asString();
|
||||
long seqNo = row.getPrimaryKey().getPrimaryKeyColumn(pkName2).getValue().asLong();
|
||||
innerDelete(id, seqNo);
|
||||
});
|
||||
}
|
||||
|
||||
private void innerDelete(String memoryId, long seqNo) {
|
||||
ValidationUtils.ensureNotNull(memoryId, "memoryId");
|
||||
ValidationUtils.ensureNotNull(seqNo, "seqNo");
|
||||
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName1, PrimaryKeyValue.fromString(memoryId));
|
||||
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.fromLong(seqNo));
|
||||
PrimaryKey primaryKey = primaryKeyBuilder.build();
|
||||
RowDeleteChange rowDeleteChange = new RowDeleteChange(this.tableName, primaryKey);
|
||||
try {
|
||||
client.deleteRow(new DeleteRowRequest(rowDeleteChange));
|
||||
log.debug("delete memoryId:{}, seqNo:{}", memoryId, seqNo);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("delete embedding data failed, memoryId:%s, seqNo:%s", memoryId, seqNo), e);
|
||||
}
|
||||
}
|
||||
|
||||
private void innerAdd(String memoryId, int seqNo, String chatMessage) {
|
||||
ValidationUtils.ensureNotNull(memoryId, "memoryId");
|
||||
ValidationUtils.ensureNotNull(seqNo, "seqNo");
|
||||
ValidationUtils.ensureNotNull(chatMessage, "chatMessage");
|
||||
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName1, PrimaryKeyValue.fromString(memoryId));
|
||||
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.fromLong(seqNo));
|
||||
PrimaryKey primaryKey = primaryKeyBuilder.build();
|
||||
RowPutChange rowPutChange = new RowPutChange(this.tableName, primaryKey);
|
||||
rowPutChange.addColumn(new Column(chatMessageFieldName, ColumnValue.fromString(chatMessage)));
|
||||
try {
|
||||
client.putRow(new PutRowRequest(rowPutChange));
|
||||
if (log.isDebugEnabled()) {
|
||||
log.debug("add memoryId:{}, seqNo:{}, chatMessage:{}", memoryId, seqNo, chatMessage);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("add embedding data failed, memoryId:%s, seqNo:%s, chatMessage:%s", memoryId, seqNo, chatMessage), e);
|
||||
}
|
||||
}
|
||||
|
||||
private String getMemoryId(Object memoryId) {
|
||||
boolean isNullOrEmpty = memoryId == null || memoryId.toString().trim().isEmpty();
|
||||
if (isNullOrEmpty) {
|
||||
throw new IllegalArgumentException("memoryId cannot be null or empty");
|
||||
}
|
||||
return memoryId.toString();
|
||||
}
|
||||
|
||||
private void createTableIfNotExist() {
|
||||
if (tableExists()) {
|
||||
log.info("table:{} already exists", tableName);
|
||||
return;
|
||||
}
|
||||
TableMeta tableMeta = new TableMeta(this.tableName);
|
||||
tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(pkName1, PrimaryKeyType.STRING));
|
||||
tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(pkName2, PrimaryKeyType.INTEGER));
|
||||
TableOptions tableOptions = new TableOptions(-1, 1);
|
||||
CreateTableRequest request = new CreateTableRequest(tableMeta, tableOptions);
|
||||
request.setReservedThroughput(new ReservedThroughput(new CapacityUnit(0, 0)));
|
||||
client.createTable(request);
|
||||
log.info("create table:{}", tableName);
|
||||
}
|
||||
|
||||
|
||||
private boolean tableExists() {
|
||||
ListTableResponse listTableResponse = client.listTable();
|
||||
return listTableResponse.getTableNames().contains(tableName);
|
||||
}
|
||||
|
||||
private void forEachAllData(PrimaryKeyValue memoryId, Consumer<Row> rowConsumer) {
|
||||
forEachAllData(memoryId, memoryId, rowConsumer);
|
||||
}
|
||||
|
||||
private void forEachAllData(PrimaryKeyValue memoryIdStart, PrimaryKeyValue memoryIdEnd, Consumer<Row> rowConsumer) {
|
||||
RangeRowQueryCriteria rangeRowQueryCriteria = new RangeRowQueryCriteria(this.tableName);
|
||||
PrimaryKeyBuilder start = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
start.addPrimaryKeyColumn(this.pkName1, memoryIdStart);
|
||||
start.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.INF_MIN);
|
||||
PrimaryKeyBuilder end = PrimaryKeyBuilder.createPrimaryKeyBuilder();
|
||||
end.addPrimaryKeyColumn(this.pkName1, memoryIdEnd);
|
||||
end.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.INF_MAX);
|
||||
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(start.build());
|
||||
rangeRowQueryCriteria.setExclusiveEndPrimaryKey(end.build());
|
||||
rangeRowQueryCriteria.setMaxVersions(1);
|
||||
rangeRowQueryCriteria.setLimit(5000);
|
||||
rangeRowQueryCriteria.addColumnsToGet(Collections.singletonList(chatMessageFieldName));
|
||||
rangeRowQueryCriteria.setDirection(Direction.FORWARD);
|
||||
GetRangeRequest getRangeRequest = new GetRangeRequest(rangeRowQueryCriteria);
|
||||
GetRangeResponse getRangeResponse;
|
||||
while (true) {
|
||||
getRangeResponse = client.getRange(getRangeRequest);
|
||||
for (Row row : getRangeResponse.getRows()) {
|
||||
rowConsumer.accept(row);
|
||||
}
|
||||
if (getRangeResponse.getNextStartPrimaryKey() != null) {
|
||||
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(getRangeResponse.getNextStartPrimaryKey());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.SyncClient;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldType;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
|
||||
class TablestoreEmbeddingStoreExampleIT {
|
||||
|
||||
@Test
|
||||
void test_simple() {
|
||||
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
/*
|
||||
* Step 1: create a TablestoreEmbeddingStore.
|
||||
*/
|
||||
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
|
||||
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
|
||||
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
|
||||
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
|
||||
TablestoreEmbeddingStore embeddingStore = new TablestoreEmbeddingStore(
|
||||
new SyncClient(endpoint,
|
||||
accessKeyId,
|
||||
accessKeySecret,
|
||||
instanceName),
|
||||
384,
|
||||
Arrays.asList(
|
||||
new FieldSchema("meta_example_keyword", FieldType.KEYWORD),
|
||||
new FieldSchema("meta_example_long", FieldType.LONG),
|
||||
new FieldSchema("meta_example_double", FieldType.DOUBLE),
|
||||
new FieldSchema("meta_example_text", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.MaxWord)
|
||||
)
|
||||
);
|
||||
/*
|
||||
* Step 2: init.
|
||||
*
|
||||
* Note: It only needs to be executed once, and the first execution requires
|
||||
* waiting for table and index initialization
|
||||
*/
|
||||
embeddingStore.init();
|
||||
|
||||
/*
|
||||
* Step 3: Add some docs.
|
||||
*/
|
||||
TextSegment segment1 = TextSegment.from(
|
||||
"I like football.",
|
||||
new Metadata().put("meta_example_keyword", "a")
|
||||
.put("meta_example_long", 123)
|
||||
.put("meta_example_double", 1.5)
|
||||
.put("meta_example_text", "dog cat")
|
||||
);
|
||||
Embedding embedding1 = embeddingModel.embed(segment1).content();
|
||||
embeddingStore.add(embedding1, segment1);
|
||||
|
||||
TextSegment segment2 = TextSegment.from(
|
||||
"The weather is good today.",
|
||||
new Metadata().put("meta_example_keyword", "b")
|
||||
.put("meta_example_long", 456)
|
||||
.put("meta_example_double", 5.6)
|
||||
.put("meta_example_text", "foo boo")
|
||||
);
|
||||
Embedding embedding2 = embeddingModel.embed(segment2).content();
|
||||
embeddingStore.add(embedding2, segment2);
|
||||
|
||||
/*
|
||||
* Step 4: Search
|
||||
*/
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embeddingModel.embed("What is your favourite sport?").content())
|
||||
.filter(new IsLessThan("meta_example_double", 0.5))
|
||||
.maxResults(100)
|
||||
.build();
|
||||
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
|
||||
// get result detail.
|
||||
for (EmbeddingMatch<TextSegment> match : result.matches()) {
|
||||
String embeddingId = match.embeddingId();
|
||||
Double score = match.score();
|
||||
Embedding embedding = match.embedding();
|
||||
TextSegment embedded = match.embedded();
|
||||
String text = embedded.text();
|
||||
Metadata metadata = embedded.metadata();
|
||||
Assertions.assertNotNull(embeddingId);
|
||||
Assertions.assertNotNull(score);
|
||||
Assertions.assertNotNull(embedding);
|
||||
Assertions.assertNotNull(text);
|
||||
Assertions.assertNotNull(metadata);
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 5: Delete docs.
|
||||
*/
|
||||
embeddingStore.remove("id_example");
|
||||
embeddingStore.removeAll();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,300 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.SyncClient;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldType;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchQuery;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchRequest;
|
||||
import com.alicloud.openservices.tablestore.model.search.SearchResponse;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.MatchAllQuery;
|
||||
import com.alicloud.openservices.tablestore.model.search.query.Query;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
|
||||
import org.apache.commons.lang3.function.TriConsumer;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Assumptions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
|
||||
class TablestoreEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
|
||||
private final Logger log = LoggerFactory.getLogger(getClass());
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
private final static long WAIT_FOR_REPLICA_TIME_IN_MILLS = TimeUnit.SECONDS.toMillis(3);
|
||||
|
||||
private final TablestoreEmbeddingStore embeddingStore;
|
||||
|
||||
private final AtomicLong trackDocsForTest = new AtomicLong(0);
|
||||
|
||||
TablestoreEmbeddingStoreIT() {
|
||||
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
|
||||
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
|
||||
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
|
||||
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
|
||||
this.embeddingStore = new TablestoreEmbeddingStore(
|
||||
new SyncClient(endpoint,
|
||||
accessKeyId,
|
||||
accessKeySecret,
|
||||
instanceName),
|
||||
384,
|
||||
Arrays.asList(
|
||||
new FieldSchema("name", FieldType.KEYWORD),
|
||||
new FieldSchema("name2", FieldType.KEYWORD),
|
||||
new FieldSchema("key", FieldType.KEYWORD),
|
||||
new FieldSchema("key2", FieldType.KEYWORD),
|
||||
new FieldSchema("city", FieldType.KEYWORD),
|
||||
new FieldSchema("country", FieldType.KEYWORD),
|
||||
new FieldSchema("age", FieldType.LONG),
|
||||
new FieldSchema("age2", FieldType.LONG),
|
||||
new FieldSchema("meta_example_double", FieldType.DOUBLE),
|
||||
new FieldSchema("meta_example_text_max_word", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.MaxWord),
|
||||
new FieldSchema("meta_example_text_fuzzy", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.Fuzzy)
|
||||
)
|
||||
) {
|
||||
// Override for test
|
||||
@Override
|
||||
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
|
||||
if (request.maxResults() > 100) {
|
||||
request = new EmbeddingSearchRequest(request.queryEmbedding(), 100, request.minScore(), request.filter());
|
||||
}
|
||||
return super.search(request);
|
||||
}
|
||||
|
||||
// Override for test
|
||||
@Override
|
||||
protected void innerAdd(String id, Embedding embedding, TextSegment textSegment) {
|
||||
super.innerAdd(id, embedding, textSegment);
|
||||
trackDocsForTest.incrementAndGet();
|
||||
}
|
||||
|
||||
// Override for test
|
||||
@Override
|
||||
protected void innerDelete(String id) {
|
||||
super.innerDelete(id);
|
||||
trackDocsForTest.decrementAndGet();
|
||||
}
|
||||
|
||||
// Override for test: exclude the use of incorrect field types in base class testing
|
||||
@Override
|
||||
protected Query mapFilterToQuery(Filter filter) {
|
||||
if (filter instanceof IsEqualTo) {
|
||||
if (((IsEqualTo) filter).comparisonValue() instanceof Number) {
|
||||
Assumptions.abort("keyword not support number");
|
||||
}
|
||||
}
|
||||
if (filter instanceof IsNotEqualTo) {
|
||||
if (((IsNotEqualTo) filter).comparisonValue() instanceof Number) {
|
||||
Assumptions.abort("keyword not support number");
|
||||
}
|
||||
}
|
||||
if (filter instanceof IsLessThan) {
|
||||
IsLessThan t = (IsLessThan) filter;
|
||||
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
|
||||
Assumptions.abort("keyword not support number");
|
||||
}
|
||||
}
|
||||
if (filter instanceof IsLessThanOrEqualTo) {
|
||||
IsLessThanOrEqualTo t = (IsLessThanOrEqualTo) filter;
|
||||
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
|
||||
Assumptions.abort("keyword not support number");
|
||||
}
|
||||
}
|
||||
if (filter instanceof IsGreaterThan) {
|
||||
IsGreaterThan t = (IsGreaterThan) filter;
|
||||
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
|
||||
Assumptions.abort("keyword not support number");
|
||||
}
|
||||
}
|
||||
if (filter instanceof IsGreaterThanOrEqualTo) {
|
||||
IsGreaterThanOrEqualTo t = (IsGreaterThanOrEqualTo) filter;
|
||||
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
|
||||
Assumptions.abort("keyword not support number");
|
||||
}
|
||||
}
|
||||
return super.mapFilterToQuery(filter);
|
||||
}
|
||||
};
|
||||
this.embeddingStore.init();
|
||||
this.embeddingStore.removeAll();
|
||||
ensureSearchDataReady(0);
|
||||
}
|
||||
|
||||
protected void awaitUntilPersisted() {
|
||||
ensureSearchDataReady(trackDocsForTest.get());
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
@AfterEach
|
||||
void setUp() {
|
||||
trackDocsForTest.set(0);
|
||||
clearStore();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
embeddingStore.removeAll();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_match_and_match_phrase_and_double_range() {
|
||||
Embedding embedding = embeddingModel().embed("ok").content();
|
||||
embeddingStore.add(
|
||||
embedding,
|
||||
new TextSegment(
|
||||
"ok",
|
||||
new Metadata()
|
||||
.put("meta_example_double", 1d)
|
||||
.put("meta_example_text_max_word", "a b c ab ac")
|
||||
.put("meta_example_text_fuzzy", "a b c abac")
|
||||
)
|
||||
);
|
||||
awaitUntilPersisted();
|
||||
TriConsumer<String, String, Integer> matchTester = (field, value, expectSize) -> {
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding)
|
||||
.filter(new IsTextMatch(field, value))
|
||||
.maxResults(100)
|
||||
.build();
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
|
||||
|
||||
// then
|
||||
assertThat(matches).hasSize(expectSize);
|
||||
};
|
||||
|
||||
matchTester.accept("meta_example_text_max_word", "a b c", 1);
|
||||
matchTester.accept("meta_example_text_max_word", "ac", 1);
|
||||
matchTester.accept("meta_example_text_max_word", "abc", 0);
|
||||
matchTester.accept("meta_example_text_max_word", "abac", 0);
|
||||
matchTester.accept("meta_example_text_max_word", "ab", 1);
|
||||
|
||||
matchTester.accept("meta_example_text_fuzzy", "a b c", 1);
|
||||
matchTester.accept("meta_example_text_fuzzy", "ac", 1);
|
||||
matchTester.accept("meta_example_text_fuzzy", "abc", 0);
|
||||
matchTester.accept("meta_example_text_fuzzy", "abac", 1);
|
||||
matchTester.accept("meta_example_text_fuzzy", "ab", 1);
|
||||
|
||||
{
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding)
|
||||
.filter(new IsTextMatchPhrase("meta_example_text_fuzzy", "a b c abac"))
|
||||
.maxResults(100)
|
||||
.build();
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
|
||||
// then
|
||||
assertThat(matches).hasSize(1);
|
||||
}
|
||||
|
||||
{
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding)
|
||||
.filter(new IsGreaterThan("meta_example_double", 0.5))
|
||||
.maxResults(100)
|
||||
.build();
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
|
||||
// then
|
||||
assertThat(matches).hasSize(1);
|
||||
}
|
||||
{
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding)
|
||||
.filter(new IsLessThan("meta_example_double", 0.5))
|
||||
.maxResults(100)
|
||||
.build();
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
|
||||
// then
|
||||
assertThat(matches).hasSize(0);
|
||||
}
|
||||
{
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding)
|
||||
.filter(new IsLessThan("meta_example_double", 1.5))
|
||||
.maxResults(100)
|
||||
.build();
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
|
||||
// then
|
||||
assertThat(matches).hasSize(1);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("BusyWait")
|
||||
// For test stability
|
||||
private void ensureSearchDataReady(long expectTotalHit) {
|
||||
long begin = System.currentTimeMillis();
|
||||
while (true) {
|
||||
SearchQuery searchQuery = new SearchQuery();
|
||||
searchQuery.setQuery(new MatchAllQuery());
|
||||
searchQuery.setLimit(0);
|
||||
SearchRequest searchRequest = new SearchRequest(embeddingStore.getTableName(), embeddingStore.getSearchIndexName(), searchQuery);
|
||||
searchQuery.setGetTotalCount(true);
|
||||
SearchResponse resp = embeddingStore.getClient().search(searchRequest);
|
||||
assertTrue(resp.isAllSuccess());
|
||||
if (resp.getTotalCount() == expectTotalHit) {
|
||||
log.info("ensureSearchDataReady totalHit:{}, expect:{}", resp.getTotalCount(), expectTotalHit);
|
||||
log.info("DataSyncTimeInMs:" + (System.currentTimeMillis() - begin));
|
||||
break;
|
||||
} else if (resp.getTotalCount() != 0) {
|
||||
log.info("ensureSearchDataReady totalHit:{}, expect:{}", resp.getTotalCount(), expectTotalHit);
|
||||
}
|
||||
if (System.currentTimeMillis() - begin > TimeUnit.SECONDS.toMillis(120)) {
|
||||
fail("ensureSearchDataReady timeout");
|
||||
}
|
||||
|
||||
try {
|
||||
Thread.sleep(500);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
// wait for replica
|
||||
try {
|
||||
Thread.sleep(WAIT_FOR_REPLICA_TIME_IN_MILLS);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.SyncClient;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
|
||||
import com.alicloud.openservices.tablestore.model.search.FieldType;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
|
||||
class TablestoreEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT {
|
||||
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
private final TablestoreEmbeddingStore embeddingStore;
|
||||
|
||||
TablestoreEmbeddingStoreRemovalIT() {
|
||||
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
|
||||
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
|
||||
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
|
||||
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
|
||||
this.embeddingStore = new TablestoreEmbeddingStore(
|
||||
new SyncClient(endpoint,
|
||||
accessKeyId,
|
||||
accessKeySecret,
|
||||
instanceName),
|
||||
384,
|
||||
Arrays.asList(
|
||||
new FieldSchema("meta_example_keyword", FieldType.KEYWORD),
|
||||
new FieldSchema("meta_example_long", FieldType.LONG),
|
||||
new FieldSchema("meta_example_double", FieldType.DOUBLE),
|
||||
new FieldSchema("meta_example_text", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.MaxWord)
|
||||
)
|
||||
) {
|
||||
@Override
|
||||
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
|
||||
if (request.maxResults() > 100) {
|
||||
request = new EmbeddingSearchRequest(request.queryEmbedding(), 100, request.minScore(), request.filter());
|
||||
}
|
||||
return super.search(request);
|
||||
}
|
||||
};
|
||||
this.embeddingStore.init();
|
||||
this.embeddingStore.removeAll();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package dev.langchain4j.store.embedding.tablestore;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Set;
|
||||
|
||||
import static dev.langchain4j.store.embedding.tablestore.TablestoreUtils.embeddingToString;
|
||||
import static dev.langchain4j.store.embedding.tablestore.TablestoreUtils.parseEmbeddingString;
|
||||
|
||||
|
||||
class TablestoreEmbeddingStoreTest {
|
||||
|
||||
@Test
|
||||
void test_parseEmbeddingString() {
|
||||
float[] floats = parseEmbeddingString(" [1,2,3,4, 5.678, 9.12345 , -0.0123] ");
|
||||
float[] expect = new float[]{1, 2, 3, 4, 5.678f, 9.12345f, -0.0123f};
|
||||
Assertions.assertArrayEquals(expect, floats, 0.000001f);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_embeddingToString() {
|
||||
float[] expect = new float[]{1, 2, 3, 4, 5.678f, 9.12345f, -0.0123f};
|
||||
String embeddingToString = embeddingToString(expect);
|
||||
Assertions.assertEquals("[1.0,2.0,3.0,4.0,5.678,9.12345,-0.0123]", embeddingToString);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testj_supported_value_types() throws Exception {
|
||||
Field field = Metadata.class.getDeclaredField("SUPPORTED_VALUE_TYPES");
|
||||
field.setAccessible(true);
|
||||
@SuppressWarnings("unchecked")
|
||||
Set<Class<?>> supportedValueTypes = (Set<Class<?>>) field.get(new Metadata());
|
||||
Assertions.assertEquals(10, supportedValueTypes.size(), "when Metadata#SUPPORTED_VALUE_TYPES add new types, we should modify:\n" +
|
||||
"1. write logic: rowToMetadata.\n" +
|
||||
"2. read logic: innerAdd");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,178 @@
|
|||
package dev.langchain4j.store.memory.chat.tablestore;
|
||||
|
||||
import com.alicloud.openservices.tablestore.SyncClient;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.Content;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.memory.ChatMemory;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
|
||||
class TablestoreChatMemoryStoreIT {
|
||||
|
||||
private final TablestoreChatMemoryStore chatMemoryStore;
|
||||
private final static String USER_ID = "someUserId";
|
||||
|
||||
|
||||
public TablestoreChatMemoryStoreIT() {
|
||||
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
|
||||
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
|
||||
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
|
||||
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
|
||||
|
||||
chatMemoryStore = new TablestoreChatMemoryStore(new SyncClient(
|
||||
endpoint,
|
||||
accessKeyId,
|
||||
accessKeySecret,
|
||||
instanceName)
|
||||
);
|
||||
|
||||
chatMemoryStore.init();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
@AfterEach
|
||||
void setUp() {
|
||||
chatMemoryStore.clear();
|
||||
List<ChatMessage> messages = chatMemoryStore.getMessages(USER_ID);
|
||||
assertThat(messages).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_should_insert_items() {
|
||||
// When
|
||||
String chatSessionId = "chat-" + UUID.randomUUID();
|
||||
|
||||
ChatMemory chatMemory = MessageWindowChatMemory.builder()
|
||||
.chatMemoryStore(chatMemoryStore)
|
||||
.maxMessages(100)
|
||||
.id(chatSessionId)
|
||||
.build();
|
||||
|
||||
// When
|
||||
UserMessage userMessage = userMessage("How are you?");
|
||||
chatMemory.add(userMessage);
|
||||
|
||||
AiMessage aiMessage = aiMessage("I am fine! Thank you!");
|
||||
chatMemory.add(aiMessage);
|
||||
|
||||
// Then
|
||||
assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_set_messages_into_tablestore() {
|
||||
// given
|
||||
List<ChatMessage> messages = chatMemoryStore.getMessages(USER_ID);
|
||||
assertThat(messages).isEmpty();
|
||||
|
||||
// when
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
|
||||
List<Content> userMsgContents = new ArrayList<>();
|
||||
userMsgContents.add(new ImageContent("someCatImageUrl"));
|
||||
chatMessages.add(new UserMessage("What do you see in this image?", userMsgContents));
|
||||
chatMemoryStore.updateMessages(USER_ID, chatMessages);
|
||||
|
||||
// then
|
||||
messages = chatMemoryStore.getMessages(USER_ID);
|
||||
assertThat(messages).hasSize(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_delete_messages_from_tablestore() {
|
||||
// given
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
|
||||
chatMemoryStore.updateMessages(USER_ID, chatMessages);
|
||||
List<ChatMessage> messages = chatMemoryStore.getMessages(USER_ID);
|
||||
assertThat(messages).hasSize(1);
|
||||
|
||||
// when
|
||||
chatMemoryStore.deleteMessages(USER_ID);
|
||||
|
||||
// then
|
||||
messages = chatMemoryStore.getMessages(USER_ID);
|
||||
assertThat(messages).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void getMessages_memoryId_null() {
|
||||
assertThatThrownBy(() -> chatMemoryStore.getMessages(null))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("memoryId cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void getMessages_memoryId_empty() {
|
||||
assertThatThrownBy(() -> chatMemoryStore.getMessages(" "))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("memoryId cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateMessages_messages_null() {
|
||||
assertThatThrownBy(() -> chatMemoryStore.updateMessages(USER_ID, null))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("messages cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateMessages_messages_empty() {
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
assertThatThrownBy(() -> chatMemoryStore.updateMessages(USER_ID, chatMessages))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("messages cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateMessages_memoryId_null() {
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
|
||||
assertThatThrownBy(() -> chatMemoryStore.updateMessages(null, chatMessages))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("memoryId cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateMessages_memoryId_empty() {
|
||||
List<ChatMessage> chatMessages = new ArrayList<>();
|
||||
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
|
||||
assertThatThrownBy(() -> chatMemoryStore.updateMessages(" ", chatMessages))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("memoryId cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void deleteMessages_memoryId_null() {
|
||||
assertThatThrownBy(() -> chatMemoryStore.deleteMessages(null))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("memoryId cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void deleteMessages_memoryId_empty() {
|
||||
assertThatThrownBy(() -> chatMemoryStore.deleteMessages(" "))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("memoryId cannot be null or empty");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
<configuration>
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern>[%d{yyyy-MM-dd HH:mm:ss.SSS}][%thread][%level][%logger][%line] - %msg%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<root level="INFO">
|
||||
<appender-ref ref="STDOUT"/>
|
||||
</root>
|
||||
|
||||
<logger name="dev.ai4j" level="DEBUG"/>
|
||||
<logger name="dev.langchain4j" level="DEBUG"/>
|
||||
</configuration>
|
Loading…
Reference in New Issue