Support Embedding for Large Amounts of Texts (#1142)

fix: #1140

<!-- Thank you so much for your contribution! -->
<!-- Please fill in all the sections below. -->

<!-- Please open the PR as a draft initially. Once it is reviewed and
approved, we will ask you to add documentation and examples. -->
<!-- Please note that PRs with breaking changes will be rejected. -->
<!-- Please note that PRs without tests will be rejected. -->

<!-- Please note that PRs will be reviewed based on the priority of the
issues they address. -->
<!-- We ask for your patience. We are doing our best to review your PR
as quickly as possible. -->
<!-- Please refrain from pinging and asking when it will be reviewed.
Thank you for understanding! -->


## Issue
<!-- Please paste the link to the issue this PR is addressing. For
example: https://github.com/langchain4j/langchain4j/issues/1012 -->


## Change
<!-- Please describe the changes you made. -->


## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)


## Checklist for adding new model integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)


## Checklist for adding new embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends
from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT`
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)


## Checklist for changing existing embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have manually verified that the
`{NameOfIntegration}EmbeddingStore` works correctly with the data
persisted using the latest released version of LangChain4j
This commit is contained in:
jiangsier-xyz 2024-05-24 20:43:13 +08:00 committed by GitHub
parent 463a3a3280
commit 4398395e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 4 deletions

View File

@ -24,6 +24,7 @@ public class QwenEmbeddingModel implements EmbeddingModel {
public static final String TYPE_KEY = "type"; public static final String TYPE_KEY = "type";
public static final String TYPE_QUERY = "query"; public static final String TYPE_QUERY = "query";
public static final String TYPE_DOCUMENT = "document"; public static final String TYPE_DOCUMENT = "document";
private static final int MAX_BATCH_SIZE = 25;
private final String apiKey; private final String apiKey;
private final String modelName; private final String modelName;
@ -53,7 +54,30 @@ public class QwenEmbeddingModel implements EmbeddingModel {
.anyMatch(TYPE_QUERY::equalsIgnoreCase); .anyMatch(TYPE_QUERY::equalsIgnoreCase);
} }
private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) { private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments,
TextEmbeddingParam.TextType textType) {
int size = textSegments.size();
if (size < MAX_BATCH_SIZE) {
return batchEmbedTexts(textSegments, textType);
}
List<Embedding> allEmbeddings = new ArrayList<>(size);
TokenUsage allUsage = null;
int fromIndex = 0;
int toIndex = MAX_BATCH_SIZE;
while (fromIndex < size) {
List<TextSegment> batchTextSegments = textSegments.subList(fromIndex, toIndex);
Response<List<Embedding>> batchResponse = batchEmbedTexts(batchTextSegments, textType);
allEmbeddings.addAll(batchResponse.content());
allUsage = TokenUsage.sum(allUsage, batchResponse.tokenUsage());
fromIndex = toIndex;
toIndex = Math.min(size, fromIndex + MAX_BATCH_SIZE);
}
return Response.from(allEmbeddings, allUsage);
}
private Response<List<Embedding>> batchEmbedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
TextEmbeddingParam param = TextEmbeddingParam.builder() TextEmbeddingParam param = TextEmbeddingParam.builder()
.apiKey(apiKey) .apiKey(apiKey)
.model(modelName) .model(modelName)

View File

@ -7,9 +7,14 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static dev.langchain4j.data.segment.TextSegment.textSegment; import static dev.langchain4j.data.segment.TextSegment.textSegment;
import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_KEY;
import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_QUERY;
import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey; import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -52,8 +57,8 @@ public class QwenEmbeddingModelIT {
void should_embed_queries(String modelName) { void should_embed_queries(String modelName) {
EmbeddingModel model = getModel(modelName); EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(asList( List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)),
textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)) textSegment("how are you?", Metadata.from(TYPE_KEY, TYPE_QUERY))
)).content(); )).content();
assertThat(embeddings).hasSize(2); assertThat(embeddings).hasSize(2);
@ -66,7 +71,7 @@ public class QwenEmbeddingModelIT {
void should_embed_mix_segments(String modelName) { void should_embed_mix_segments(String modelName) {
EmbeddingModel model = getModel(modelName); EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(asList( List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)),
textSegment("how are you?") textSegment("how are you?")
)).content(); )).content();
@ -74,4 +79,39 @@ public class QwenEmbeddingModelIT {
assertThat(embeddings.get(0).vector()).isNotEmpty(); assertThat(embeddings.get(0).vector()).isNotEmpty();
assertThat(embeddings.get(1).vector()).isNotEmpty(); assertThat(embeddings.get(1).vector()).isNotEmpty();
} }
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_large_amounts_of_documents(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(
Collections.nCopies(50, textSegment("hello"))).content();
assertThat(embeddings).hasSize(50);
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_large_amounts_of_queries(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(
Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)))
).content();
assertThat(embeddings).hasSize(50);
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_large_amounts_of_mix_segments(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(
Stream.concat(
Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))).stream(),
Collections.nCopies(50, textSegment("how are you?")).stream()
).collect(Collectors.toList())
).content();
assertThat(embeddings).hasSize(100);
}
} }