diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java index f6c61aaff..876959f9f 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java @@ -24,6 +24,7 @@ public class QwenEmbeddingModel implements EmbeddingModel { public static final String TYPE_KEY = "type"; public static final String TYPE_QUERY = "query"; public static final String TYPE_DOCUMENT = "document"; + private static final int MAX_BATCH_SIZE = 25; private final String apiKey; private final String modelName; @@ -53,7 +54,30 @@ public class QwenEmbeddingModel implements EmbeddingModel { .anyMatch(TYPE_QUERY::equalsIgnoreCase); } - private Response> embedTexts(List textSegments, TextEmbeddingParam.TextType textType) { + private Response> embedTexts(List textSegments, + TextEmbeddingParam.TextType textType) { + int size = textSegments.size(); + if (size < MAX_BATCH_SIZE) { + return batchEmbedTexts(textSegments, textType); + } + + List allEmbeddings = new ArrayList<>(size); + TokenUsage allUsage = null; + int fromIndex = 0; + int toIndex = MAX_BATCH_SIZE; + while (fromIndex < size) { + List batchTextSegments = textSegments.subList(fromIndex, toIndex); + Response> batchResponse = batchEmbedTexts(batchTextSegments, textType); + allEmbeddings.addAll(batchResponse.content()); + allUsage = TokenUsage.sum(allUsage, batchResponse.tokenUsage()); + fromIndex = toIndex; + toIndex = Math.min(size, fromIndex + MAX_BATCH_SIZE); + } + + return Response.from(allEmbeddings, allUsage); + } + + private Response> batchEmbedTexts(List textSegments, TextEmbeddingParam.TextType textType) { TextEmbeddingParam param = TextEmbeddingParam.builder() .apiKey(apiKey) .model(modelName) diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java index c3e67a47b..9b06c4f9d 100644 --- a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java @@ -7,9 +7,14 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static dev.langchain4j.data.segment.TextSegment.textSegment; +import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_KEY; +import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_QUERY; import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; @@ -52,8 +57,8 @@ public class QwenEmbeddingModelIT { void should_embed_queries(String modelName) { EmbeddingModel model = getModel(modelName); List embeddings = model.embedAll(asList( - textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), - textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)) + textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)), + textSegment("how are you?", Metadata.from(TYPE_KEY, TYPE_QUERY)) )).content(); assertThat(embeddings).hasSize(2); @@ -66,7 +71,7 @@ public class QwenEmbeddingModelIT { void should_embed_mix_segments(String modelName) { EmbeddingModel model = getModel(modelName); List embeddings = model.embedAll(asList( - textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), + textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)), textSegment("how are you?") )).content(); @@ -74,4 +79,39 @@ public class QwenEmbeddingModelIT { assertThat(embeddings.get(0).vector()).isNotEmpty(); assertThat(embeddings.get(1).vector()).isNotEmpty(); } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_large_amounts_of_documents(String modelName) { + EmbeddingModel model = getModel(modelName); + List embeddings = model.embedAll( + Collections.nCopies(50, textSegment("hello"))).content(); + + assertThat(embeddings).hasSize(50); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_large_amounts_of_queries(String modelName) { + EmbeddingModel model = getModel(modelName); + List embeddings = model.embedAll( + Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))) + ).content(); + + assertThat(embeddings).hasSize(50); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_large_amounts_of_mix_segments(String modelName) { + EmbeddingModel model = getModel(modelName); + List embeddings = model.embedAll( + Stream.concat( + Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))).stream(), + Collections.nCopies(50, textSegment("how are you?")).stream() + ).collect(Collectors.toList()) + ).content(); + + assertThat(embeddings).hasSize(100); + } }