Integrate the Qwen series models via dashscope-sdk. (#99)

Qwen series models are provided by Alibaba Cloud. They are much better
in Asia languages then other LLMs.

DashScope is a model service platform. Qwen models are its primary
supported models. But it also supports other series like LLaMA2, Dolly,
ChatGLM, BiLLa(based on LLaMA)...These may be integrated sometime in the
future.
This commit is contained in:
jiangsier-xyz 2023-08-19 02:49:50 +08:00 committed by GitHub
parent ec4a673b52
commit d908f5158a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 922 additions and 14 deletions

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.20.0</version>
<version>${revision}</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

View File

@ -0,0 +1,76 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>${revision}</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-dashscope</artifactId>
<packaging>jar</packaging>
<name>LangChain4j integration with DashScope</name>
<description>It uses the dashscope-sdk-java library, which has an Apache-2.0 license.</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dashscope-sdk-java</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.honton.chas</groupId>
<artifactId>license-maven-plugin</artifactId>
<configuration>
<!-- The weaviate client has a BSD-3 license, see https://github.com/weaviate/java-client/blob/main/LICENSE -->
<skipCompliance>true</skipCompliance>
</configuration>
</plugin>
</plugins>
</build>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,137 @@
package dev.langchain4j.model.dashscope;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.List;
import java.util.Optional;
public class QwenChatModel implements ChatLanguageModel {
protected final Generation gen;
protected final String apiKey;
protected final String modelName;
protected final Double topP;
protected final Double topK;
protected final Boolean enableSearch;
protected final Integer seed;
protected QwenChatModel(String apiKey,
String modelName,
Double topP,
Double topK,
Boolean enableSearch,
Integer seed) {
this.apiKey = apiKey;
this.modelName = modelName;
this.topP = topP;
this.topK = topK;
this.enableSearch = enableSearch;
this.seed = seed;
gen = new Generation();
}
@Override
public AiMessage sendMessages(List<ChatMessage> messages) {
return AiMessage.aiMessage(sendMessage(QwenParamHelper.toQwenPrompt(messages)));
}
protected String sendMessage(String prompt) {
QwenParam param = QwenParam.builder()
.apiKey(apiKey)
.model(modelName)
.topP(topP)
.topK(topK)
.enableSearch(enableSearch)
.seed(seed)
.prompt(prompt)
.build();
try {
GenerationResult result = gen.call(param);
return Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getText)
.orElse("Oops, something wrong...[request id: " + result.getRequestId() + "]");
} catch (NoApiKeyException | InputRequiredException e) {
throw new RuntimeException(e);
}
}
@Override
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
}
@Override
public AiMessage sendMessages(List<ChatMessage> messages, ToolSpecification toolSpecification) {
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected String apiKey;
protected String modelName;
protected Double topP;
protected Double topK;
protected Boolean enableSearch;
protected Integer seed;
public Builder apiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}
public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}
public Builder topP(Double topP) {
this.topP = topP;
return this;
}
public Builder topK(Double topK) {
this.topK = topK;
return this;
}
public Builder enableSearch(Boolean enableSearch) {
this.enableSearch = enableSearch;
return this;
}
public Builder seed(Integer seed) {
this.seed = seed;
return this;
}
protected void ensureOptions() {
if (Utils.isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
}
modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.QWEN_V1 : modelName;
topP = topP == null ? 0.8 : topP;
topK = topK == null ? 100.0 : topK;
enableSearch = enableSearch == null ? Boolean.FALSE : enableSearch;
seed = seed == null ? 1234 : seed;
}
public QwenChatModel build() {
ensureOptions();
return new QwenChatModel(apiKey, modelName, topP, topK, enableSearch, seed);
}
}
}

View File

@ -0,0 +1,127 @@
package dev.langchain4j.model.dashscope;
import com.alibaba.dashscope.embeddings.*;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public class QwenEmbeddingModel implements EmbeddingModel {
public static final String TYPE_KEY = "type";
public static final String TYPE_QUERY = "query";
public static final String TYPE_DOCUMENT = "document";
private final String apiKey;
private final String modelName;
private final TextEmbedding embedding;
protected QwenEmbeddingModel(String apiKey, String modelName) {
this.apiKey = apiKey;
this.modelName = modelName;
embedding = new TextEmbedding();
}
private boolean containsDocuments(List<TextSegment> textSegments) {
return textSegments.stream()
.map(TextSegment::metadata)
.map(metadata -> metadata.get(TYPE_KEY))
.filter(TYPE_DOCUMENT::equalsIgnoreCase)
.anyMatch(Utils::isNullOrBlank);
}
private boolean containsQueries(List<TextSegment> textSegments) {
return textSegments.stream()
.map(TextSegment::metadata)
.map(metadata -> metadata.get(TYPE_KEY))
.filter(TYPE_QUERY::equalsIgnoreCase)
.anyMatch(Utils::isNullOrBlank);
}
private List<Embedding> embedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
TextEmbeddingParam param = TextEmbeddingParam.builder()
.apiKey(apiKey)
.model(modelName)
.textType(textType)
.texts(textSegments.stream()
.map(TextSegment::text)
.collect(Collectors.toList()))
.build();
try {
return Optional.of(embedding.call(param))
.map(TextEmbeddingResult::getOutput)
.map(TextEmbeddingOutput::getEmbeddings)
.orElse(Collections.emptyList())
.stream()
.map(TextEmbeddingResultItem::getEmbedding)
.map(doubleList -> doubleList.stream().map(Double::floatValue).collect(Collectors.toList()))
.map(Embedding::from)
.collect(Collectors.toList());
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}
@Override
public List<Embedding> embedAll(List<TextSegment> textSegments) {
boolean queries = containsQueries(textSegments);
if (!queries) {
// default all documents
return embedTexts(textSegments, TextEmbeddingParam.TextType.DOCUMENT);
} else {
boolean documents = containsDocuments(textSegments);
if (!documents) {
return embedTexts(textSegments, TextEmbeddingParam.TextType.QUERY);
} else {
// This is a mixed collection of queries and documents. Embed one by one.
List<Embedding> embeddings = new ArrayList<>(textSegments.size());
for (TextSegment textSegment: textSegments) {
List<Embedding> result;
if (TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().get(TYPE_KEY))) {
result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.QUERY);
} else {
result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.DOCUMENT);
}
embeddings.addAll(result);
}
return embeddings;
}
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String apiKey;
private String modelName;
public Builder apiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}
public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}
protected void ensureOptions() {
if (Utils.isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
}
modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.TEXT_EMBEDDING_V1 : modelName;
}
public QwenEmbeddingModel build() {
ensureOptions();
return new QwenEmbeddingModel(apiKey, modelName);
}
}
}

View File

@ -0,0 +1,53 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.model.language.LanguageModel;
public class QwenLanguageModel extends QwenChatModel implements LanguageModel {
protected QwenLanguageModel(String apiKey,
String modelName,
Double topP,
Double topK,
Boolean enableSearch,
Integer seed) {
super(apiKey, modelName, topP, topK, enableSearch, seed);
}
@Override
public String process(String text) {
return sendMessage(text);
}
public static Builder builder() {
return new Builder();
}
public static class Builder extends QwenChatModel.Builder {
public Builder apiKey(String apiKey) {
return (Builder) super.apiKey(apiKey);
}
public Builder modelName(String modelName) {
return (Builder) super.modelName(modelName);
}
public Builder topP(Double topP) {
return (Builder) super.topP(topP);
}
public Builder topK(Double topK) {
return (Builder) super.topK(topK);
}
public Builder enableSearch(Boolean enableSearch) {
return (Builder) super.enableSearch(enableSearch);
}
public Builder seed(Integer seed) {
return (Builder) super.seed(seed);
}
public QwenLanguageModel build() {
ensureOptions();
return new QwenLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed);
}
}
}

View File

@ -0,0 +1,15 @@
package dev.langchain4j.model.dashscope;
/**
* The LLMs provided by Alibaba Cloud, performs better than most LLMs in Asia languages.
*/
public class QwenModelName {
// Use with QwenChatModel and QwenLanguageModel
public static final String QWEN_V1 = "qwen-v1"; // Qwen base model, 4k context.
public static final String QWEN_PLUS_V1 = "qwen-plus-v1"; // Qwen plus model, 8k context.
public static final String QWEN_7B_CHAT_V1 = "qwen-7b-chat-v1"; // Qwen open sourced 7-billion-parameters version, 4k context.
public static final String QWEN_SPARK_V1 = "qwen-spark-v1"; // Qwen sft for conversation scene, 4k context.
// Use with QwenEmbeddingModel
public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1";
}

View File

@ -0,0 +1,44 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import java.util.List;
import java.util.stream.Collectors;
public class QwenParamHelper {
private static final String SYSTEM_PREFIX = "<|system|>:";
private static final String ASSISTANT_PREFIX = "<|assistant|>:";
private static final String USER_PREFIX = "<|user|>:";
/* Qwen prompt format:
* <|system|>: ...
*
* <|user|>: ...
*
* <|assistant|>: ...
*
* <|user|>: ...
* ...
*/
public static String toQwenPrompt(List<ChatMessage> messages) {
return messages.stream()
.map(QwenParamHelper::toQwenMessage)
.collect(Collectors.joining("\n\n"));
}
public static String toQwenMessage(ChatMessage message) {
return prefixFrom(message) + message.text();
}
public static String prefixFrom(ChatMessage message) {
if (message instanceof AiMessage) {
return ASSISTANT_PREFIX;
} else if (message instanceof SystemMessage) {
return SYSTEM_PREFIX;
} else {
return USER_PREFIX;
}
}
}

View File

@ -0,0 +1,109 @@
package dev.langchain4j.model.dashscope;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.ResultCallback;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import java.util.List;
public class QwenStreamingChatModel extends QwenChatModel implements StreamingChatLanguageModel {
protected QwenStreamingChatModel(String apiKey,
String modelName,
Double topP,
Double topK,
Boolean enableSearch,
Integer seed) {
super(apiKey, modelName, topP, topK, enableSearch, seed);
}
@Override
public void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler) {
sendMessage(QwenParamHelper.toQwenPrompt(messages), handler);
}
protected void sendMessage(String prompt, StreamingResponseHandler handler) {
QwenParam param = QwenParam.builder()
.apiKey(apiKey)
.model(modelName)
.topP(topP)
.topK(topK)
.enableSearch(enableSearch)
.seed(seed)
.prompt(prompt)
.build();
try {
gen.call(param, new ResultCallback<GenerationResult>() {
@Override
public void onEvent(GenerationResult result) {
handler.onNext(result.getOutput().getText());
}
@Override
public void onComplete() {
handler.onComplete();
}
@Override
public void onError(Exception e) {
handler.onError(e);
}
});
} catch (NoApiKeyException | InputRequiredException e) {
throw new RuntimeException(e);
}
}
@Override
public void sendMessages(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
StreamingResponseHandler handler) {
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
}
@Override
public void sendMessages(List<ChatMessage> messages,
ToolSpecification toolSpecification,
StreamingResponseHandler handler) {
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
}
public static Builder builder() {
return new Builder();
}
public static class Builder extends QwenChatModel.Builder {
public Builder apiKey(String apiKey) {
return (Builder) super.apiKey(apiKey);
}
public Builder modelName(String modelName) {
return (Builder) super.modelName(modelName);
}
public Builder topP(Double topP) {
return (Builder) super.topP(topP);
}
public Builder topK(Double topK) {
return (Builder) super.topK(topK);
}
public Builder enableSearch(Boolean enableSearch) {
return (Builder) super.enableSearch(enableSearch);
}
public Builder seed(Integer seed) {
return (Builder) super.seed(seed);
}
public QwenStreamingChatModel build() {
ensureOptions();
return new QwenStreamingChatModel(apiKey, modelName, topP, topK, enableSearch, seed);
}
}
}

View File

@ -0,0 +1,54 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel;
public class QwenStreamingLanguageModel extends QwenStreamingChatModel implements StreamingLanguageModel {
protected QwenStreamingLanguageModel(String apiKey,
String modelName,
Double topP,
Double topK,
Boolean enableSearch,
Integer seed) {
super(apiKey, modelName, topP, topK, enableSearch, seed);
}
@Override
public void process(String text, StreamingResponseHandler handler) {
sendMessage(text, handler);
}
public static Builder builder() {
return new Builder();
}
public static class Builder extends QwenStreamingChatModel.Builder {
public Builder apiKey(String apiKey) {
return (Builder) super.apiKey(apiKey);
}
public Builder modelName(String modelName) {
return (Builder) super.modelName(modelName);
}
public Builder topP(Double topP) {
return (Builder) super.topP(topP);
}
public Builder topK(Double topK) {
return (Builder) super.topK(topK);
}
public Builder enableSearch(Boolean enableSearch) {
return (Builder) super.enableSearch(enableSearch);
}
public Builder seed(Integer seed) {
return (Builder) super.seed(seed);
}
public QwenStreamingLanguageModel build() {
ensureOptions();
return new QwenStreamingLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed);
}
}
}

View File

@ -0,0 +1,27 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import static org.assertj.core.api.Assertions.assertThat;
public class QwenChatModelIT {
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
public void should_send_messages_and_receive_response(String modelName) {
String apiKey = QwenTestHelper.apiKey();
if (Utils.isNullOrBlank(apiKey)) {
return;
}
ChatLanguageModel model = QwenChatModel.builder()
.apiKey(apiKey)
.modelName(modelName)
.build();
AiMessage answer = model.sendMessages(QwenTestHelper.chatMessages());
assertThat(answer.text()).containsIgnoringCase("rain");
}
}

View File

@ -0,0 +1,88 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.List;
import static dev.langchain4j.data.segment.TextSegment.textSegment;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
public class QwenEmbeddingModelIT {
private EmbeddingModel getModel(String modelName) {
String apiKey = QwenTestHelper.apiKey();
if (Utils.isNullOrBlank(apiKey)) {
return null;
}
return QwenEmbeddingModel.builder()
.apiKey(apiKey)
.modelName(modelName)
.build();
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_one_text(String modelName) {
EmbeddingModel model = getModel(modelName);
if (model == null) {
return;
}
Embedding embedding = model.embed("hello");
assertThat(embedding.vector()).isNotEmpty();
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_documents(String modelName) {
EmbeddingModel model = getModel(modelName);
if (model == null) {
return;
}
List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello"),
textSegment("how are you?")
));
assertThat(embeddings).hasSize(2);
assertThat(embeddings.get(0).vector()).isNotEmpty();
assertThat(embeddings.get(1).vector()).isNotEmpty();
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_queries(String modelName) {
EmbeddingModel model = getModel(modelName);
if (model == null) {
return;
}
List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY))
));
assertThat(embeddings).hasSize(2);
assertThat(embeddings.get(0).vector()).isNotEmpty();
assertThat(embeddings.get(1).vector()).isNotEmpty();
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_mix_segments(String modelName) {
EmbeddingModel model = getModel(modelName);
if (model == null) {
return;
}
List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
textSegment("how are you?")
));
assertThat(embeddings).hasSize(2);
assertThat(embeddings.get(0).vector()).isNotEmpty();
assertThat(embeddings.get(1).vector()).isNotEmpty();
}
}

View File

@ -0,0 +1,24 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.language.LanguageModel;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import static org.assertj.core.api.Assertions.assertThat;
public class QwenLanguageModelIT {
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
public void should_send_messages_and_receive_response(String modelName) {
String apiKey = QwenTestHelper.apiKey();
if (Utils.isNullOrBlank(apiKey)) {
return;
}
LanguageModel model = QwenLanguageModel.builder()
.apiKey(apiKey)
.modelName(modelName)
.build();
assertThat(model.process("Please say 'hello' to me")).containsIgnoringCase("hello");
}
}

View File

@ -0,0 +1,53 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
public class QwenStreamingChatModelIT {
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException {
String apiKey = QwenTestHelper.apiKey();
if (Utils.isNullOrBlank(apiKey)) {
return;
}
StreamingChatLanguageModel model = QwenStreamingChatModel.builder()
.apiKey(apiKey)
.modelName(modelName)
.build();
CompletableFuture<String> future = new CompletableFuture<>();
model.sendMessages(
QwenTestHelper.chatMessages(),
new StreamingResponseHandler() {
final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String partialResult) {
answerBuilder.append(partialResult);
System.out.println("onPartialResult: '" + partialResult + "'");
}
@Override
public void onComplete() {
future.complete(answerBuilder.toString());
System.out.println("onComplete");
}
@Override
public void onError(Throwable error) {
future.completeExceptionally(error);
}
});
String answer = future.get(30, SECONDS);
assertThat(answer).containsIgnoringCase("rain");
}
}

View File

@ -0,0 +1,53 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
public class QwenStreamingLanguageModelIT {
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException {
String apiKey = QwenTestHelper.apiKey();
if (Utils.isNullOrBlank(apiKey)) {
return;
}
StreamingLanguageModel model = QwenStreamingLanguageModel.builder()
.apiKey(apiKey)
.modelName(modelName)
.build();
CompletableFuture<String> future = new CompletableFuture<>();
model.process(
"Please say 'hello' to me",
new StreamingResponseHandler() {
final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String partialResult) {
answerBuilder.append(partialResult);
System.out.println("onPartialResult: '" + partialResult + "'");
}
@Override
public void onComplete() {
future.complete(answerBuilder.toString());
System.out.println("onComplete");
}
@Override
public void onError(Throwable error) {
future.completeExceptionally(error);
}
});
String answer = future.get(30, SECONDS);
assertThat(answer).containsIgnoringCase("hello");
}
}

View File

@ -0,0 +1,43 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import org.junit.jupiter.params.provider.Arguments;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Stream;
public class QwenTestHelper {
public static Stream<Arguments> chatModelNameProvider() {
return Stream.of(
Arguments.of(QwenModelName.QWEN_V1),
Arguments.of(QwenModelName.QWEN_PLUS_V1),
Arguments.of(QwenModelName.QWEN_7B_CHAT_V1),
Arguments.of(QwenModelName.QWEN_SPARK_V1)
);
}
public static Stream<Arguments> embeddingModelNameProvider() {
return Stream.of(
Arguments.of(QwenModelName.TEXT_EMBEDDING_V1)
);
}
public static String apiKey() {
return System.getenv("DASHSCOPE_API_KEY");
}
public static List<ChatMessage> chatMessages() {
List<ChatMessage> messages = new LinkedList<>();
messages.add(SystemMessage.from("Your name is Jack." +
" You like to answer other people's questions briefly." +
" It's rainy today."));
messages.add(UserMessage.from("Hello. What's your name?"));
messages.add(AiMessage.from("Jack."));
messages.add(UserMessage.from("How about the weather today?"));
return messages;
}
}

View File

@ -6,7 +6,7 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.20.0</version>
<version>${revision}</version>
<packaging>pom</packaging>
<name>langchain4j parent POM</name>
@ -14,6 +14,7 @@
<url>https://github.com/langchain4j/langchain4j</url>
<properties>
<revision>0.20.0</revision>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.20.0</version>
<version>${revision}</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -23,7 +23,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.20.0</version>
<version>${project.version}</version>
</dependency>
<dependency>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.20.0</version>
<version>${revision}</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -22,7 +22,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>0.20.0</version>
<version>${project.version}</version>
</dependency>
<dependency>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.20.0</version>
<version>${revision}</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -25,7 +25,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.20.0</version>
<version>${project.version}</version>
</dependency>
<dependency>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.20.0</version>
<version>${revision}</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -24,7 +24,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.20.0</version>
<version>${project.version}</version>
</dependency>
<dependency>

10
pom.xml
View File

@ -10,17 +10,21 @@
<packaging>pom</packaging>
<modules>
<module>langchain4j-parent</module>
<module>langchain4j-core</module>
<module>langchain4j</module>
<module>langchain4j-spring-boot-starter</module>
<!-- model providers -->
<module>langchain4j-dashscope</module>
<!-- embedding stores -->
<module>langchain4j-milvus</module>
<module>langchain4j-pinecone</module>
<module>langchain4j-weaviate</module>
<module>langchain4j-spring-boot-starter</module>
<module>langchain4j-milvus</module>
</modules>
</project>