From d908f5158ae0abdb0571e68e003311d67d57beca Mon Sep 17 00:00:00 2001 From: jiangsier-xyz <126842484+jiangsier-xyz@users.noreply.github.com> Date: Sat, 19 Aug 2023 02:49:50 +0800 Subject: [PATCH] Integrate the Qwen series models via dashscope-sdk. (#99) Qwen series models are provided by Alibaba Cloud. They are much better in Asia languages then other LLMs. DashScope is a model service platform. Qwen models are its primary supported models. But it also supports other series like LLaMA2, Dolly, ChatGLM, BiLLa(based on LLaMA)...These may be integrated sometime in the future. --- langchain4j-core/pom.xml | 2 +- langchain4j-dashscope/pom.xml | 76 ++++++++++ .../model/dashscope/QwenChatModel.java | 137 ++++++++++++++++++ .../model/dashscope/QwenEmbeddingModel.java | 127 ++++++++++++++++ .../model/dashscope/QwenLanguageModel.java | 53 +++++++ .../model/dashscope/QwenModelName.java | 15 ++ .../model/dashscope/QwenParamHelper.java | 44 ++++++ .../dashscope/QwenStreamingChatModel.java | 109 ++++++++++++++ .../dashscope/QwenStreamingLanguageModel.java | 54 +++++++ .../model/dashscope/QwenChatModelIT.java | 27 ++++ .../model/dashscope/QwenEmbeddingModelIT.java | 88 +++++++++++ .../model/dashscope/QwenLanguageModelIT.java | 24 +++ .../dashscope/QwenStreamingChatModelIT.java | 53 +++++++ .../QwenStreamingLanguageModelIT.java | 53 +++++++ .../model/dashscope/QwenTestHelper.java | 43 ++++++ langchain4j-parent/pom.xml | 3 +- langchain4j-pinecone/pom.xml | 4 +- langchain4j-spring-boot-starter/pom.xml | 4 +- langchain4j-weaviate/pom.xml | 4 +- langchain4j/pom.xml | 4 +- pom.xml | 12 +- 21 files changed, 922 insertions(+), 14 deletions(-) create mode 100644 langchain4j-dashscope/pom.xml create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java create mode 100644 langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java create mode 100644 langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java diff --git a/langchain4j-core/pom.xml b/langchain4j-core/pom.xml index dc48c31bf..2e2953b8b 100644 --- a/langchain4j-core/pom.xml +++ b/langchain4j-core/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.20.0 + ${revision} ../langchain4j-parent/pom.xml diff --git a/langchain4j-dashscope/pom.xml b/langchain4j-dashscope/pom.xml new file mode 100644 index 000000000..fbf02cfd2 --- /dev/null +++ b/langchain4j-dashscope/pom.xml @@ -0,0 +1,76 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + ${revision} + ../langchain4j-parent/pom.xml + + + + langchain4j-dashscope + jar + + LangChain4j integration with DashScope + It uses the dashscope-sdk-java library, which has an Apache-2.0 license. + + + + + dev.langchain4j + langchain4j-core + ${project.version} + + + + com.alibaba + dashscope-sdk-java + 2.1.1 + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + + + + org.honton.chas + license-maven-plugin + + + true + + + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + \ No newline at end of file diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java new file mode 100644 index 000000000..f93d4e361 --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java @@ -0,0 +1,137 @@ +package dev.langchain4j.model.dashscope; + +import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.generation.GenerationOutput; +import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.aigc.generation.models.QwenParam; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.chat.ChatLanguageModel; + +import java.util.List; +import java.util.Optional; + +public class QwenChatModel implements ChatLanguageModel { + protected final Generation gen; + protected final String apiKey; + protected final String modelName; + protected final Double topP; + protected final Double topK; + protected final Boolean enableSearch; + protected final Integer seed; + + protected QwenChatModel(String apiKey, + String modelName, + Double topP, + Double topK, + Boolean enableSearch, + Integer seed) { + + this.apiKey = apiKey; + this.modelName = modelName; + this.topP = topP; + this.topK = topK; + this.enableSearch = enableSearch; + this.seed = seed; + gen = new Generation(); + } + + @Override + public AiMessage sendMessages(List messages) { + return AiMessage.aiMessage(sendMessage(QwenParamHelper.toQwenPrompt(messages))); + } + + protected String sendMessage(String prompt) { + QwenParam param = QwenParam.builder() + .apiKey(apiKey) + .model(modelName) + .topP(topP) + .topK(topK) + .enableSearch(enableSearch) + .seed(seed) + .prompt(prompt) + .build(); + + try { + GenerationResult result = gen.call(param); + return Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getText) + .orElse("Oops, something wrong...[request id: " + result.getRequestId() + "]"); + } catch (NoApiKeyException | InputRequiredException e) { + throw new RuntimeException(e); + } + } + + @Override + public AiMessage sendMessages(List messages, List toolSpecifications) { + throw new IllegalArgumentException("Tools are currently not supported for qwen models"); + } + + @Override + public AiMessage sendMessages(List messages, ToolSpecification toolSpecification) { + throw new IllegalArgumentException("Tools are currently not supported for qwen models"); + } + + public static Builder builder() { + return new Builder(); + } + public static class Builder { + protected String apiKey; + protected String modelName; + protected Double topP; + protected Double topK; + protected Boolean enableSearch; + protected Integer seed; + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder topK(Double topK) { + this.topK = topK; + return this; + } + + public Builder enableSearch(Boolean enableSearch) { + this.enableSearch = enableSearch; + return this; + } + + public Builder seed(Integer seed) { + this.seed = seed; + return this; + } + + protected void ensureOptions() { + if (Utils.isNullOrBlank(apiKey)) { + throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey"); + } + modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.QWEN_V1 : modelName; + topP = topP == null ? 0.8 : topP; + topK = topK == null ? 100.0 : topK; + enableSearch = enableSearch == null ? Boolean.FALSE : enableSearch; + seed = seed == null ? 1234 : seed; + } + + public QwenChatModel build() { + ensureOptions(); + return new QwenChatModel(apiKey, modelName, topP, topK, enableSearch, seed); + } + } +} diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java new file mode 100644 index 000000000..b4e706c9e --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java @@ -0,0 +1,127 @@ +package dev.langchain4j.model.dashscope; + +import com.alibaba.dashscope.embeddings.*; +import com.alibaba.dashscope.exception.NoApiKeyException; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.embedding.EmbeddingModel; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class QwenEmbeddingModel implements EmbeddingModel { + public static final String TYPE_KEY = "type"; + public static final String TYPE_QUERY = "query"; + public static final String TYPE_DOCUMENT = "document"; + private final String apiKey; + private final String modelName; + private final TextEmbedding embedding; + + protected QwenEmbeddingModel(String apiKey, String modelName) { + this.apiKey = apiKey; + this.modelName = modelName; + embedding = new TextEmbedding(); + } + + private boolean containsDocuments(List textSegments) { + return textSegments.stream() + .map(TextSegment::metadata) + .map(metadata -> metadata.get(TYPE_KEY)) + .filter(TYPE_DOCUMENT::equalsIgnoreCase) + .anyMatch(Utils::isNullOrBlank); + } + + private boolean containsQueries(List textSegments) { + return textSegments.stream() + .map(TextSegment::metadata) + .map(metadata -> metadata.get(TYPE_KEY)) + .filter(TYPE_QUERY::equalsIgnoreCase) + .anyMatch(Utils::isNullOrBlank); + } + + private List embedTexts(List textSegments, TextEmbeddingParam.TextType textType) { + TextEmbeddingParam param = TextEmbeddingParam.builder() + .apiKey(apiKey) + .model(modelName) + .textType(textType) + .texts(textSegments.stream() + .map(TextSegment::text) + .collect(Collectors.toList())) + .build(); + try { + return Optional.of(embedding.call(param)) + .map(TextEmbeddingResult::getOutput) + .map(TextEmbeddingOutput::getEmbeddings) + .orElse(Collections.emptyList()) + .stream() + .map(TextEmbeddingResultItem::getEmbedding) + .map(doubleList -> doubleList.stream().map(Double::floatValue).collect(Collectors.toList())) + .map(Embedding::from) + .collect(Collectors.toList()); + } catch (NoApiKeyException e) { + throw new RuntimeException(e); + } + } + + @Override + public List embedAll(List textSegments) { + boolean queries = containsQueries(textSegments); + + if (!queries) { + // default all documents + return embedTexts(textSegments, TextEmbeddingParam.TextType.DOCUMENT); + } else { + boolean documents = containsDocuments(textSegments); + if (!documents) { + return embedTexts(textSegments, TextEmbeddingParam.TextType.QUERY); + } else { + // This is a mixed collection of queries and documents. Embed one by one. + List embeddings = new ArrayList<>(textSegments.size()); + for (TextSegment textSegment: textSegments) { + List result; + if (TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().get(TYPE_KEY))) { + result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.QUERY); + } else { + result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.DOCUMENT); + } + embeddings.addAll(result); + } + return embeddings; + } + } + } + + public static Builder builder() { + return new Builder(); + } + public static class Builder { + private String apiKey; + private String modelName; + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + protected void ensureOptions() { + if (Utils.isNullOrBlank(apiKey)) { + throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey"); + } + modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.TEXT_EMBEDDING_V1 : modelName; + } + + public QwenEmbeddingModel build() { + ensureOptions(); + return new QwenEmbeddingModel(apiKey, modelName); + } + } +} diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java new file mode 100644 index 000000000..940d658ce --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java @@ -0,0 +1,53 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.model.language.LanguageModel; + +public class QwenLanguageModel extends QwenChatModel implements LanguageModel { + protected QwenLanguageModel(String apiKey, + String modelName, + Double topP, + Double topK, + Boolean enableSearch, + Integer seed) { + super(apiKey, modelName, topP, topK, enableSearch, seed); + } + @Override + public String process(String text) { + return sendMessage(text); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends QwenChatModel.Builder { + public Builder apiKey(String apiKey) { + return (Builder) super.apiKey(apiKey); + } + + public Builder modelName(String modelName) { + return (Builder) super.modelName(modelName); + } + + public Builder topP(Double topP) { + return (Builder) super.topP(topP); + } + + public Builder topK(Double topK) { + return (Builder) super.topK(topK); + } + + public Builder enableSearch(Boolean enableSearch) { + return (Builder) super.enableSearch(enableSearch); + } + + public Builder seed(Integer seed) { + return (Builder) super.seed(seed); + } + + public QwenLanguageModel build() { + ensureOptions(); + return new QwenLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed); + } + } +} diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java new file mode 100644 index 000000000..f9c312231 --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java @@ -0,0 +1,15 @@ +package dev.langchain4j.model.dashscope; + +/** + * The LLMs provided by Alibaba Cloud, performs better than most LLMs in Asia languages. + */ +public class QwenModelName { + // Use with QwenChatModel and QwenLanguageModel + public static final String QWEN_V1 = "qwen-v1"; // Qwen base model, 4k context. + public static final String QWEN_PLUS_V1 = "qwen-plus-v1"; // Qwen plus model, 8k context. + public static final String QWEN_7B_CHAT_V1 = "qwen-7b-chat-v1"; // Qwen open sourced 7-billion-parameters version, 4k context. + public static final String QWEN_SPARK_V1 = "qwen-spark-v1"; // Qwen sft for conversation scene, 4k context. + + // Use with QwenEmbeddingModel + public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1"; +} \ No newline at end of file diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java new file mode 100644 index 000000000..724d725df --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenParamHelper.java @@ -0,0 +1,44 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; + +import java.util.List; +import java.util.stream.Collectors; + +public class QwenParamHelper { + private static final String SYSTEM_PREFIX = "<|system|>:"; + private static final String ASSISTANT_PREFIX = "<|assistant|>:"; + private static final String USER_PREFIX = "<|user|>:"; + + /* Qwen prompt format: + * <|system|>: ... + * + * <|user|>: ... + * + * <|assistant|>: ... + * + * <|user|>: ... + * ... + */ + public static String toQwenPrompt(List messages) { + return messages.stream() + .map(QwenParamHelper::toQwenMessage) + .collect(Collectors.joining("\n\n")); + } + + public static String toQwenMessage(ChatMessage message) { + return prefixFrom(message) + message.text(); + } + + public static String prefixFrom(ChatMessage message) { + if (message instanceof AiMessage) { + return ASSISTANT_PREFIX; + } else if (message instanceof SystemMessage) { + return SYSTEM_PREFIX; + } else { + return USER_PREFIX; + } + } +} diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java new file mode 100644 index 000000000..11b6666b9 --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java @@ -0,0 +1,109 @@ +package dev.langchain4j.model.dashscope; + +import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.aigc.generation.models.QwenParam; +import com.alibaba.dashscope.common.ResultCallback; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; + +import java.util.List; + +public class QwenStreamingChatModel extends QwenChatModel implements StreamingChatLanguageModel { + protected QwenStreamingChatModel(String apiKey, + String modelName, + Double topP, + Double topK, + Boolean enableSearch, + Integer seed) { + super(apiKey, modelName, topP, topK, enableSearch, seed); + } + + @Override + public void sendMessages(List messages, StreamingResponseHandler handler) { + sendMessage(QwenParamHelper.toQwenPrompt(messages), handler); + } + + protected void sendMessage(String prompt, StreamingResponseHandler handler) { + QwenParam param = QwenParam.builder() + .apiKey(apiKey) + .model(modelName) + .topP(topP) + .topK(topK) + .enableSearch(enableSearch) + .seed(seed) + .prompt(prompt) + .build(); + + try { + gen.call(param, new ResultCallback() { + @Override + public void onEvent(GenerationResult result) { + handler.onNext(result.getOutput().getText()); + } + @Override + public void onComplete() { + handler.onComplete(); + } + @Override + public void onError(Exception e) { + handler.onError(e); + } + }); + } catch (NoApiKeyException | InputRequiredException e) { + throw new RuntimeException(e); + } + } + + @Override + public void sendMessages(List messages, + List toolSpecifications, + StreamingResponseHandler handler) { + throw new IllegalArgumentException("Tools are currently not supported for qwen models"); + } + + @Override + public void sendMessages(List messages, + ToolSpecification toolSpecification, + StreamingResponseHandler handler) { + throw new IllegalArgumentException("Tools are currently not supported for qwen models"); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends QwenChatModel.Builder { + public Builder apiKey(String apiKey) { + return (Builder) super.apiKey(apiKey); + } + + public Builder modelName(String modelName) { + return (Builder) super.modelName(modelName); + } + + public Builder topP(Double topP) { + return (Builder) super.topP(topP); + } + + public Builder topK(Double topK) { + return (Builder) super.topK(topK); + } + + public Builder enableSearch(Boolean enableSearch) { + return (Builder) super.enableSearch(enableSearch); + } + + public Builder seed(Integer seed) { + return (Builder) super.seed(seed); + } + + public QwenStreamingChatModel build() { + ensureOptions(); + return new QwenStreamingChatModel(apiKey, modelName, topP, topK, enableSearch, seed); + } + } +} diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java new file mode 100644 index 000000000..0797c3bab --- /dev/null +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java @@ -0,0 +1,54 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.language.StreamingLanguageModel; + +public class QwenStreamingLanguageModel extends QwenStreamingChatModel implements StreamingLanguageModel { + protected QwenStreamingLanguageModel(String apiKey, + String modelName, + Double topP, + Double topK, + Boolean enableSearch, + Integer seed) { + super(apiKey, modelName, topP, topK, enableSearch, seed); + } + + @Override + public void process(String text, StreamingResponseHandler handler) { + sendMessage(text, handler); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends QwenStreamingChatModel.Builder { + public Builder apiKey(String apiKey) { + return (Builder) super.apiKey(apiKey); + } + + public Builder modelName(String modelName) { + return (Builder) super.modelName(modelName); + } + + public Builder topP(Double topP) { + return (Builder) super.topP(topP); + } + + public Builder topK(Double topK) { + return (Builder) super.topK(topK); + } + + public Builder enableSearch(Boolean enableSearch) { + return (Builder) super.enableSearch(enableSearch); + } + + public Builder seed(Integer seed) { + return (Builder) super.seed(seed); + } + public QwenStreamingLanguageModel build() { + ensureOptions(); + return new QwenStreamingLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed); + } + } +} diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java new file mode 100644 index 000000000..7cf87b423 --- /dev/null +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.chat.ChatLanguageModel; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.assertj.core.api.Assertions.assertThat; + +public class QwenChatModelIT { + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider") + public void should_send_messages_and_receive_response(String modelName) { + String apiKey = QwenTestHelper.apiKey(); + if (Utils.isNullOrBlank(apiKey)) { + return; + } + ChatLanguageModel model = QwenChatModel.builder() + .apiKey(apiKey) + .modelName(modelName) + .build(); + AiMessage answer = model.sendMessages(QwenTestHelper.chatMessages()); + assertThat(answer.text()).containsIgnoringCase("rain"); + } +} diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java new file mode 100644 index 000000000..89bca6008 --- /dev/null +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java @@ -0,0 +1,88 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.embedding.EmbeddingModel; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; + +import static dev.langchain4j.data.segment.TextSegment.textSegment; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +public class QwenEmbeddingModelIT { + private EmbeddingModel getModel(String modelName) { + String apiKey = QwenTestHelper.apiKey(); + if (Utils.isNullOrBlank(apiKey)) { + return null; + } + return QwenEmbeddingModel.builder() + .apiKey(apiKey) + .modelName(modelName) + .build(); + } + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_one_text(String modelName) { + EmbeddingModel model = getModel(modelName); + if (model == null) { + return; + } + Embedding embedding = model.embed("hello"); + assertThat(embedding.vector()).isNotEmpty(); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_documents(String modelName) { + EmbeddingModel model = getModel(modelName); + if (model == null) { + return; + } + List embeddings = model.embedAll(asList( + textSegment("hello"), + textSegment("how are you?") + )); + + assertThat(embeddings).hasSize(2); + assertThat(embeddings.get(0).vector()).isNotEmpty(); + assertThat(embeddings.get(1).vector()).isNotEmpty(); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_queries(String modelName) { + EmbeddingModel model = getModel(modelName); + if (model == null) { + return; + } + List embeddings = model.embedAll(asList( + textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), + textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)) + )); + + assertThat(embeddings).hasSize(2); + assertThat(embeddings.get(0).vector()).isNotEmpty(); + assertThat(embeddings.get(1).vector()).isNotEmpty(); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_mix_segments(String modelName) { + EmbeddingModel model = getModel(modelName); + if (model == null) { + return; + } + List embeddings = model.embedAll(asList( + textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), + textSegment("how are you?") + )); + + assertThat(embeddings).hasSize(2); + assertThat(embeddings.get(0).vector()).isNotEmpty(); + assertThat(embeddings.get(1).vector()).isNotEmpty(); + } +} diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java new file mode 100644 index 000000000..6f1a447ed --- /dev/null +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenLanguageModelIT.java @@ -0,0 +1,24 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.language.LanguageModel; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.assertj.core.api.Assertions.assertThat; + +public class QwenLanguageModelIT { + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider") + public void should_send_messages_and_receive_response(String modelName) { + String apiKey = QwenTestHelper.apiKey(); + if (Utils.isNullOrBlank(apiKey)) { + return; + } + LanguageModel model = QwenLanguageModel.builder() + .apiKey(apiKey) + .modelName(modelName) + .build(); + assertThat(model.process("Please say 'hello' to me")).containsIgnoringCase("hello"); + } +} diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java new file mode 100644 index 000000000..85cee2302 --- /dev/null +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingChatModelIT.java @@ -0,0 +1,53 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +public class QwenStreamingChatModelIT { + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider") + public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException { + String apiKey = QwenTestHelper.apiKey(); + if (Utils.isNullOrBlank(apiKey)) { + return; + } + StreamingChatLanguageModel model = QwenStreamingChatModel.builder() + .apiKey(apiKey) + .modelName(modelName) + .build(); + + CompletableFuture future = new CompletableFuture<>(); + model.sendMessages( + QwenTestHelper.chatMessages(), + new StreamingResponseHandler() { + final StringBuilder answerBuilder = new StringBuilder(); + @Override + public void onNext(String partialResult) { + answerBuilder.append(partialResult); + System.out.println("onPartialResult: '" + partialResult + "'"); + } + @Override + public void onComplete() { + future.complete(answerBuilder.toString()); + System.out.println("onComplete"); + } + @Override + public void onError(Throwable error) { + future.completeExceptionally(error); + } + }); + + String answer = future.get(30, SECONDS); + assertThat(answer).containsIgnoringCase("rain"); + } +} diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java new file mode 100644 index 000000000..e19558668 --- /dev/null +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModelIT.java @@ -0,0 +1,53 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.language.StreamingLanguageModel; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +public class QwenStreamingLanguageModelIT { + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider") + public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException { + String apiKey = QwenTestHelper.apiKey(); + if (Utils.isNullOrBlank(apiKey)) { + return; + } + StreamingLanguageModel model = QwenStreamingLanguageModel.builder() + .apiKey(apiKey) + .modelName(modelName) + .build(); + + CompletableFuture future = new CompletableFuture<>(); + model.process( + "Please say 'hello' to me", + new StreamingResponseHandler() { + final StringBuilder answerBuilder = new StringBuilder(); + @Override + public void onNext(String partialResult) { + answerBuilder.append(partialResult); + System.out.println("onPartialResult: '" + partialResult + "'"); + } + @Override + public void onComplete() { + future.complete(answerBuilder.toString()); + System.out.println("onComplete"); + } + @Override + public void onError(Throwable error) { + future.completeExceptionally(error); + } + }); + + String answer = future.get(30, SECONDS); + assertThat(answer).containsIgnoringCase("hello"); + } +} diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java new file mode 100644 index 000000000..e666f2f50 --- /dev/null +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java @@ -0,0 +1,43 @@ +package dev.langchain4j.model.dashscope; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import org.junit.jupiter.params.provider.Arguments; + +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Stream; + +public class QwenTestHelper { + public static Stream chatModelNameProvider() { + return Stream.of( + Arguments.of(QwenModelName.QWEN_V1), + Arguments.of(QwenModelName.QWEN_PLUS_V1), + Arguments.of(QwenModelName.QWEN_7B_CHAT_V1), + Arguments.of(QwenModelName.QWEN_SPARK_V1) + ); + } + + public static Stream embeddingModelNameProvider() { + return Stream.of( + Arguments.of(QwenModelName.TEXT_EMBEDDING_V1) + ); + } + + public static String apiKey() { + return System.getenv("DASHSCOPE_API_KEY"); + } + + public static List chatMessages() { + List messages = new LinkedList<>(); + messages.add(SystemMessage.from("Your name is Jack." + + " You like to answer other people's questions briefly." + + " It's rainy today.")); + messages.add(UserMessage.from("Hello. What's your name?")); + messages.add(AiMessage.from("Jack.")); + messages.add(UserMessage.from("How about the weather today?")); + return messages; + } +} diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml index 1a26ed7cc..0f349b753 100644 --- a/langchain4j-parent/pom.xml +++ b/langchain4j-parent/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.20.0 + ${revision} pom langchain4j parent POM @@ -14,6 +14,7 @@ https://github.com/langchain4j/langchain4j + 0.20.0 1.8 1.8 UTF-8 diff --git a/langchain4j-pinecone/pom.xml b/langchain4j-pinecone/pom.xml index 65729561c..06f68815c 100644 --- a/langchain4j-pinecone/pom.xml +++ b/langchain4j-pinecone/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.20.0 + ${revision} ../langchain4j-parent/pom.xml @@ -23,7 +23,7 @@ dev.langchain4j langchain4j-core - 0.20.0 + ${project.version} diff --git a/langchain4j-spring-boot-starter/pom.xml b/langchain4j-spring-boot-starter/pom.xml index 152f277fa..d16aea193 100644 --- a/langchain4j-spring-boot-starter/pom.xml +++ b/langchain4j-spring-boot-starter/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.20.0 + ${revision} ../langchain4j-parent/pom.xml @@ -22,7 +22,7 @@ dev.langchain4j langchain4j - 0.20.0 + ${project.version} diff --git a/langchain4j-weaviate/pom.xml b/langchain4j-weaviate/pom.xml index c6397184e..082b23559 100644 --- a/langchain4j-weaviate/pom.xml +++ b/langchain4j-weaviate/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.20.0 + ${revision} ../langchain4j-parent/pom.xml @@ -25,7 +25,7 @@ dev.langchain4j langchain4j-core - 0.20.0 + ${project.version} diff --git a/langchain4j/pom.xml b/langchain4j/pom.xml index 96258bd43..95f7eb388 100644 --- a/langchain4j/pom.xml +++ b/langchain4j/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.20.0 + ${revision} ../langchain4j-parent/pom.xml @@ -24,7 +24,7 @@ dev.langchain4j langchain4j-core - 0.20.0 + ${project.version} diff --git a/pom.xml b/pom.xml index 98f2d902e..1b91ec872 100644 --- a/pom.xml +++ b/pom.xml @@ -10,17 +10,21 @@ pom + langchain4j-parent langchain4j-core langchain4j - - langchain4j-pinecone - langchain4j-weaviate - langchain4j-spring-boot-starter + + langchain4j-dashscope + + langchain4j-milvus + langchain4j-pinecone + langchain4j-weaviate + \ No newline at end of file