Add langchain4j-tablestore Integration: TablestoreEmbeddingStore/TablestoreChatMemoryStore (#1650)

## Change
Add langchain4j-tablestore Integration: TablestoreEmbeddingStore /
TablestoreChatMemoryStore

## General checklist
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [x] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)

## Checklist for adding new embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends
from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT`
- [x] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
This commit is contained in:
ScriptShi 2024-09-18 17:41:53 +08:00 committed by GitHub
parent 9ea2e27337
commit be7454a7c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1893 additions and 0 deletions

View File

@ -25,6 +25,7 @@ sidebar_position: 0
| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ |
| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | ✅ | |
| [Redis](/integrations/embedding-stores/redis) | ✅ | | |
| [Tablestore](/integrations/embedding-stores/tablestore) | ✅ |✅ |✅ |
| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | |
| [Vespa](/integrations/embedding-stores/vespa) | | | |
| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ |

View File

@ -0,0 +1,26 @@
---
sidebar_position: 16
---
# Tablestore
https://www.aliyun.com/product/ots
## Maven Dependency
```xml
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-tablestore</artifactId>
<version>0.35.0</version>
</dependency>
```
## APIs
- `TablestoreEmbeddingStore`
## Examples
- [TablestoreEmbeddingStoreExampleIT](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-tablestore/src/test/java/dev/langchain4j/store/embedding/tablestore/TablestoreEmbeddingStoreExampleIT.java)

View File

@ -257,6 +257,12 @@
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-tablestore</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-vearch</artifactId>

View File

@ -0,0 +1,91 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.35.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-tablestore</artifactId>
<name>LangChain4j :: Integration :: Tablestore</name>
<dependencies>
<dependency>
<groupId>com.aliyun.openservices</groupId>
<artifactId>tablestore</artifactId>
<version>5.17.3</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<!-- junit-jupiter-params should be declared explicitly
to run parameterized tests inherited from EmbeddingStore*IT-->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>3.25.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,58 @@
package dev.langchain4j.store.embedding.tablestore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.Objects;
import java.util.StringJoiner;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
public class IsTextMatch implements Filter {
private final String key;
private final String comparisonValue;
public IsTextMatch(String key, String comparisonValue) {
this.key = ensureNotBlank(key, "key");
this.comparisonValue = ensureNotNull(comparisonValue, "comparisonValue with key '" + key + "'");
}
public String key() {
return key;
}
public String comparisonValue() {
return comparisonValue;
}
@Override
public boolean test(Object object) {
throw new UnsupportedOperationException("only used in search filters");
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof IsTextMatch)) {
return false;
}
IsTextMatch that = (IsTextMatch) o;
return Objects.equals(key, that.key) && Objects.equals(comparisonValue, that.comparisonValue);
}
@Override
public int hashCode() {
return Objects.hash(key, comparisonValue);
}
@Override
public String toString() {
return new StringJoiner(", ", IsTextMatch.class.getSimpleName() + "[", "]")
.add("key='" + key + "'")
.add("comparisonValue='" + comparisonValue + "'")
.toString();
}
}

View File

@ -0,0 +1,58 @@
package dev.langchain4j.store.embedding.tablestore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.Objects;
import java.util.StringJoiner;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
public class IsTextMatchPhrase implements Filter {
private final String key;
private final String comparisonValue;
public IsTextMatchPhrase(String key, String comparisonValue) {
this.key = ensureNotBlank(key, "key");
this.comparisonValue = ensureNotNull(comparisonValue, "comparisonValue with key '" + key + "'");
}
public String key() {
return key;
}
public String comparisonValue() {
return comparisonValue;
}
@Override
public boolean test(Object object) {
throw new UnsupportedOperationException("only used in search filters");
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof IsTextMatchPhrase)) {
return false;
}
IsTextMatchPhrase that = (IsTextMatchPhrase) o;
return Objects.equals(key, that.key) && Objects.equals(comparisonValue, that.comparisonValue);
}
@Override
public int hashCode() {
return Objects.hash(key, comparisonValue);
}
@Override
public String toString() {
return new StringJoiner(", ", IsTextMatchPhrase.class.getSimpleName() + "[", "]")
.add("key='" + key + "'")
.add("comparisonValue='" + comparisonValue + "'")
.toString();
}
}

View File

@ -0,0 +1,523 @@
package dev.langchain4j.store.embedding.tablestore;
import com.alicloud.openservices.tablestore.SyncClient;
import com.alicloud.openservices.tablestore.core.utils.ValueUtil;
import com.alicloud.openservices.tablestore.model.CapacityUnit;
import com.alicloud.openservices.tablestore.model.Column;
import com.alicloud.openservices.tablestore.model.ColumnType;
import com.alicloud.openservices.tablestore.model.ColumnValue;
import com.alicloud.openservices.tablestore.model.CreateTableRequest;
import com.alicloud.openservices.tablestore.model.DeleteRowRequest;
import com.alicloud.openservices.tablestore.model.DeleteTableRequest;
import com.alicloud.openservices.tablestore.model.Direction;
import com.alicloud.openservices.tablestore.model.GetRangeRequest;
import com.alicloud.openservices.tablestore.model.GetRangeResponse;
import com.alicloud.openservices.tablestore.model.ListTableResponse;
import com.alicloud.openservices.tablestore.model.PrimaryKey;
import com.alicloud.openservices.tablestore.model.PrimaryKeyBuilder;
import com.alicloud.openservices.tablestore.model.PrimaryKeySchema;
import com.alicloud.openservices.tablestore.model.PrimaryKeyType;
import com.alicloud.openservices.tablestore.model.PrimaryKeyValue;
import com.alicloud.openservices.tablestore.model.PutRowRequest;
import com.alicloud.openservices.tablestore.model.RangeRowQueryCriteria;
import com.alicloud.openservices.tablestore.model.ReservedThroughput;
import com.alicloud.openservices.tablestore.model.Row;
import com.alicloud.openservices.tablestore.model.RowDeleteChange;
import com.alicloud.openservices.tablestore.model.RowPutChange;
import com.alicloud.openservices.tablestore.model.TableMeta;
import com.alicloud.openservices.tablestore.model.TableOptions;
import com.alicloud.openservices.tablestore.model.search.CreateSearchIndexRequest;
import com.alicloud.openservices.tablestore.model.search.DeleteSearchIndexRequest;
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
import com.alicloud.openservices.tablestore.model.search.FieldType;
import com.alicloud.openservices.tablestore.model.search.IndexSchema;
import com.alicloud.openservices.tablestore.model.search.ListSearchIndexRequest;
import com.alicloud.openservices.tablestore.model.search.ListSearchIndexResponse;
import com.alicloud.openservices.tablestore.model.search.SearchHit;
import com.alicloud.openservices.tablestore.model.search.SearchIndexInfo;
import com.alicloud.openservices.tablestore.model.search.SearchQuery;
import com.alicloud.openservices.tablestore.model.search.SearchRequest;
import com.alicloud.openservices.tablestore.model.search.SearchResponse;
import com.alicloud.openservices.tablestore.model.search.query.KnnVectorQuery;
import com.alicloud.openservices.tablestore.model.search.query.Query;
import com.alicloud.openservices.tablestore.model.search.query.QueryBuilders;
import com.alicloud.openservices.tablestore.model.search.sort.ScoreSort;
import com.alicloud.openservices.tablestore.model.search.sort.Sort;
import com.alicloud.openservices.tablestore.model.search.vector.VectorDataType;
import com.alicloud.openservices.tablestore.model.search.vector.VectorMetricType;
import com.alicloud.openservices.tablestore.model.search.vector.VectorOptions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
public class TablestoreEmbeddingStore implements EmbeddingStore<TextSegment> {
private final Logger log = LoggerFactory.getLogger(getClass());
private final SyncClient client;
private final String tableName;
private final String searchIndexName;
private final String pkName;
private final String textField;
private final String embeddingField;
private final int vectorDimension;
private final VectorMetricType vectorMetricType;
private final List<FieldSchema> metadataSchemaList;
private static final String DEFAULT_TABLE_NAME = "langchain4j_embedding_store_ots_v1";
private static final String DEFAULT_INDEX_NAME = "langchain4j_embedding_ots_index_v1";
private static final String DEFAULT_TABLE_PK_NAME = "id";
private static final String DEFAULT_TEXT_FIELD_NAME = "default_content";
private static final String DEFAULT_VECTOR_FIELD_NAME = "default_embedding";
private static final VectorMetricType DEFAULT_VECTOR_METRIC_TYPE = VectorMetricType.COSINE;
public TablestoreEmbeddingStore(SyncClient client, int vectorDimension) {
this(client, vectorDimension, Collections.emptyList());
}
public TablestoreEmbeddingStore(SyncClient client, int vectorDimension, List<FieldSchema> metadataSchemaList) {
this(client, DEFAULT_TABLE_NAME, DEFAULT_INDEX_NAME, DEFAULT_TABLE_PK_NAME, DEFAULT_TEXT_FIELD_NAME, DEFAULT_VECTOR_FIELD_NAME, vectorDimension, DEFAULT_VECTOR_METRIC_TYPE, metadataSchemaList);
}
public TablestoreEmbeddingStore(SyncClient client, String tableName, String searchIndexName, String pkName, String textField, String embeddingField, int vectorDimension, VectorMetricType vectorMetricType, List<FieldSchema> metadataSchemaList) {
this.client = ValidationUtils.ensureNotNull(client, "client");
this.tableName = ValidationUtils.ensureNotBlank(tableName, "tableName");
this.searchIndexName = ValidationUtils.ensureNotBlank(searchIndexName, "searchIndexName");
this.pkName = ValidationUtils.ensureNotBlank(pkName, "pkName");
this.textField = ValidationUtils.ensureNotBlank(textField, "textField");
this.embeddingField = ValidationUtils.ensureNotBlank(embeddingField, "embeddingField");
this.vectorDimension = ValidationUtils.ensureGreaterThanZero(vectorDimension, "vectorDimension");
this.vectorMetricType = ValidationUtils.ensureNotNull(vectorMetricType, "vectorMetricType");
ValidationUtils.ensureNotNull(metadataSchemaList, "metadataSchemaList");
List<FieldSchema> tmpMetaList = new ArrayList<>();
tmpMetaList.add(new FieldSchema(textField, FieldType.TEXT).setIndex(true).setAnalyzer(FieldSchema.Analyzer.MaxWord));
tmpMetaList.add(new FieldSchema(embeddingField, FieldType.VECTOR).setIndex(true).setVectorOptions(new VectorOptions(VectorDataType.FLOAT_32, vectorDimension, vectorMetricType)));
for (FieldSchema fieldSchema : metadataSchemaList) {
if (fieldSchema.getFieldName().equals(textField)) {
throw Exceptions.illegalArgument("the custom meta data field name matches the system text field:{}", textField);
}
if (fieldSchema.getFieldName().equals(embeddingField)) {
throw Exceptions.illegalArgument("the custom meta data field name matches the system embedding field:{}", embeddingField);
}
tmpMetaList.add(fieldSchema);
}
this.metadataSchemaList = Collections.unmodifiableList(tmpMetaList);
}
public void init() {
createTableIfNotExist();
createSearchIndexIfNotExist();
}
public SyncClient getClient() {
return client;
}
public String getTableName() {
return tableName;
}
public String getSearchIndexName() {
return searchIndexName;
}
public String getPkName() {
return pkName;
}
public String getTextField() {
return textField;
}
public String getEmbeddingField() {
return embeddingField;
}
public int getVectorDimension() {
return vectorDimension;
}
public VectorMetricType getVectorMetricType() {
return vectorMetricType;
}
public List<FieldSchema> getMetadataSchemaList() {
return metadataSchemaList;
}
@Override
public String add(Embedding embedding) {
String id = UUID.randomUUID().toString();
innerAdd(id, embedding, null);
return id;
}
@Override
public void add(String id, Embedding embedding) {
innerAdd(id, embedding, null);
}
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = UUID.randomUUID().toString();
innerAdd(id, embedding, textSegment);
return id;
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
return addAll(embeddings, null);
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
if (embedded != null) {
ValidationUtils.ensureEq(embeddings.size(), embedded.size(), "the size of embeddings should be the same as the size of embedded");
}
List<String> ids = new ArrayList<>(embeddings.size());
List<Exception> exceptions = new ArrayList<>();
for (int i = 0; i < embeddings.size(); i++) {
Embedding embedding = embeddings.get(i);
TextSegment textSegment = null;
if (embedded != null) {
textSegment = embedded.get(i);
}
try {
String id = UUID.randomUUID().toString();
innerAdd(id, embedding, textSegment);
ids.add(id);
} catch (Exception e) {
exceptions.add(e);
}
}
if (!exceptions.isEmpty()) {
IllegalStateException exception = new IllegalStateException("Add all embeddings with error, failed:" + exceptions.size());
for (Exception e : exceptions) {
exception.addSuppressed(e);
}
throw exception;
}
return ids;
}
@Override
public void remove(String id) {
ensureNotBlank(id, "id");
innerDelete(id);
}
@Override
public void removeAll(Collection<String> ids) {
if (ids == null || ids.isEmpty()) {
throw Exceptions.illegalArgument("ids cannot be null or empty");
}
log.debug("remove all:{}", ids);
List<Exception> exceptions = new ArrayList<>();
for (String id : ids) {
try {
remove(id);
} catch (Exception e) {
exceptions.add(e);
}
}
if (!exceptions.isEmpty()) {
IllegalStateException exception = new IllegalStateException("remove all embeddings with error, failed:" + exceptions.size());
for (Exception e : exceptions) {
exception.addSuppressed(e);
}
throw exception;
}
}
@Override
public void removeAll(Filter filter) {
if (filter == null) {
throw Exceptions.illegalArgument("filter cannot be null");
}
forEachAllData(Collections.emptyList(), (row -> {
Metadata metadata = rowToMetadata(row);
if (filter.test(metadata)) {
remove(row.getPrimaryKey().getPrimaryKeyColumn(pkName).getValue().asString());
}
}));
}
@Override
public void removeAll() {
log.debug("remove all");
forEachAllData(Collections.emptyList(), (row) -> {
this.innerDelete(row.getPrimaryKey().getPrimaryKeyColumn(pkName).getValue().asString());
});
}
@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
log.debug("search ([...{}...], {}, {})", request.queryEmbedding().vector().length, request.maxResults(), request.minScore());
KnnVectorQuery knnVectorQuery = QueryBuilders.knnVector(embeddingField, request.maxResults(), request.queryEmbedding().vector())
.filter(mapFilterToQuery(request.filter()))
.build();
SearchQuery searchQuery = SearchQuery.newBuilder()
.query(knnVectorQuery)
.getTotalCount(false)
.limit(request.maxResults())
.offset(0)
.sort(new Sort(Collections.singletonList(new ScoreSort())))
.build();
SearchRequest searchRequest = SearchRequest.newBuilder()
.tableName(tableName)
.indexName(searchIndexName)
.searchQuery(searchQuery)
.returnAllColumns(true)
.build();
SearchResponse response = client.search(searchRequest);
log.debug("search requestId:{}", response.getRequestId());
return searchResponseToEmbeddingSearchResult(request, response);
}
protected Query mapFilterToQuery(Filter filter) {
return TablestoreMetadataFilterMapper.map(filter);
}
private EmbeddingSearchResult<TextSegment> searchResponseToEmbeddingSearchResult(EmbeddingSearchRequest request, SearchResponse response) {
List<SearchHit> searchHits = response.getSearchHits();
List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>(searchHits.size());
for (SearchHit hit : searchHits) {
Double score = hit.getScore();
if (score < request.minScore()) {
continue;
}
Row row = hit.getRow();
String text = null;
if (row.getLatestColumn(textField) != null) {
text = row.getLatestColumn(textField).getValue().asString();
}
float[] embedding = null;
if (row.getLatestColumn(embeddingField) != null) {
String embeddingString = row.getLatestColumn(embeddingField).getValue().asString();
embedding = TablestoreUtils.parseEmbeddingString(embeddingString);
}
Metadata metadata = rowToMetadata(row);
TextSegment textSegment = null;
if (text != null && embedding != null) {
textSegment = new TextSegment(text, metadata);
}
EmbeddingMatch<TextSegment> match = new EmbeddingMatch<TextSegment>(
score,
row.getPrimaryKey().getPrimaryKeyColumn(pkName).getValue().asString(),
new Embedding(embedding),
textSegment
);
matches.add(match);
}
return new EmbeddingSearchResult<>(matches);
}
private void createTableIfNotExist() {
if (tableExists()) {
log.info("table:{} already exists", tableName);
return;
}
TableMeta tableMeta = new TableMeta(this.tableName);
tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(pkName, PrimaryKeyType.STRING));
TableOptions tableOptions = new TableOptions(-1, 1);
CreateTableRequest request = new CreateTableRequest(tableMeta, tableOptions);
request.setReservedThroughput(new ReservedThroughput(new CapacityUnit(0, 0)));
client.createTable(request);
log.info("create table:{}", tableName);
}
private void createSearchIndexIfNotExist() {
if (searchindexExists()) {
log.info("index:{} already exists", searchIndexName);
return;
}
CreateSearchIndexRequest request = new CreateSearchIndexRequest();
request.setTableName(tableName);
request.setIndexName(searchIndexName);
IndexSchema indexSchema = new IndexSchema();
indexSchema.setFieldSchemas(metadataSchemaList);
request.setIndexSchema(indexSchema);
client.createSearchIndex(request);
log.info("create index:{}", searchIndexName);
}
protected void deleteTableAndIndex() {
List<SearchIndexInfo> searchIndexInfos = listSearchIndex();
deleteIndex(searchIndexInfos);
deleteTable();
}
private boolean tableExists() {
ListTableResponse listTableResponse = client.listTable();
return listTableResponse.getTableNames().contains(tableName);
}
private boolean searchindexExists() {
List<SearchIndexInfo> searchIndexInfos = listSearchIndex();
for (SearchIndexInfo indexInfo : searchIndexInfos) {
if (indexInfo.getIndexName().equals(searchIndexName)) {
return true;
}
}
return false;
}
private void deleteIndex(List<SearchIndexInfo> indexNames) {
indexNames.forEach(info -> {
DeleteSearchIndexRequest request = new DeleteSearchIndexRequest();
request.setTableName(info.getTableName());
request.setIndexName(info.getIndexName());
client.deleteSearchIndex(request);
log.info("delete table:{}, index:{}", info.getTableName(), info.getIndexName());
});
}
private void deleteTable() {
DeleteTableRequest request = new DeleteTableRequest(tableName);
client.deleteTable(request);
log.info("delete table:{}", tableName);
}
private List<SearchIndexInfo> listSearchIndex() {
ListSearchIndexRequest request = new ListSearchIndexRequest();
request.setTableName(tableName);
ListSearchIndexResponse listSearchIndexResponse = client.listSearchIndex(request);
return listSearchIndexResponse.getIndexInfos();
}
protected void innerAdd(String id, Embedding embedding, TextSegment textSegment) {
ValidationUtils.ensureNotNull(embedding, "embedding");
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.fromString(id));
PrimaryKey primaryKey = primaryKeyBuilder.build();
RowPutChange rowPutChange = new RowPutChange(this.tableName, primaryKey);
String embeddinged = TablestoreUtils.embeddingToString(embedding.vector());
rowPutChange.addColumn(new Column(this.embeddingField, ColumnValue.fromString(embeddinged)));
if (textSegment != null) {
String text = textSegment.text();
if (text != null) {
rowPutChange.addColumn(new Column(this.textField, ColumnValue.fromString(text)));
}
Metadata metadata = textSegment.metadata();
if (metadata != null) {
Map<String, Object> map = metadata.toMap();
for (Map.Entry<String, Object> entry : map.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
if (this.textField.equals(key)) {
throw Exceptions.illegalArgument("there is a metadata(%s,%s) that is consistent with the name of the text field:%s", key, value, this.textField);
}
if (this.embeddingField.equals(key)) {
throw Exceptions.illegalArgument("there is a metadata(%s,%s) that is consistent with the name of the vector field:%s", key, value, this.embeddingField);
}
if (value instanceof Float) {
rowPutChange.addColumn(new Column(key, ColumnValue.fromDouble((Float) value)));
} else if (value instanceof UUID) {
rowPutChange.addColumn(new Column(key, ColumnValue.fromString(((UUID) value).toString())));
} else {
rowPutChange.addColumn(new Column(key, ValueUtil.toColumnValue(value)));
}
}
}
}
try {
client.putRow(new PutRowRequest(rowPutChange));
if (log.isDebugEnabled()) {
log.debug("add id:{}, textSegment:{}, embedding:{}", id, textSegment, TablestoreUtils.maxLogOrNull(embedding.toString()));
}
} catch (Exception e) {
throw new RuntimeException(String.format("add embedding data failed, id:%s, textSegment:%s,embedding:%s", id, textSegment, embedding), e);
}
}
protected void innerDelete(String id) {
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.fromString(id));
PrimaryKey primaryKey = primaryKeyBuilder.build();
RowDeleteChange rowDeleteChange = new RowDeleteChange(this.tableName, primaryKey);
try {
client.deleteRow(new DeleteRowRequest(rowDeleteChange));
log.debug("delete id:{}", id);
} catch (Exception e) {
throw new RuntimeException(String.format("delete embedding data failed, id:%s", id), e);
}
}
private void forEachAllData(Collection<String> columnsToGet, Consumer<Row> rowConsumer) {
RangeRowQueryCriteria rangeRowQueryCriteria = new RangeRowQueryCriteria(this.tableName);
PrimaryKeyBuilder start = PrimaryKeyBuilder.createPrimaryKeyBuilder();
start.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.INF_MIN);
PrimaryKeyBuilder end = PrimaryKeyBuilder.createPrimaryKeyBuilder();
end.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.INF_MAX);
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(start.build());
rangeRowQueryCriteria.setExclusiveEndPrimaryKey(end.build());
rangeRowQueryCriteria.setMaxVersions(1);
rangeRowQueryCriteria.setLimit(5000);
rangeRowQueryCriteria.addColumnsToGet(columnsToGet);
rangeRowQueryCriteria.setDirection(Direction.FORWARD);
GetRangeRequest getRangeRequest = new GetRangeRequest(rangeRowQueryCriteria);
GetRangeResponse getRangeResponse;
while (true) {
getRangeResponse = client.getRange(getRangeRequest);
for (Row row : getRangeResponse.getRows()) {
rowConsumer.accept(row);
}
if (getRangeResponse.getNextStartPrimaryKey() != null) {
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(getRangeResponse.getNextStartPrimaryKey());
} else {
break;
}
}
}
private Metadata rowToMetadata(Row row) {
Metadata metadata = new Metadata();
for (Column column : row.getColumns()) {
if (column.getName().equals(embeddingField)) {
continue;
}
if (column.getName().equals(textField)) {
continue;
}
ColumnType columnType = column.getValue().getType();
switch (columnType) {
case STRING:
metadata.put(column.getName(), column.getValue().asString());
break;
case INTEGER:
metadata.put(column.getName(), column.getValue().asLong());
break;
case DOUBLE:
metadata.put(column.getName(), column.getValue().asDouble());
break;
default:
log.warn("unsupported columnType:{}, key:{}, value:{}", columnType, column.getName(), column.getValue());
}
}
return metadata;
}
}

View File

@ -0,0 +1,150 @@
package dev.langchain4j.store.embedding.tablestore;
import com.alicloud.openservices.tablestore.model.search.query.Query;
import com.alicloud.openservices.tablestore.model.search.query.QueryBuilders;
import com.alicloud.openservices.tablestore.model.search.query.TermsQuery;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
import dev.langchain4j.store.embedding.filter.logical.And;
import dev.langchain4j.store.embedding.filter.logical.Not;
import dev.langchain4j.store.embedding.filter.logical.Or;
import java.util.UUID;
class TablestoreMetadataFilterMapper {
static Query map(Filter filter) {
if (filter == null) {
return QueryBuilders.matchAll().build();
}
if (filter instanceof IsEqualTo) {
return mapEqual((IsEqualTo) filter);
} else if (filter instanceof IsNotEqualTo) {
return mapNotEqual((IsNotEqualTo) filter);
} else if (filter instanceof IsTextMatch) {
return mapMatch((IsTextMatch) filter);
} else if (filter instanceof IsTextMatchPhrase) {
return mapMatchPhrase((IsTextMatchPhrase) filter);
} else if (filter instanceof IsGreaterThan) {
return mapGreaterThan((IsGreaterThan) filter);
} else if (filter instanceof IsGreaterThanOrEqualTo) {
return mapGreaterThanOrEqual((IsGreaterThanOrEqualTo) filter);
} else if (filter instanceof IsLessThan) {
return mapLessThan((IsLessThan) filter);
} else if (filter instanceof IsLessThanOrEqualTo) {
return mapLessThanOrEqual((IsLessThanOrEqualTo) filter);
} else if (filter instanceof IsIn) {
return mapIn((IsIn) filter);
} else if (filter instanceof IsNotIn) {
return mapNotIn((IsNotIn) filter);
} else if (filter instanceof And) {
return mapAnd((And) filter);
} else if (filter instanceof Not) {
return mapNot((Not) filter);
} else if (filter instanceof Or) {
return mapOr((Or) filter);
} else {
throw new UnsupportedOperationException("Unsupported filter type: " + filter.getClass().getName());
}
}
private static Object transformType(Object object) {
if (object instanceof Float) {
object = ((Float) object).doubleValue();
}
if (object instanceof UUID) {
object = ((UUID) object).toString();
}
return object;
}
private static Query mapEqual(IsEqualTo isEqualTo) {
return QueryBuilders.term(isEqualTo.key(), transformType(isEqualTo.comparisonValue()))
.build();
}
private static Query mapMatch(IsTextMatch isTextMatch) {
return QueryBuilders.match(isTextMatch.key(), isTextMatch.comparisonValue())
.build();
}
private static Query mapMatchPhrase(IsTextMatchPhrase isTextMatchPhrase) {
return QueryBuilders.matchPhrase(isTextMatchPhrase.key(), isTextMatchPhrase.comparisonValue())
.build();
}
private static Query mapNotEqual(IsNotEqualTo isNotEqualTo) {
return QueryBuilders.bool()
.mustNot(QueryBuilders.term(isNotEqualTo.key(), transformType(isNotEqualTo.comparisonValue())))
.build();
}
private static Query mapGreaterThan(IsGreaterThan isGreaterThan) {
return QueryBuilders.range(isGreaterThan.key())
.greaterThan(transformType(isGreaterThan.comparisonValue()))
.build();
}
private static Query mapGreaterThanOrEqual(IsGreaterThanOrEqualTo isGreaterThanOrEqualTo) {
return QueryBuilders.range(isGreaterThanOrEqualTo.key())
.greaterThanOrEqual(transformType(isGreaterThanOrEqualTo.comparisonValue()))
.build();
}
private static Query mapLessThan(IsLessThan isLessThan) {
return QueryBuilders.range(isLessThan.key())
.lessThan(transformType(isLessThan.comparisonValue()))
.build();
}
private static Query mapLessThanOrEqual(IsLessThanOrEqualTo isLessThanOrEqualTo) {
return QueryBuilders.range(isLessThanOrEqualTo.key())
.lessThanOrEqual(transformType(isLessThanOrEqualTo.comparisonValue()))
.build();
}
private static Query mapIn(IsIn isIn) {
TermsQuery.Builder builder = QueryBuilders.terms(isIn.key());
for (Object object : isIn.comparisonValues()) {
builder.addTerm(transformType(object));
}
return builder.build();
}
private static Query mapNotIn(IsNotIn isNotIn) {
TermsQuery.Builder builder = QueryBuilders.terms(isNotIn.key());
for (Object object : isNotIn.comparisonValues()) {
builder.addTerm(transformType(object));
}
return QueryBuilders.bool()
.mustNot(builder.build())
.build();
}
private static Query mapAnd(And and) {
return QueryBuilders.bool()
.must(map(and.left()))
.must(map(and.right()))
.build();
}
private static Query mapNot(Not not) {
return QueryBuilders.bool()
.mustNot(map(not.expression()))
.build();
}
private static Query mapOr(Or or) {
return QueryBuilders.bool()
.should(map(or.left()))
.should(map(or.right()))
.build();
}
}

View File

@ -0,0 +1,33 @@
package dev.langchain4j.store.embedding.tablestore;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.ValidationUtils;
class TablestoreUtils {
private static final int MAX_DEBUG_LOG_LENGTH = 100;
private static final Gson GSON = new GsonBuilder().create();
protected static float[] parseEmbeddingString(String embeddingString) {
ValidationUtils.ensureNotBlank(embeddingString, "embeddingString");
return GSON.fromJson(embeddingString, float[].class);
}
protected static String embeddingToString(float[] embedding) {
ValidationUtils.ensureNotNull(embedding, "embedding");
return GSON.toJson(embedding);
}
protected static String maxLogOrNull(String str) {
if (str == null) {
return null;
}
int max = MAX_DEBUG_LOG_LENGTH;
if (str.length() <= max) {
return str;
}
return str.substring(0, max) + "......";
}
}

View File

@ -0,0 +1,235 @@
package dev.langchain4j.store.memory.chat.tablestore;
import com.alicloud.openservices.tablestore.SyncClient;
import com.alicloud.openservices.tablestore.model.CapacityUnit;
import com.alicloud.openservices.tablestore.model.Column;
import com.alicloud.openservices.tablestore.model.ColumnValue;
import com.alicloud.openservices.tablestore.model.CreateTableRequest;
import com.alicloud.openservices.tablestore.model.DeleteRowRequest;
import com.alicloud.openservices.tablestore.model.Direction;
import com.alicloud.openservices.tablestore.model.GetRangeRequest;
import com.alicloud.openservices.tablestore.model.GetRangeResponse;
import com.alicloud.openservices.tablestore.model.ListTableResponse;
import com.alicloud.openservices.tablestore.model.PrimaryKey;
import com.alicloud.openservices.tablestore.model.PrimaryKeyBuilder;
import com.alicloud.openservices.tablestore.model.PrimaryKeySchema;
import com.alicloud.openservices.tablestore.model.PrimaryKeyType;
import com.alicloud.openservices.tablestore.model.PrimaryKeyValue;
import com.alicloud.openservices.tablestore.model.PutRowRequest;
import com.alicloud.openservices.tablestore.model.RangeRowQueryCriteria;
import com.alicloud.openservices.tablestore.model.ReservedThroughput;
import com.alicloud.openservices.tablestore.model.Row;
import com.alicloud.openservices.tablestore.model.RowDeleteChange;
import com.alicloud.openservices.tablestore.model.RowPutChange;
import com.alicloud.openservices.tablestore.model.TableMeta;
import com.alicloud.openservices.tablestore.model.TableOptions;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageDeserializer;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;
public class TablestoreChatMemoryStore implements ChatMemoryStore {
private final Logger log = LoggerFactory.getLogger(getClass());
private final SyncClient client;
private final String tableName;
private final String pkName1;
private final String pkName2;
private final String chatMessageFieldName;
private static final String DEFAULT_TABLE_NAME = "langchain4j_chat_memory_store_ots_v1";
private static final String DEFAULT_TABLE_PK_1_NAME = "memory_id";
private static final String DEFAULT_TABLE_PK_2_NAME = "seq_no";
private static final String DEFAULT_CHAT_MESSAGE_FIELD_NAME = "chat_message";
public TablestoreChatMemoryStore(SyncClient client) {
this(client, DEFAULT_TABLE_NAME, DEFAULT_TABLE_PK_1_NAME, DEFAULT_TABLE_PK_2_NAME, DEFAULT_CHAT_MESSAGE_FIELD_NAME);
}
public TablestoreChatMemoryStore(SyncClient client, String tableName, String pkName1, String pkName2, String chatMessageFieldName) {
this.client = client;
this.tableName = tableName;
this.pkName1 = pkName1;
this.pkName2 = pkName2;
this.chatMessageFieldName = chatMessageFieldName;
}
public void init() {
createTableIfNotExist();
}
/**
* Clear all message.
*/
public void clear() {
forEachAllData(PrimaryKeyValue.INF_MIN, PrimaryKeyValue.INF_MAX, row -> {
String id = row.getPrimaryKey().getPrimaryKeyColumn(pkName1).getValue().asString();
long seqNo = row.getPrimaryKey().getPrimaryKeyColumn(pkName2).getValue().asLong();
innerDelete(id, seqNo);
}
);
}
@Override
public List<ChatMessage> getMessages(Object memoryId) {
String memoryIdStr = getMemoryId(memoryId);
log.debug("get messages, memoryIdStr:{}", memoryIdStr);
List<ChatMessage> messages = new ArrayList<>();
forEachAllData(PrimaryKeyValue.fromString(memoryIdStr), row -> {
Column column = row.getLatestColumn(chatMessageFieldName);
if (column != null) {
String jsonString = column.getValue().asString();
try {
ChatMessage chatMessage = ChatMessageDeserializer.messageFromJson(jsonString);
messages.add(chatMessage);
} catch (Exception e) {
String id = row.getPrimaryKey().getPrimaryKeyColumn(pkName1).getValue().asString();
long seqNo = row.getPrimaryKey().getPrimaryKeyColumn(pkName2).getValue().asLong();
throw new RuntimeException(String.format("unable to parse message body, memoryId:%s, seqNo:%s", id, seqNo), e);
}
}
});
return messages;
}
@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
String memoryIdStr = getMemoryId(memoryId);
log.debug("update messages, memoryIdStr:{}", memoryIdStr);
ValidationUtils.ensureNotEmpty(messages, "messages");
deleteMessages(memoryId);
List<Exception> exceptions = new ArrayList<>();
for (int i = 0; i < messages.size(); i++) {
ChatMessage message = messages.get(i);
try {
innerAdd(memoryIdStr, i, ChatMessageSerializer.messageToJson(message));
} catch (Exception e) {
exceptions.add(e);
}
}
if (!exceptions.isEmpty()) {
IllegalStateException exception = new IllegalStateException("update messages with error, failed:" + exceptions.size());
for (Exception e : exceptions) {
exception.addSuppressed(e);
}
throw exception;
}
}
@Override
public void deleteMessages(Object memoryId) {
String memoryIdStr = getMemoryId(memoryId);
log.debug("delete messages, memoryIdStr:{}", memoryIdStr);
forEachAllData(PrimaryKeyValue.fromString(memoryIdStr), row -> {
String id = row.getPrimaryKey().getPrimaryKeyColumn(pkName1).getValue().asString();
long seqNo = row.getPrimaryKey().getPrimaryKeyColumn(pkName2).getValue().asLong();
innerDelete(id, seqNo);
});
}
private void innerDelete(String memoryId, long seqNo) {
ValidationUtils.ensureNotNull(memoryId, "memoryId");
ValidationUtils.ensureNotNull(seqNo, "seqNo");
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName1, PrimaryKeyValue.fromString(memoryId));
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.fromLong(seqNo));
PrimaryKey primaryKey = primaryKeyBuilder.build();
RowDeleteChange rowDeleteChange = new RowDeleteChange(this.tableName, primaryKey);
try {
client.deleteRow(new DeleteRowRequest(rowDeleteChange));
log.debug("delete memoryId:{}, seqNo:{}", memoryId, seqNo);
} catch (Exception e) {
throw new RuntimeException(String.format("delete embedding data failed, memoryId:%s, seqNo:%s", memoryId, seqNo), e);
}
}
private void innerAdd(String memoryId, int seqNo, String chatMessage) {
ValidationUtils.ensureNotNull(memoryId, "memoryId");
ValidationUtils.ensureNotNull(seqNo, "seqNo");
ValidationUtils.ensureNotNull(chatMessage, "chatMessage");
PrimaryKeyBuilder primaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName1, PrimaryKeyValue.fromString(memoryId));
primaryKeyBuilder.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.fromLong(seqNo));
PrimaryKey primaryKey = primaryKeyBuilder.build();
RowPutChange rowPutChange = new RowPutChange(this.tableName, primaryKey);
rowPutChange.addColumn(new Column(chatMessageFieldName, ColumnValue.fromString(chatMessage)));
try {
client.putRow(new PutRowRequest(rowPutChange));
if (log.isDebugEnabled()) {
log.debug("add memoryId:{}, seqNo:{}, chatMessage:{}", memoryId, seqNo, chatMessage);
}
} catch (Exception e) {
throw new RuntimeException(String.format("add embedding data failed, memoryId:%s, seqNo:%s, chatMessage:%s", memoryId, seqNo, chatMessage), e);
}
}
private String getMemoryId(Object memoryId) {
boolean isNullOrEmpty = memoryId == null || memoryId.toString().trim().isEmpty();
if (isNullOrEmpty) {
throw new IllegalArgumentException("memoryId cannot be null or empty");
}
return memoryId.toString();
}
private void createTableIfNotExist() {
if (tableExists()) {
log.info("table:{} already exists", tableName);
return;
}
TableMeta tableMeta = new TableMeta(this.tableName);
tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(pkName1, PrimaryKeyType.STRING));
tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(pkName2, PrimaryKeyType.INTEGER));
TableOptions tableOptions = new TableOptions(-1, 1);
CreateTableRequest request = new CreateTableRequest(tableMeta, tableOptions);
request.setReservedThroughput(new ReservedThroughput(new CapacityUnit(0, 0)));
client.createTable(request);
log.info("create table:{}", tableName);
}
private boolean tableExists() {
ListTableResponse listTableResponse = client.listTable();
return listTableResponse.getTableNames().contains(tableName);
}
private void forEachAllData(PrimaryKeyValue memoryId, Consumer<Row> rowConsumer) {
forEachAllData(memoryId, memoryId, rowConsumer);
}
private void forEachAllData(PrimaryKeyValue memoryIdStart, PrimaryKeyValue memoryIdEnd, Consumer<Row> rowConsumer) {
RangeRowQueryCriteria rangeRowQueryCriteria = new RangeRowQueryCriteria(this.tableName);
PrimaryKeyBuilder start = PrimaryKeyBuilder.createPrimaryKeyBuilder();
start.addPrimaryKeyColumn(this.pkName1, memoryIdStart);
start.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.INF_MIN);
PrimaryKeyBuilder end = PrimaryKeyBuilder.createPrimaryKeyBuilder();
end.addPrimaryKeyColumn(this.pkName1, memoryIdEnd);
end.addPrimaryKeyColumn(this.pkName2, PrimaryKeyValue.INF_MAX);
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(start.build());
rangeRowQueryCriteria.setExclusiveEndPrimaryKey(end.build());
rangeRowQueryCriteria.setMaxVersions(1);
rangeRowQueryCriteria.setLimit(5000);
rangeRowQueryCriteria.addColumnsToGet(Collections.singletonList(chatMessageFieldName));
rangeRowQueryCriteria.setDirection(Direction.FORWARD);
GetRangeRequest getRangeRequest = new GetRangeRequest(rangeRowQueryCriteria);
GetRangeResponse getRangeResponse;
while (true) {
getRangeResponse = client.getRange(getRangeRequest);
for (Row row : getRangeResponse.getRows()) {
rowConsumer.accept(row);
}
if (getRangeResponse.getNextStartPrimaryKey() != null) {
rangeRowQueryCriteria.setInclusiveStartPrimaryKey(getRangeResponse.getNextStartPrimaryKey());
} else {
break;
}
}
}
}

View File

@ -0,0 +1,113 @@
package dev.langchain4j.store.embedding.tablestore;
import com.alicloud.openservices.tablestore.SyncClient;
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
import com.alicloud.openservices.tablestore.model.search.FieldType;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.Arrays;
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
class TablestoreEmbeddingStoreExampleIT {
@Test
void test_simple() {
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
/*
* Step 1: create a TablestoreEmbeddingStore.
*/
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
TablestoreEmbeddingStore embeddingStore = new TablestoreEmbeddingStore(
new SyncClient(endpoint,
accessKeyId,
accessKeySecret,
instanceName),
384,
Arrays.asList(
new FieldSchema("meta_example_keyword", FieldType.KEYWORD),
new FieldSchema("meta_example_long", FieldType.LONG),
new FieldSchema("meta_example_double", FieldType.DOUBLE),
new FieldSchema("meta_example_text", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.MaxWord)
)
);
/*
* Step 2: init.
*
* Note: It only needs to be executed once, and the first execution requires
* waiting for table and index initialization
*/
embeddingStore.init();
/*
* Step 3: Add some docs.
*/
TextSegment segment1 = TextSegment.from(
"I like football.",
new Metadata().put("meta_example_keyword", "a")
.put("meta_example_long", 123)
.put("meta_example_double", 1.5)
.put("meta_example_text", "dog cat")
);
Embedding embedding1 = embeddingModel.embed(segment1).content();
embeddingStore.add(embedding1, segment1);
TextSegment segment2 = TextSegment.from(
"The weather is good today.",
new Metadata().put("meta_example_keyword", "b")
.put("meta_example_long", 456)
.put("meta_example_double", 5.6)
.put("meta_example_text", "foo boo")
);
Embedding embedding2 = embeddingModel.embed(segment2).content();
embeddingStore.add(embedding2, segment2);
/*
* Step 4: Search
*/
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddingModel.embed("What is your favourite sport?").content())
.filter(new IsLessThan("meta_example_double", 0.5))
.maxResults(100)
.build();
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
// get result detail.
for (EmbeddingMatch<TextSegment> match : result.matches()) {
String embeddingId = match.embeddingId();
Double score = match.score();
Embedding embedding = match.embedding();
TextSegment embedded = match.embedded();
String text = embedded.text();
Metadata metadata = embedded.metadata();
Assertions.assertNotNull(embeddingId);
Assertions.assertNotNull(score);
Assertions.assertNotNull(embedding);
Assertions.assertNotNull(text);
Assertions.assertNotNull(metadata);
}
/*
* Step 5: Delete docs.
*/
embeddingStore.remove("id_example");
embeddingStore.removeAll();
}
}

View File

@ -0,0 +1,300 @@
package dev.langchain4j.store.embedding.tablestore;
import com.alicloud.openservices.tablestore.SyncClient;
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
import com.alicloud.openservices.tablestore.model.search.FieldType;
import com.alicloud.openservices.tablestore.model.search.SearchQuery;
import com.alicloud.openservices.tablestore.model.search.SearchRequest;
import com.alicloud.openservices.tablestore.model.search.SearchResponse;
import com.alicloud.openservices.tablestore.model.search.query.MatchAllQuery;
import com.alicloud.openservices.tablestore.model.search.query.Query;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
import org.apache.commons.lang3.function.TriConsumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
class TablestoreEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
private final Logger log = LoggerFactory.getLogger(getClass());
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
private final static long WAIT_FOR_REPLICA_TIME_IN_MILLS = TimeUnit.SECONDS.toMillis(3);
private final TablestoreEmbeddingStore embeddingStore;
private final AtomicLong trackDocsForTest = new AtomicLong(0);
TablestoreEmbeddingStoreIT() {
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
this.embeddingStore = new TablestoreEmbeddingStore(
new SyncClient(endpoint,
accessKeyId,
accessKeySecret,
instanceName),
384,
Arrays.asList(
new FieldSchema("name", FieldType.KEYWORD),
new FieldSchema("name2", FieldType.KEYWORD),
new FieldSchema("key", FieldType.KEYWORD),
new FieldSchema("key2", FieldType.KEYWORD),
new FieldSchema("city", FieldType.KEYWORD),
new FieldSchema("country", FieldType.KEYWORD),
new FieldSchema("age", FieldType.LONG),
new FieldSchema("age2", FieldType.LONG),
new FieldSchema("meta_example_double", FieldType.DOUBLE),
new FieldSchema("meta_example_text_max_word", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.MaxWord),
new FieldSchema("meta_example_text_fuzzy", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.Fuzzy)
)
) {
// Override for test
@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
if (request.maxResults() > 100) {
request = new EmbeddingSearchRequest(request.queryEmbedding(), 100, request.minScore(), request.filter());
}
return super.search(request);
}
// Override for test
@Override
protected void innerAdd(String id, Embedding embedding, TextSegment textSegment) {
super.innerAdd(id, embedding, textSegment);
trackDocsForTest.incrementAndGet();
}
// Override for test
@Override
protected void innerDelete(String id) {
super.innerDelete(id);
trackDocsForTest.decrementAndGet();
}
// Override for test: exclude the use of incorrect field types in base class testing
@Override
protected Query mapFilterToQuery(Filter filter) {
if (filter instanceof IsEqualTo) {
if (((IsEqualTo) filter).comparisonValue() instanceof Number) {
Assumptions.abort("keyword not support number");
}
}
if (filter instanceof IsNotEqualTo) {
if (((IsNotEqualTo) filter).comparisonValue() instanceof Number) {
Assumptions.abort("keyword not support number");
}
}
if (filter instanceof IsLessThan) {
IsLessThan t = (IsLessThan) filter;
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
Assumptions.abort("keyword not support number");
}
}
if (filter instanceof IsLessThanOrEqualTo) {
IsLessThanOrEqualTo t = (IsLessThanOrEqualTo) filter;
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
Assumptions.abort("keyword not support number");
}
}
if (filter instanceof IsGreaterThan) {
IsGreaterThan t = (IsGreaterThan) filter;
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
Assumptions.abort("keyword not support number");
}
}
if (filter instanceof IsGreaterThanOrEqualTo) {
IsGreaterThanOrEqualTo t = (IsGreaterThanOrEqualTo) filter;
if (t.key().contains("key") && t.comparisonValue() instanceof Number) {
Assumptions.abort("keyword not support number");
}
}
return super.mapFilterToQuery(filter);
}
};
this.embeddingStore.init();
this.embeddingStore.removeAll();
ensureSearchDataReady(0);
}
protected void awaitUntilPersisted() {
ensureSearchDataReady(trackDocsForTest.get());
}
@BeforeEach
@AfterEach
void setUp() {
trackDocsForTest.set(0);
clearStore();
}
@Override
protected void clearStore() {
embeddingStore.removeAll();
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Test
void test_match_and_match_phrase_and_double_range() {
Embedding embedding = embeddingModel().embed("ok").content();
embeddingStore.add(
embedding,
new TextSegment(
"ok",
new Metadata()
.put("meta_example_double", 1d)
.put("meta_example_text_max_word", "a b c ab ac")
.put("meta_example_text_fuzzy", "a b c abac")
)
);
awaitUntilPersisted();
TriConsumer<String, String, Integer> matchTester = (field, value, expectSize) -> {
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.filter(new IsTextMatch(field, value))
.maxResults(100)
.build();
// when
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
// then
assertThat(matches).hasSize(expectSize);
};
matchTester.accept("meta_example_text_max_word", "a b c", 1);
matchTester.accept("meta_example_text_max_word", "ac", 1);
matchTester.accept("meta_example_text_max_word", "abc", 0);
matchTester.accept("meta_example_text_max_word", "abac", 0);
matchTester.accept("meta_example_text_max_word", "ab", 1);
matchTester.accept("meta_example_text_fuzzy", "a b c", 1);
matchTester.accept("meta_example_text_fuzzy", "ac", 1);
matchTester.accept("meta_example_text_fuzzy", "abc", 0);
matchTester.accept("meta_example_text_fuzzy", "abac", 1);
matchTester.accept("meta_example_text_fuzzy", "ab", 1);
{
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.filter(new IsTextMatchPhrase("meta_example_text_fuzzy", "a b c abac"))
.maxResults(100)
.build();
// when
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
// then
assertThat(matches).hasSize(1);
}
{
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.filter(new IsGreaterThan("meta_example_double", 0.5))
.maxResults(100)
.build();
// when
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
// then
assertThat(matches).hasSize(1);
}
{
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.filter(new IsLessThan("meta_example_double", 0.5))
.maxResults(100)
.build();
// when
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
// then
assertThat(matches).hasSize(0);
}
{
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.filter(new IsLessThan("meta_example_double", 1.5))
.maxResults(100)
.build();
// when
List<EmbeddingMatch<TextSegment>> matches = embeddingStore().search(embeddingSearchRequest).matches();
// then
assertThat(matches).hasSize(1);
}
}
@SuppressWarnings("BusyWait")
// For test stability
private void ensureSearchDataReady(long expectTotalHit) {
long begin = System.currentTimeMillis();
while (true) {
SearchQuery searchQuery = new SearchQuery();
searchQuery.setQuery(new MatchAllQuery());
searchQuery.setLimit(0);
SearchRequest searchRequest = new SearchRequest(embeddingStore.getTableName(), embeddingStore.getSearchIndexName(), searchQuery);
searchQuery.setGetTotalCount(true);
SearchResponse resp = embeddingStore.getClient().search(searchRequest);
assertTrue(resp.isAllSuccess());
if (resp.getTotalCount() == expectTotalHit) {
log.info("ensureSearchDataReady totalHit:{}, expect:{}", resp.getTotalCount(), expectTotalHit);
log.info("DataSyncTimeInMs:" + (System.currentTimeMillis() - begin));
break;
} else if (resp.getTotalCount() != 0) {
log.info("ensureSearchDataReady totalHit:{}, expect:{}", resp.getTotalCount(), expectTotalHit);
}
if (System.currentTimeMillis() - begin > TimeUnit.SECONDS.toMillis(120)) {
fail("ensureSearchDataReady timeout");
}
try {
Thread.sleep(500);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
// wait for replica
try {
Thread.sleep(WAIT_FOR_REPLICA_TIME_IN_MILLS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,66 @@
package dev.langchain4j.store.embedding.tablestore;
import com.alicloud.openservices.tablestore.SyncClient;
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
import com.alicloud.openservices.tablestore.model.search.FieldType;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.Arrays;
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
class TablestoreEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT {
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
private final TablestoreEmbeddingStore embeddingStore;
TablestoreEmbeddingStoreRemovalIT() {
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
this.embeddingStore = new TablestoreEmbeddingStore(
new SyncClient(endpoint,
accessKeyId,
accessKeySecret,
instanceName),
384,
Arrays.asList(
new FieldSchema("meta_example_keyword", FieldType.KEYWORD),
new FieldSchema("meta_example_long", FieldType.LONG),
new FieldSchema("meta_example_double", FieldType.DOUBLE),
new FieldSchema("meta_example_text", FieldType.TEXT).setAnalyzer(FieldSchema.Analyzer.MaxWord)
)
) {
@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
if (request.maxResults() > 100) {
request = new EmbeddingSearchRequest(request.queryEmbedding(), 100, request.minScore(), request.filter());
}
return super.search(request);
}
};
this.embeddingStore.init();
this.embeddingStore.removeAll();
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
}

View File

@ -0,0 +1,40 @@
package dev.langchain4j.store.embedding.tablestore;
import dev.langchain4j.data.document.Metadata;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Field;
import java.util.Set;
import static dev.langchain4j.store.embedding.tablestore.TablestoreUtils.embeddingToString;
import static dev.langchain4j.store.embedding.tablestore.TablestoreUtils.parseEmbeddingString;
class TablestoreEmbeddingStoreTest {
@Test
void test_parseEmbeddingString() {
float[] floats = parseEmbeddingString(" [1,2,3,4, 5.678, 9.12345 , -0.0123] ");
float[] expect = new float[]{1, 2, 3, 4, 5.678f, 9.12345f, -0.0123f};
Assertions.assertArrayEquals(expect, floats, 0.000001f);
}
@Test
void test_embeddingToString() {
float[] expect = new float[]{1, 2, 3, 4, 5.678f, 9.12345f, -0.0123f};
String embeddingToString = embeddingToString(expect);
Assertions.assertEquals("[1.0,2.0,3.0,4.0,5.678,9.12345,-0.0123]", embeddingToString);
}
@Test
void testj_supported_value_types() throws Exception {
Field field = Metadata.class.getDeclaredField("SUPPORTED_VALUE_TYPES");
field.setAccessible(true);
@SuppressWarnings("unchecked")
Set<Class<?>> supportedValueTypes = (Set<Class<?>>) field.get(new Metadata());
Assertions.assertEquals(10, supportedValueTypes.size(), "when Metadata#SUPPORTED_VALUE_TYPES add new types, we should modify:\n" +
"1. write logic: rowToMetadata.\n" +
"2. read logic: innerAdd");
}
}

View File

@ -0,0 +1,178 @@
package dev.langchain4j.store.memory.chat.tablestore;
import com.alicloud.openservices.tablestore.SyncClient;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ENDPOINT", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_INSTANCE_NAME", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_ID", matches = ".+")
@EnabledIfEnvironmentVariable(named = "TABLESTORE_ACCESS_KEY_SECRET", matches = ".+")
class TablestoreChatMemoryStoreIT {
private final TablestoreChatMemoryStore chatMemoryStore;
private final static String USER_ID = "someUserId";
public TablestoreChatMemoryStoreIT() {
String endpoint = System.getenv("TABLESTORE_ENDPOINT");
String instanceName = System.getenv("TABLESTORE_INSTANCE_NAME");
String accessKeyId = System.getenv("TABLESTORE_ACCESS_KEY_ID");
String accessKeySecret = System.getenv("TABLESTORE_ACCESS_KEY_SECRET");
chatMemoryStore = new TablestoreChatMemoryStore(new SyncClient(
endpoint,
accessKeyId,
accessKeySecret,
instanceName)
);
chatMemoryStore.init();
}
@BeforeEach
@AfterEach
void setUp() {
chatMemoryStore.clear();
List<ChatMessage> messages = chatMemoryStore.getMessages(USER_ID);
assertThat(messages).isEmpty();
}
@Test
void test_should_insert_items() {
// When
String chatSessionId = "chat-" + UUID.randomUUID();
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryStore(chatMemoryStore)
.maxMessages(100)
.id(chatSessionId)
.build();
// When
UserMessage userMessage = userMessage("How are you?");
chatMemory.add(userMessage);
AiMessage aiMessage = aiMessage("I am fine! Thank you!");
chatMemory.add(aiMessage);
// Then
assertThat(chatMemory.messages()).containsExactly(userMessage, aiMessage);
}
@Test
void should_set_messages_into_tablestore() {
// given
List<ChatMessage> messages = chatMemoryStore.getMessages(USER_ID);
assertThat(messages).isEmpty();
// when
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
List<Content> userMsgContents = new ArrayList<>();
userMsgContents.add(new ImageContent("someCatImageUrl"));
chatMessages.add(new UserMessage("What do you see in this image?", userMsgContents));
chatMemoryStore.updateMessages(USER_ID, chatMessages);
// then
messages = chatMemoryStore.getMessages(USER_ID);
assertThat(messages).hasSize(2);
}
@Test
void should_delete_messages_from_tablestore() {
// given
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
chatMemoryStore.updateMessages(USER_ID, chatMessages);
List<ChatMessage> messages = chatMemoryStore.getMessages(USER_ID);
assertThat(messages).hasSize(1);
// when
chatMemoryStore.deleteMessages(USER_ID);
// then
messages = chatMemoryStore.getMessages(USER_ID);
assertThat(messages).isEmpty();
}
@Test
void getMessages_memoryId_null() {
assertThatThrownBy(() -> chatMemoryStore.getMessages(null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("memoryId cannot be null or empty");
}
@Test
void getMessages_memoryId_empty() {
assertThatThrownBy(() -> chatMemoryStore.getMessages(" "))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("memoryId cannot be null or empty");
}
@Test
void updateMessages_messages_null() {
assertThatThrownBy(() -> chatMemoryStore.updateMessages(USER_ID, null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("messages cannot be null or empty");
}
@Test
void updateMessages_messages_empty() {
List<ChatMessage> chatMessages = new ArrayList<>();
assertThatThrownBy(() -> chatMemoryStore.updateMessages(USER_ID, chatMessages))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("messages cannot be null or empty");
}
@Test
void updateMessages_memoryId_null() {
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
assertThatThrownBy(() -> chatMemoryStore.updateMessages(null, chatMessages))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("memoryId cannot be null or empty");
}
@Test
void updateMessages_memoryId_empty() {
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(new SystemMessage("You are a large language model working with Langchain4j"));
assertThatThrownBy(() -> chatMemoryStore.updateMessages(" ", chatMessages))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("memoryId cannot be null or empty");
}
@Test
void deleteMessages_memoryId_null() {
assertThatThrownBy(() -> chatMemoryStore.deleteMessages(null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("memoryId cannot be null or empty");
}
@Test
void deleteMessages_memoryId_empty() {
assertThatThrownBy(() -> chatMemoryStore.deleteMessages(" "))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("memoryId cannot be null or empty");
}
}

View File

@ -0,0 +1,14 @@
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>[%d{yyyy-MM-dd HH:mm:ss.SSS}][%thread][%level][%logger][%line] - %msg%n</pattern>
</encoder>
</appender>
<root level="INFO">
<appender-ref ref="STDOUT"/>
</root>
<logger name="dev.ai4j" level="DEBUG"/>
<logger name="dev.langchain4j" level="DEBUG"/>
</configuration>

View File

@ -65,6 +65,7 @@
<module>langchain4j-pinecone</module>
<module>langchain4j-qdrant</module>
<module>langchain4j-redis</module>
<module>langchain4j-tablestore</module>
<module>langchain4j-vearch</module>
<module>langchain4j-vespa</module>
<module>langchain4j-weaviate</module>