From d908f5158ae0abdb0571e68e003311d67d57beca Mon Sep 17 00:00:00 2001
From: jiangsier-xyz <126842484+jiangsier-xyz@users.noreply.github.com>
Date: Sat, 19 Aug 2023 02:49:50 +0800
Subject: [PATCH] Integrate the Qwen series models via dashscope-sdk. (#99)
Qwen series models are provided by Alibaba Cloud. They are much better
in Asia languages then other LLMs.
DashScope is a model service platform. Qwen models are its primary
supported models. But it also supports other series like LLaMA2, Dolly,
ChatGLM, BiLLa(based on LLaMA)...These may be integrated sometime in the
future.
---
langchain4j-core/pom.xml | 2 +-
langchain4j-dashscope/pom.xml | 76 ++++++++++
.../model/dashscope/QwenChatModel.java | 137 ++++++++++++++++++
.../model/dashscope/QwenEmbeddingModel.java | 127 ++++++++++++++++
.../model/dashscope/QwenLanguageModel.java | 53 +++++++
.../model/dashscope/QwenModelName.java | 15 ++
.../model/dashscope/QwenParamHelper.java | 44 ++++++
.../dashscope/QwenStreamingChatModel.java | 109 ++++++++++++++
.../dashscope/QwenStreamingLanguageModel.java | 54 +++++++
.../model/dashscope/QwenChatModelIT.java | 27 ++++
.../model/dashscope/QwenEmbeddingModelIT.java | 88 +++++++++++
.../model/dashscope/QwenLanguageModelIT.java | 24 +++
.../dashscope/QwenStreamingChatModelIT.java | 53 +++++++
.../QwenStreamingLanguageModelIT.java | 53 +++++++
.../model/dashscope/QwenTestHelper.java | 43 ++++++
langchain4j-parent/pom.xml | 3 +-
langchain4j-pinecone/pom.xml | 4 +-
langchain4j-spring-boot-starter/pom.xml | 4 +-
langchain4j-weaviate/pom.xml | 4 +-
langchain4j/pom.xml | 4 +-
pom.xml | 12 +-
21 files changed, 922 insertions(+), 14 deletions(-)
create mode 100644 langchain4j-dashscope/pom.xml
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java
create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java
create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java
create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java
create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java
create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java
create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java
create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java
diff --git a/langchain4j-core/pom.xml b/langchain4j-core/pom.xml
index dc48c31bf..2e2953b8b 100644
--- a/langchain4j-core/pom.xml
+++ b/langchain4j-core/pom.xml
@@ -7,7 +7,7 @@
dev.langchain4j
langchain4j-parent
- 0.20.0
+ ${revision}
../langchain4j-parent/pom.xml
diff --git a/langchain4j-dashscope/pom.xml b/langchain4j-dashscope/pom.xml
new file mode 100644
index 000000000..fbf02cfd2
--- /dev/null
+++ b/langchain4j-dashscope/pom.xml
@@ -0,0 +1,76 @@
+
+
+ 4.0.0
+
+
+ dev.langchain4j
+ langchain4j-parent
+ ${revision}
+ ../langchain4j-parent/pom.xml
+
+
+
+ langchain4j-dashscope
+ jar
+
+ LangChain4j integration with DashScope
+ It uses the dashscope-sdk-java library, which has an Apache-2.0 license.
+
+
+
+
+ dev.langchain4j
+ langchain4j-core
+ ${project.version}
+
+
+
+ com.alibaba
+ dashscope-sdk-java
+ 2.1.1
+
+
+
+ org.junit.jupiter
+ junit-jupiter-engine
+ test
+
+
+
+ org.junit.jupiter
+ junit-jupiter-params
+ test
+
+
+
+ org.assertj
+ assertj-core
+ test
+
+
+
+
+
+
+ org.honton.chas
+ license-maven-plugin
+
+
+ true
+
+
+
+
+
+
+
+ Apache-2.0
+ https://www.apache.org/licenses/LICENSE-2.0.txt
+ repo
+ A business-friendly OSS license
+
+
+
+
\ No newline at end of file
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java
new file mode 100644
index 000000000..f93d4e361
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java
@@ -0,0 +1,137 @@
+package dev.langchain4j.model.dashscope;
+
+import com.alibaba.dashscope.aigc.generation.Generation;
+import com.alibaba.dashscope.aigc.generation.GenerationOutput;
+import com.alibaba.dashscope.aigc.generation.GenerationResult;
+import com.alibaba.dashscope.aigc.generation.models.QwenParam;
+import com.alibaba.dashscope.exception.InputRequiredException;
+import com.alibaba.dashscope.exception.NoApiKeyException;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+
+import java.util.List;
+import java.util.Optional;
+
+public class QwenChatModel implements ChatLanguageModel {
+ protected final Generation gen;
+ protected final String apiKey;
+ protected final String modelName;
+ protected final Double topP;
+ protected final Double topK;
+ protected final Boolean enableSearch;
+ protected final Integer seed;
+
+ protected QwenChatModel(String apiKey,
+ String modelName,
+ Double topP,
+ Double topK,
+ Boolean enableSearch,
+ Integer seed) {
+
+ this.apiKey = apiKey;
+ this.modelName = modelName;
+ this.topP = topP;
+ this.topK = topK;
+ this.enableSearch = enableSearch;
+ this.seed = seed;
+ gen = new Generation();
+ }
+
+ @Override
+ public AiMessage sendMessages(List messages) {
+ return AiMessage.aiMessage(sendMessage(QwenParamHelper.toQwenPrompt(messages)));
+ }
+
+ protected String sendMessage(String prompt) {
+ QwenParam param = QwenParam.builder()
+ .apiKey(apiKey)
+ .model(modelName)
+ .topP(topP)
+ .topK(topK)
+ .enableSearch(enableSearch)
+ .seed(seed)
+ .prompt(prompt)
+ .build();
+
+ try {
+ GenerationResult result = gen.call(param);
+ return Optional.of(result)
+ .map(GenerationResult::getOutput)
+ .map(GenerationOutput::getText)
+ .orElse("Oops, something wrong...[request id: " + result.getRequestId() + "]");
+ } catch (NoApiKeyException | InputRequiredException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public AiMessage sendMessages(List messages, List toolSpecifications) {
+ throw new IllegalArgumentException("Tools are currently not supported for qwen models");
+ }
+
+ @Override
+ public AiMessage sendMessages(List messages, ToolSpecification toolSpecification) {
+ throw new IllegalArgumentException("Tools are currently not supported for qwen models");
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+ public static class Builder {
+ protected String apiKey;
+ protected String modelName;
+ protected Double topP;
+ protected Double topK;
+ protected Boolean enableSearch;
+ protected Integer seed;
+
+ public Builder apiKey(String apiKey) {
+ this.apiKey = apiKey;
+ return this;
+ }
+
+ public Builder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ public Builder topP(Double topP) {
+ this.topP = topP;
+ return this;
+ }
+
+ public Builder topK(Double topK) {
+ this.topK = topK;
+ return this;
+ }
+
+ public Builder enableSearch(Boolean enableSearch) {
+ this.enableSearch = enableSearch;
+ return this;
+ }
+
+ public Builder seed(Integer seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ protected void ensureOptions() {
+ if (Utils.isNullOrBlank(apiKey)) {
+ throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
+ }
+ modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.QWEN_V1 : modelName;
+ topP = topP == null ? 0.8 : topP;
+ topK = topK == null ? 100.0 : topK;
+ enableSearch = enableSearch == null ? Boolean.FALSE : enableSearch;
+ seed = seed == null ? 1234 : seed;
+ }
+
+ public QwenChatModel build() {
+ ensureOptions();
+ return new QwenChatModel(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+ }
+}
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java
new file mode 100644
index 000000000..b4e706c9e
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java
@@ -0,0 +1,127 @@
+package dev.langchain4j.model.dashscope;
+
+import com.alibaba.dashscope.embeddings.*;
+import com.alibaba.dashscope.exception.NoApiKeyException;
+import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+public class QwenEmbeddingModel implements EmbeddingModel {
+ public static final String TYPE_KEY = "type";
+ public static final String TYPE_QUERY = "query";
+ public static final String TYPE_DOCUMENT = "document";
+ private final String apiKey;
+ private final String modelName;
+ private final TextEmbedding embedding;
+
+ protected QwenEmbeddingModel(String apiKey, String modelName) {
+ this.apiKey = apiKey;
+ this.modelName = modelName;
+ embedding = new TextEmbedding();
+ }
+
+ private boolean containsDocuments(List textSegments) {
+ return textSegments.stream()
+ .map(TextSegment::metadata)
+ .map(metadata -> metadata.get(TYPE_KEY))
+ .filter(TYPE_DOCUMENT::equalsIgnoreCase)
+ .anyMatch(Utils::isNullOrBlank);
+ }
+
+ private boolean containsQueries(List textSegments) {
+ return textSegments.stream()
+ .map(TextSegment::metadata)
+ .map(metadata -> metadata.get(TYPE_KEY))
+ .filter(TYPE_QUERY::equalsIgnoreCase)
+ .anyMatch(Utils::isNullOrBlank);
+ }
+
+ private List embedTexts(List textSegments, TextEmbeddingParam.TextType textType) {
+ TextEmbeddingParam param = TextEmbeddingParam.builder()
+ .apiKey(apiKey)
+ .model(modelName)
+ .textType(textType)
+ .texts(textSegments.stream()
+ .map(TextSegment::text)
+ .collect(Collectors.toList()))
+ .build();
+ try {
+ return Optional.of(embedding.call(param))
+ .map(TextEmbeddingResult::getOutput)
+ .map(TextEmbeddingOutput::getEmbeddings)
+ .orElse(Collections.emptyList())
+ .stream()
+ .map(TextEmbeddingResultItem::getEmbedding)
+ .map(doubleList -> doubleList.stream().map(Double::floatValue).collect(Collectors.toList()))
+ .map(Embedding::from)
+ .collect(Collectors.toList());
+ } catch (NoApiKeyException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public List embedAll(List textSegments) {
+ boolean queries = containsQueries(textSegments);
+
+ if (!queries) {
+ // default all documents
+ return embedTexts(textSegments, TextEmbeddingParam.TextType.DOCUMENT);
+ } else {
+ boolean documents = containsDocuments(textSegments);
+ if (!documents) {
+ return embedTexts(textSegments, TextEmbeddingParam.TextType.QUERY);
+ } else {
+ // This is a mixed collection of queries and documents. Embed one by one.
+ List embeddings = new ArrayList<>(textSegments.size());
+ for (TextSegment textSegment: textSegments) {
+ List result;
+ if (TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().get(TYPE_KEY))) {
+ result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.QUERY);
+ } else {
+ result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.DOCUMENT);
+ }
+ embeddings.addAll(result);
+ }
+ return embeddings;
+ }
+ }
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+ public static class Builder {
+ private String apiKey;
+ private String modelName;
+
+ public Builder apiKey(String apiKey) {
+ this.apiKey = apiKey;
+ return this;
+ }
+
+ public Builder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ protected void ensureOptions() {
+ if (Utils.isNullOrBlank(apiKey)) {
+ throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
+ }
+ modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.TEXT_EMBEDDING_V1 : modelName;
+ }
+
+ public QwenEmbeddingModel build() {
+ ensureOptions();
+ return new QwenEmbeddingModel(apiKey, modelName);
+ }
+ }
+}
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java
new file mode 100644
index 000000000..940d658ce
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java
@@ -0,0 +1,53 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.model.language.LanguageModel;
+
+public class QwenLanguageModel extends QwenChatModel implements LanguageModel {
+ protected QwenLanguageModel(String apiKey,
+ String modelName,
+ Double topP,
+ Double topK,
+ Boolean enableSearch,
+ Integer seed) {
+ super(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+ @Override
+ public String process(String text) {
+ return sendMessage(text);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder extends QwenChatModel.Builder {
+ public Builder apiKey(String apiKey) {
+ return (Builder) super.apiKey(apiKey);
+ }
+
+ public Builder modelName(String modelName) {
+ return (Builder) super.modelName(modelName);
+ }
+
+ public Builder topP(Double topP) {
+ return (Builder) super.topP(topP);
+ }
+
+ public Builder topK(Double topK) {
+ return (Builder) super.topK(topK);
+ }
+
+ public Builder enableSearch(Boolean enableSearch) {
+ return (Builder) super.enableSearch(enableSearch);
+ }
+
+ public Builder seed(Integer seed) {
+ return (Builder) super.seed(seed);
+ }
+
+ public QwenLanguageModel build() {
+ ensureOptions();
+ return new QwenLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+ }
+}
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java
new file mode 100644
index 000000000..f9c312231
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java
@@ -0,0 +1,15 @@
+package dev.langchain4j.model.dashscope;
+
+/**
+ * The LLMs provided by Alibaba Cloud, performs better than most LLMs in Asia languages.
+ */
+public class QwenModelName {
+ // Use with QwenChatModel and QwenLanguageModel
+ public static final String QWEN_V1 = "qwen-v1"; // Qwen base model, 4k context.
+ public static final String QWEN_PLUS_V1 = "qwen-plus-v1"; // Qwen plus model, 8k context.
+ public static final String QWEN_7B_CHAT_V1 = "qwen-7b-chat-v1"; // Qwen open sourced 7-billion-parameters version, 4k context.
+ public static final String QWEN_SPARK_V1 = "qwen-spark-v1"; // Qwen sft for conversation scene, 4k context.
+
+ // Use with QwenEmbeddingModel
+ public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1";
+}
\ No newline at end of file
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java
new file mode 100644
index 000000000..724d725df
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java
@@ -0,0 +1,44 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class QwenParamHelper {
+ private static final String SYSTEM_PREFIX = "<|system|>:";
+ private static final String ASSISTANT_PREFIX = "<|assistant|>:";
+ private static final String USER_PREFIX = "<|user|>:";
+
+ /* Qwen prompt format:
+ * <|system|>: ...
+ *
+ * <|user|>: ...
+ *
+ * <|assistant|>: ...
+ *
+ * <|user|>: ...
+ * ...
+ */
+ public static String toQwenPrompt(List messages) {
+ return messages.stream()
+ .map(QwenParamHelper::toQwenMessage)
+ .collect(Collectors.joining("\n\n"));
+ }
+
+ public static String toQwenMessage(ChatMessage message) {
+ return prefixFrom(message) + message.text();
+ }
+
+ public static String prefixFrom(ChatMessage message) {
+ if (message instanceof AiMessage) {
+ return ASSISTANT_PREFIX;
+ } else if (message instanceof SystemMessage) {
+ return SYSTEM_PREFIX;
+ } else {
+ return USER_PREFIX;
+ }
+ }
+}
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java
new file mode 100644
index 000000000..11b6666b9
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java
@@ -0,0 +1,109 @@
+package dev.langchain4j.model.dashscope;
+
+import com.alibaba.dashscope.aigc.generation.GenerationResult;
+import com.alibaba.dashscope.aigc.generation.models.QwenParam;
+import com.alibaba.dashscope.common.ResultCallback;
+import com.alibaba.dashscope.exception.InputRequiredException;
+import com.alibaba.dashscope.exception.NoApiKeyException;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+
+import java.util.List;
+
+public class QwenStreamingChatModel extends QwenChatModel implements StreamingChatLanguageModel {
+ protected QwenStreamingChatModel(String apiKey,
+ String modelName,
+ Double topP,
+ Double topK,
+ Boolean enableSearch,
+ Integer seed) {
+ super(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+
+ @Override
+ public void sendMessages(List messages, StreamingResponseHandler handler) {
+ sendMessage(QwenParamHelper.toQwenPrompt(messages), handler);
+ }
+
+ protected void sendMessage(String prompt, StreamingResponseHandler handler) {
+ QwenParam param = QwenParam.builder()
+ .apiKey(apiKey)
+ .model(modelName)
+ .topP(topP)
+ .topK(topK)
+ .enableSearch(enableSearch)
+ .seed(seed)
+ .prompt(prompt)
+ .build();
+
+ try {
+ gen.call(param, new ResultCallback() {
+ @Override
+ public void onEvent(GenerationResult result) {
+ handler.onNext(result.getOutput().getText());
+ }
+ @Override
+ public void onComplete() {
+ handler.onComplete();
+ }
+ @Override
+ public void onError(Exception e) {
+ handler.onError(e);
+ }
+ });
+ } catch (NoApiKeyException | InputRequiredException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void sendMessages(List messages,
+ List toolSpecifications,
+ StreamingResponseHandler handler) {
+ throw new IllegalArgumentException("Tools are currently not supported for qwen models");
+ }
+
+ @Override
+ public void sendMessages(List messages,
+ ToolSpecification toolSpecification,
+ StreamingResponseHandler handler) {
+ throw new IllegalArgumentException("Tools are currently not supported for qwen models");
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder extends QwenChatModel.Builder {
+ public Builder apiKey(String apiKey) {
+ return (Builder) super.apiKey(apiKey);
+ }
+
+ public Builder modelName(String modelName) {
+ return (Builder) super.modelName(modelName);
+ }
+
+ public Builder topP(Double topP) {
+ return (Builder) super.topP(topP);
+ }
+
+ public Builder topK(Double topK) {
+ return (Builder) super.topK(topK);
+ }
+
+ public Builder enableSearch(Boolean enableSearch) {
+ return (Builder) super.enableSearch(enableSearch);
+ }
+
+ public Builder seed(Integer seed) {
+ return (Builder) super.seed(seed);
+ }
+
+ public QwenStreamingChatModel build() {
+ ensureOptions();
+ return new QwenStreamingChatModel(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+ }
+}
diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java
new file mode 100644
index 000000000..0797c3bab
--- /dev/null
+++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java
@@ -0,0 +1,54 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.language.StreamingLanguageModel;
+
+public class QwenStreamingLanguageModel extends QwenStreamingChatModel implements StreamingLanguageModel {
+ protected QwenStreamingLanguageModel(String apiKey,
+ String modelName,
+ Double topP,
+ Double topK,
+ Boolean enableSearch,
+ Integer seed) {
+ super(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+
+ @Override
+ public void process(String text, StreamingResponseHandler handler) {
+ sendMessage(text, handler);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder extends QwenStreamingChatModel.Builder {
+ public Builder apiKey(String apiKey) {
+ return (Builder) super.apiKey(apiKey);
+ }
+
+ public Builder modelName(String modelName) {
+ return (Builder) super.modelName(modelName);
+ }
+
+ public Builder topP(Double topP) {
+ return (Builder) super.topP(topP);
+ }
+
+ public Builder topK(Double topK) {
+ return (Builder) super.topK(topK);
+ }
+
+ public Builder enableSearch(Boolean enableSearch) {
+ return (Builder) super.enableSearch(enableSearch);
+ }
+
+ public Builder seed(Integer seed) {
+ return (Builder) super.seed(seed);
+ }
+ public QwenStreamingLanguageModel build() {
+ ensureOptions();
+ return new QwenStreamingLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed);
+ }
+ }
+}
diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java
new file mode 100644
index 000000000..7cf87b423
--- /dev/null
+++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java
@@ -0,0 +1,27 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class QwenChatModelIT {
+
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
+ public void should_send_messages_and_receive_response(String modelName) {
+ String apiKey = QwenTestHelper.apiKey();
+ if (Utils.isNullOrBlank(apiKey)) {
+ return;
+ }
+ ChatLanguageModel model = QwenChatModel.builder()
+ .apiKey(apiKey)
+ .modelName(modelName)
+ .build();
+ AiMessage answer = model.sendMessages(QwenTestHelper.chatMessages());
+ assertThat(answer.text()).containsIgnoringCase("rain");
+ }
+}
diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java
new file mode 100644
index 000000000..89bca6008
--- /dev/null
+++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java
@@ -0,0 +1,88 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.data.document.Metadata;
+import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.List;
+
+import static dev.langchain4j.data.segment.TextSegment.textSegment;
+import static java.util.Arrays.asList;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class QwenEmbeddingModelIT {
+ private EmbeddingModel getModel(String modelName) {
+ String apiKey = QwenTestHelper.apiKey();
+ if (Utils.isNullOrBlank(apiKey)) {
+ return null;
+ }
+ return QwenEmbeddingModel.builder()
+ .apiKey(apiKey)
+ .modelName(modelName)
+ .build();
+ }
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
+ void should_embed_one_text(String modelName) {
+ EmbeddingModel model = getModel(modelName);
+ if (model == null) {
+ return;
+ }
+ Embedding embedding = model.embed("hello");
+ assertThat(embedding.vector()).isNotEmpty();
+ }
+
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
+ void should_embed_documents(String modelName) {
+ EmbeddingModel model = getModel(modelName);
+ if (model == null) {
+ return;
+ }
+ List embeddings = model.embedAll(asList(
+ textSegment("hello"),
+ textSegment("how are you?")
+ ));
+
+ assertThat(embeddings).hasSize(2);
+ assertThat(embeddings.get(0).vector()).isNotEmpty();
+ assertThat(embeddings.get(1).vector()).isNotEmpty();
+ }
+
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
+ void should_embed_queries(String modelName) {
+ EmbeddingModel model = getModel(modelName);
+ if (model == null) {
+ return;
+ }
+ List embeddings = model.embedAll(asList(
+ textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
+ textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY))
+ ));
+
+ assertThat(embeddings).hasSize(2);
+ assertThat(embeddings.get(0).vector()).isNotEmpty();
+ assertThat(embeddings.get(1).vector()).isNotEmpty();
+ }
+
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
+ void should_embed_mix_segments(String modelName) {
+ EmbeddingModel model = getModel(modelName);
+ if (model == null) {
+ return;
+ }
+ List embeddings = model.embedAll(asList(
+ textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
+ textSegment("how are you?")
+ ));
+
+ assertThat(embeddings).hasSize(2);
+ assertThat(embeddings.get(0).vector()).isNotEmpty();
+ assertThat(embeddings.get(1).vector()).isNotEmpty();
+ }
+}
diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java
new file mode 100644
index 000000000..6f1a447ed
--- /dev/null
+++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java
@@ -0,0 +1,24 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.language.LanguageModel;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class QwenLanguageModelIT {
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
+ public void should_send_messages_and_receive_response(String modelName) {
+ String apiKey = QwenTestHelper.apiKey();
+ if (Utils.isNullOrBlank(apiKey)) {
+ return;
+ }
+ LanguageModel model = QwenLanguageModel.builder()
+ .apiKey(apiKey)
+ .modelName(modelName)
+ .build();
+ assertThat(model.process("Please say 'hello' to me")).containsIgnoringCase("hello");
+ }
+}
diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java
new file mode 100644
index 000000000..85cee2302
--- /dev/null
+++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java
@@ -0,0 +1,53 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class QwenStreamingChatModelIT {
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
+ public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException {
+ String apiKey = QwenTestHelper.apiKey();
+ if (Utils.isNullOrBlank(apiKey)) {
+ return;
+ }
+ StreamingChatLanguageModel model = QwenStreamingChatModel.builder()
+ .apiKey(apiKey)
+ .modelName(modelName)
+ .build();
+
+ CompletableFuture future = new CompletableFuture<>();
+ model.sendMessages(
+ QwenTestHelper.chatMessages(),
+ new StreamingResponseHandler() {
+ final StringBuilder answerBuilder = new StringBuilder();
+ @Override
+ public void onNext(String partialResult) {
+ answerBuilder.append(partialResult);
+ System.out.println("onPartialResult: '" + partialResult + "'");
+ }
+ @Override
+ public void onComplete() {
+ future.complete(answerBuilder.toString());
+ System.out.println("onComplete");
+ }
+ @Override
+ public void onError(Throwable error) {
+ future.completeExceptionally(error);
+ }
+ });
+
+ String answer = future.get(30, SECONDS);
+ assertThat(answer).containsIgnoringCase("rain");
+ }
+}
diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java
new file mode 100644
index 000000000..e19558668
--- /dev/null
+++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java
@@ -0,0 +1,53 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.language.StreamingLanguageModel;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class QwenStreamingLanguageModelIT {
+ @ParameterizedTest
+ @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
+ public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException {
+ String apiKey = QwenTestHelper.apiKey();
+ if (Utils.isNullOrBlank(apiKey)) {
+ return;
+ }
+ StreamingLanguageModel model = QwenStreamingLanguageModel.builder()
+ .apiKey(apiKey)
+ .modelName(modelName)
+ .build();
+
+ CompletableFuture future = new CompletableFuture<>();
+ model.process(
+ "Please say 'hello' to me",
+ new StreamingResponseHandler() {
+ final StringBuilder answerBuilder = new StringBuilder();
+ @Override
+ public void onNext(String partialResult) {
+ answerBuilder.append(partialResult);
+ System.out.println("onPartialResult: '" + partialResult + "'");
+ }
+ @Override
+ public void onComplete() {
+ future.complete(answerBuilder.toString());
+ System.out.println("onComplete");
+ }
+ @Override
+ public void onError(Throwable error) {
+ future.completeExceptionally(error);
+ }
+ });
+
+ String answer = future.get(30, SECONDS);
+ assertThat(answer).containsIgnoringCase("hello");
+ }
+}
diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java
new file mode 100644
index 000000000..e666f2f50
--- /dev/null
+++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java
@@ -0,0 +1,43 @@
+package dev.langchain4j.model.dashscope;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.UserMessage;
+import org.junit.jupiter.params.provider.Arguments;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.stream.Stream;
+
+public class QwenTestHelper {
+ public static Stream chatModelNameProvider() {
+ return Stream.of(
+ Arguments.of(QwenModelName.QWEN_V1),
+ Arguments.of(QwenModelName.QWEN_PLUS_V1),
+ Arguments.of(QwenModelName.QWEN_7B_CHAT_V1),
+ Arguments.of(QwenModelName.QWEN_SPARK_V1)
+ );
+ }
+
+ public static Stream embeddingModelNameProvider() {
+ return Stream.of(
+ Arguments.of(QwenModelName.TEXT_EMBEDDING_V1)
+ );
+ }
+
+ public static String apiKey() {
+ return System.getenv("DASHSCOPE_API_KEY");
+ }
+
+ public static List chatMessages() {
+ List messages = new LinkedList<>();
+ messages.add(SystemMessage.from("Your name is Jack." +
+ " You like to answer other people's questions briefly." +
+ " It's rainy today."));
+ messages.add(UserMessage.from("Hello. What's your name?"));
+ messages.add(AiMessage.from("Jack."));
+ messages.add(UserMessage.from("How about the weather today?"));
+ return messages;
+ }
+}
diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml
index 1a26ed7cc..0f349b753 100644
--- a/langchain4j-parent/pom.xml
+++ b/langchain4j-parent/pom.xml
@@ -6,7 +6,7 @@
dev.langchain4j
langchain4j-parent
- 0.20.0
+ ${revision}
pom
langchain4j parent POM
@@ -14,6 +14,7 @@
https://github.com/langchain4j/langchain4j
+ 0.20.0
1.8
1.8
UTF-8
diff --git a/langchain4j-pinecone/pom.xml b/langchain4j-pinecone/pom.xml
index 65729561c..06f68815c 100644
--- a/langchain4j-pinecone/pom.xml
+++ b/langchain4j-pinecone/pom.xml
@@ -7,7 +7,7 @@
dev.langchain4j
langchain4j-parent
- 0.20.0
+ ${revision}
../langchain4j-parent/pom.xml
@@ -23,7 +23,7 @@
dev.langchain4j
langchain4j-core
- 0.20.0
+ ${project.version}
diff --git a/langchain4j-spring-boot-starter/pom.xml b/langchain4j-spring-boot-starter/pom.xml
index 152f277fa..d16aea193 100644
--- a/langchain4j-spring-boot-starter/pom.xml
+++ b/langchain4j-spring-boot-starter/pom.xml
@@ -7,7 +7,7 @@
dev.langchain4j
langchain4j-parent
- 0.20.0
+ ${revision}
../langchain4j-parent/pom.xml
@@ -22,7 +22,7 @@
dev.langchain4j
langchain4j
- 0.20.0
+ ${project.version}
diff --git a/langchain4j-weaviate/pom.xml b/langchain4j-weaviate/pom.xml
index c6397184e..082b23559 100644
--- a/langchain4j-weaviate/pom.xml
+++ b/langchain4j-weaviate/pom.xml
@@ -7,7 +7,7 @@
dev.langchain4j
langchain4j-parent
- 0.20.0
+ ${revision}
../langchain4j-parent/pom.xml
@@ -25,7 +25,7 @@
dev.langchain4j
langchain4j-core
- 0.20.0
+ ${project.version}
diff --git a/langchain4j/pom.xml b/langchain4j/pom.xml
index 96258bd43..95f7eb388 100644
--- a/langchain4j/pom.xml
+++ b/langchain4j/pom.xml
@@ -7,7 +7,7 @@
dev.langchain4j
langchain4j-parent
- 0.20.0
+ ${revision}
../langchain4j-parent/pom.xml
@@ -24,7 +24,7 @@
dev.langchain4j
langchain4j-core
- 0.20.0
+ ${project.version}
diff --git a/pom.xml b/pom.xml
index 98f2d902e..1b91ec872 100644
--- a/pom.xml
+++ b/pom.xml
@@ -10,17 +10,21 @@
pom
+
langchain4j-parent
langchain4j-core
langchain4j
-
- langchain4j-pinecone
- langchain4j-weaviate
-
langchain4j-spring-boot-starter
+
+ langchain4j-dashscope
+
+
langchain4j-milvus
+ langchain4j-pinecone
+ langchain4j-weaviate
+
\ No newline at end of file