Support Embedding for Large Amounts of Texts (#1142)
fix: #1140 <!-- Thank you so much for your contribution! --> <!-- Please fill in all the sections below. --> <!-- Please open the PR as a draft initially. Once it is reviewed and approved, we will ask you to add documentation and examples. --> <!-- Please note that PRs with breaking changes will be rejected. --> <!-- Please note that PRs without tests will be rejected. --> <!-- Please note that PRs will be reviewed based on the priority of the issues they address. --> <!-- We ask for your patience. We are doing our best to review your PR as quickly as possible. --> <!-- Please refrain from pinging and asking when it will be reviewed. Thank you for understanding! --> ## Issue <!-- Please paste the link to the issue this PR is addressing. For example: https://github.com/langchain4j/langchain4j/issues/1012 --> ## Change <!-- Please describe the changes you made. --> ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [x] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) ## Checklist for adding new model integration <!-- Please double-check the following points and mark them like this: [X] --> - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) ## Checklist for adding new embedding store integration <!-- Please double-check the following points and mark them like this: [X] --> - [ ] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT` - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) ## Checklist for changing existing embedding store integration <!-- Please double-check the following points and mark them like this: [X] --> - [ ] I have manually verified that the `{NameOfIntegration}EmbeddingStore` works correctly with the data persisted using the latest released version of LangChain4j
This commit is contained in:
parent
463a3a3280
commit
4398395e58
|
@ -24,6 +24,7 @@ public class QwenEmbeddingModel implements EmbeddingModel {
|
||||||
public static final String TYPE_KEY = "type";
|
public static final String TYPE_KEY = "type";
|
||||||
public static final String TYPE_QUERY = "query";
|
public static final String TYPE_QUERY = "query";
|
||||||
public static final String TYPE_DOCUMENT = "document";
|
public static final String TYPE_DOCUMENT = "document";
|
||||||
|
private static final int MAX_BATCH_SIZE = 25;
|
||||||
|
|
||||||
private final String apiKey;
|
private final String apiKey;
|
||||||
private final String modelName;
|
private final String modelName;
|
||||||
|
@ -53,7 +54,30 @@ public class QwenEmbeddingModel implements EmbeddingModel {
|
||||||
.anyMatch(TYPE_QUERY::equalsIgnoreCase);
|
.anyMatch(TYPE_QUERY::equalsIgnoreCase);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
|
private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments,
|
||||||
|
TextEmbeddingParam.TextType textType) {
|
||||||
|
int size = textSegments.size();
|
||||||
|
if (size < MAX_BATCH_SIZE) {
|
||||||
|
return batchEmbedTexts(textSegments, textType);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Embedding> allEmbeddings = new ArrayList<>(size);
|
||||||
|
TokenUsage allUsage = null;
|
||||||
|
int fromIndex = 0;
|
||||||
|
int toIndex = MAX_BATCH_SIZE;
|
||||||
|
while (fromIndex < size) {
|
||||||
|
List<TextSegment> batchTextSegments = textSegments.subList(fromIndex, toIndex);
|
||||||
|
Response<List<Embedding>> batchResponse = batchEmbedTexts(batchTextSegments, textType);
|
||||||
|
allEmbeddings.addAll(batchResponse.content());
|
||||||
|
allUsage = TokenUsage.sum(allUsage, batchResponse.tokenUsage());
|
||||||
|
fromIndex = toIndex;
|
||||||
|
toIndex = Math.min(size, fromIndex + MAX_BATCH_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response.from(allEmbeddings, allUsage);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Response<List<Embedding>> batchEmbedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
|
||||||
TextEmbeddingParam param = TextEmbeddingParam.builder()
|
TextEmbeddingParam param = TextEmbeddingParam.builder()
|
||||||
.apiKey(apiKey)
|
.apiKey(apiKey)
|
||||||
.model(modelName)
|
.model(modelName)
|
||||||
|
|
|
@ -7,9 +7,14 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static dev.langchain4j.data.segment.TextSegment.textSegment;
|
import static dev.langchain4j.data.segment.TextSegment.textSegment;
|
||||||
|
import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_KEY;
|
||||||
|
import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_QUERY;
|
||||||
import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey;
|
import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey;
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
@ -52,8 +57,8 @@ public class QwenEmbeddingModelIT {
|
||||||
void should_embed_queries(String modelName) {
|
void should_embed_queries(String modelName) {
|
||||||
EmbeddingModel model = getModel(modelName);
|
EmbeddingModel model = getModel(modelName);
|
||||||
List<Embedding> embeddings = model.embedAll(asList(
|
List<Embedding> embeddings = model.embedAll(asList(
|
||||||
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
|
textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)),
|
||||||
textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY))
|
textSegment("how are you?", Metadata.from(TYPE_KEY, TYPE_QUERY))
|
||||||
)).content();
|
)).content();
|
||||||
|
|
||||||
assertThat(embeddings).hasSize(2);
|
assertThat(embeddings).hasSize(2);
|
||||||
|
@ -66,7 +71,7 @@ public class QwenEmbeddingModelIT {
|
||||||
void should_embed_mix_segments(String modelName) {
|
void should_embed_mix_segments(String modelName) {
|
||||||
EmbeddingModel model = getModel(modelName);
|
EmbeddingModel model = getModel(modelName);
|
||||||
List<Embedding> embeddings = model.embedAll(asList(
|
List<Embedding> embeddings = model.embedAll(asList(
|
||||||
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
|
textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)),
|
||||||
textSegment("how are you?")
|
textSegment("how are you?")
|
||||||
)).content();
|
)).content();
|
||||||
|
|
||||||
|
@ -74,4 +79,39 @@ public class QwenEmbeddingModelIT {
|
||||||
assertThat(embeddings.get(0).vector()).isNotEmpty();
|
assertThat(embeddings.get(0).vector()).isNotEmpty();
|
||||||
assertThat(embeddings.get(1).vector()).isNotEmpty();
|
assertThat(embeddings.get(1).vector()).isNotEmpty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||||
|
void should_embed_large_amounts_of_documents(String modelName) {
|
||||||
|
EmbeddingModel model = getModel(modelName);
|
||||||
|
List<Embedding> embeddings = model.embedAll(
|
||||||
|
Collections.nCopies(50, textSegment("hello"))).content();
|
||||||
|
|
||||||
|
assertThat(embeddings).hasSize(50);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||||
|
void should_embed_large_amounts_of_queries(String modelName) {
|
||||||
|
EmbeddingModel model = getModel(modelName);
|
||||||
|
List<Embedding> embeddings = model.embedAll(
|
||||||
|
Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)))
|
||||||
|
).content();
|
||||||
|
|
||||||
|
assertThat(embeddings).hasSize(50);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||||
|
void should_embed_large_amounts_of_mix_segments(String modelName) {
|
||||||
|
EmbeddingModel model = getModel(modelName);
|
||||||
|
List<Embedding> embeddings = model.embedAll(
|
||||||
|
Stream.concat(
|
||||||
|
Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))).stream(),
|
||||||
|
Collections.nCopies(50, textSegment("how are you?")).stream()
|
||||||
|
).collect(Collectors.toList())
|
||||||
|
).content();
|
||||||
|
|
||||||
|
assertThat(embeddings).hasSize(100);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue