Integrate the Qwen series models via dashscope-sdk. (#99)
Qwen series models are provided by Alibaba Cloud. They are much better in Asia languages then other LLMs. DashScope is a model service platform. Qwen models are its primary supported models. But it also supports other series like LLaMA2, Dolly, ChatGLM, BiLLa(based on LLaMA)...These may be integrated sometime in the future.
This commit is contained in:
parent
ec4a673b52
commit
d908f5158a
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>LangChain4j integration with DashScope</name>
|
||||
<description>It uses the dashscope-sdk-java library, which has an Apache-2.0 license.</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>dashscope-sdk-java</artifactId>
|
||||
<version>2.1.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.honton.chas</groupId>
|
||||
<artifactId>license-maven-plugin</artifactId>
|
||||
<configuration>
|
||||
<!-- The weaviate client has a BSD-3 license, see https://github.com/weaviate/java-client/blob/main/LICENSE -->
|
||||
<skipCompliance>true</skipCompliance>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,137 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
||||
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
|
||||
import com.alibaba.dashscope.exception.InputRequiredException;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class QwenChatModel implements ChatLanguageModel {
|
||||
protected final Generation gen;
|
||||
protected final String apiKey;
|
||||
protected final String modelName;
|
||||
protected final Double topP;
|
||||
protected final Double topK;
|
||||
protected final Boolean enableSearch;
|
||||
protected final Integer seed;
|
||||
|
||||
protected QwenChatModel(String apiKey,
|
||||
String modelName,
|
||||
Double topP,
|
||||
Double topK,
|
||||
Boolean enableSearch,
|
||||
Integer seed) {
|
||||
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
this.topP = topP;
|
||||
this.topK = topK;
|
||||
this.enableSearch = enableSearch;
|
||||
this.seed = seed;
|
||||
gen = new Generation();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiMessage sendMessages(List<ChatMessage> messages) {
|
||||
return AiMessage.aiMessage(sendMessage(QwenParamHelper.toQwenPrompt(messages)));
|
||||
}
|
||||
|
||||
protected String sendMessage(String prompt) {
|
||||
QwenParam param = QwenParam.builder()
|
||||
.apiKey(apiKey)
|
||||
.model(modelName)
|
||||
.topP(topP)
|
||||
.topK(topK)
|
||||
.enableSearch(enableSearch)
|
||||
.seed(seed)
|
||||
.prompt(prompt)
|
||||
.build();
|
||||
|
||||
try {
|
||||
GenerationResult result = gen.call(param);
|
||||
return Optional.of(result)
|
||||
.map(GenerationResult::getOutput)
|
||||
.map(GenerationOutput::getText)
|
||||
.orElse("Oops, something wrong...[request id: " + result.getRequestId() + "]");
|
||||
} catch (NoApiKeyException | InputRequiredException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiMessage sendMessages(List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
public static class Builder {
|
||||
protected String apiKey;
|
||||
protected String modelName;
|
||||
protected Double topP;
|
||||
protected Double topK;
|
||||
protected Boolean enableSearch;
|
||||
protected Integer seed;
|
||||
|
||||
public Builder apiKey(String apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder modelName(String modelName) {
|
||||
this.modelName = modelName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder topP(Double topP) {
|
||||
this.topP = topP;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder topK(Double topK) {
|
||||
this.topK = topK;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder enableSearch(Boolean enableSearch) {
|
||||
this.enableSearch = enableSearch;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder seed(Integer seed) {
|
||||
this.seed = seed;
|
||||
return this;
|
||||
}
|
||||
|
||||
protected void ensureOptions() {
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
|
||||
}
|
||||
modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.QWEN_V1 : modelName;
|
||||
topP = topP == null ? 0.8 : topP;
|
||||
topK = topK == null ? 100.0 : topK;
|
||||
enableSearch = enableSearch == null ? Boolean.FALSE : enableSearch;
|
||||
seed = seed == null ? 1234 : seed;
|
||||
}
|
||||
|
||||
public QwenChatModel build() {
|
||||
ensureOptions();
|
||||
return new QwenChatModel(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import com.alibaba.dashscope.embeddings.*;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class QwenEmbeddingModel implements EmbeddingModel {
|
||||
public static final String TYPE_KEY = "type";
|
||||
public static final String TYPE_QUERY = "query";
|
||||
public static final String TYPE_DOCUMENT = "document";
|
||||
private final String apiKey;
|
||||
private final String modelName;
|
||||
private final TextEmbedding embedding;
|
||||
|
||||
protected QwenEmbeddingModel(String apiKey, String modelName) {
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
embedding = new TextEmbedding();
|
||||
}
|
||||
|
||||
private boolean containsDocuments(List<TextSegment> textSegments) {
|
||||
return textSegments.stream()
|
||||
.map(TextSegment::metadata)
|
||||
.map(metadata -> metadata.get(TYPE_KEY))
|
||||
.filter(TYPE_DOCUMENT::equalsIgnoreCase)
|
||||
.anyMatch(Utils::isNullOrBlank);
|
||||
}
|
||||
|
||||
private boolean containsQueries(List<TextSegment> textSegments) {
|
||||
return textSegments.stream()
|
||||
.map(TextSegment::metadata)
|
||||
.map(metadata -> metadata.get(TYPE_KEY))
|
||||
.filter(TYPE_QUERY::equalsIgnoreCase)
|
||||
.anyMatch(Utils::isNullOrBlank);
|
||||
}
|
||||
|
||||
private List<Embedding> embedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
|
||||
TextEmbeddingParam param = TextEmbeddingParam.builder()
|
||||
.apiKey(apiKey)
|
||||
.model(modelName)
|
||||
.textType(textType)
|
||||
.texts(textSegments.stream()
|
||||
.map(TextSegment::text)
|
||||
.collect(Collectors.toList()))
|
||||
.build();
|
||||
try {
|
||||
return Optional.of(embedding.call(param))
|
||||
.map(TextEmbeddingResult::getOutput)
|
||||
.map(TextEmbeddingOutput::getEmbeddings)
|
||||
.orElse(Collections.emptyList())
|
||||
.stream()
|
||||
.map(TextEmbeddingResultItem::getEmbedding)
|
||||
.map(doubleList -> doubleList.stream().map(Double::floatValue).collect(Collectors.toList()))
|
||||
.map(Embedding::from)
|
||||
.collect(Collectors.toList());
|
||||
} catch (NoApiKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Embedding> embedAll(List<TextSegment> textSegments) {
|
||||
boolean queries = containsQueries(textSegments);
|
||||
|
||||
if (!queries) {
|
||||
// default all documents
|
||||
return embedTexts(textSegments, TextEmbeddingParam.TextType.DOCUMENT);
|
||||
} else {
|
||||
boolean documents = containsDocuments(textSegments);
|
||||
if (!documents) {
|
||||
return embedTexts(textSegments, TextEmbeddingParam.TextType.QUERY);
|
||||
} else {
|
||||
// This is a mixed collection of queries and documents. Embed one by one.
|
||||
List<Embedding> embeddings = new ArrayList<>(textSegments.size());
|
||||
for (TextSegment textSegment: textSegments) {
|
||||
List<Embedding> result;
|
||||
if (TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().get(TYPE_KEY))) {
|
||||
result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.QUERY);
|
||||
} else {
|
||||
result = embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.DOCUMENT);
|
||||
}
|
||||
embeddings.addAll(result);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
public static class Builder {
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
|
||||
public Builder apiKey(String apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder modelName(String modelName) {
|
||||
this.modelName = modelName;
|
||||
return this;
|
||||
}
|
||||
|
||||
protected void ensureOptions() {
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
|
||||
}
|
||||
modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.TEXT_EMBEDDING_V1 : modelName;
|
||||
}
|
||||
|
||||
public QwenEmbeddingModel build() {
|
||||
ensureOptions();
|
||||
return new QwenEmbeddingModel(apiKey, modelName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.model.language.LanguageModel;
|
||||
|
||||
public class QwenLanguageModel extends QwenChatModel implements LanguageModel {
|
||||
protected QwenLanguageModel(String apiKey,
|
||||
String modelName,
|
||||
Double topP,
|
||||
Double topK,
|
||||
Boolean enableSearch,
|
||||
Integer seed) {
|
||||
super(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
@Override
|
||||
public String process(String text) {
|
||||
return sendMessage(text);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder extends QwenChatModel.Builder {
|
||||
public Builder apiKey(String apiKey) {
|
||||
return (Builder) super.apiKey(apiKey);
|
||||
}
|
||||
|
||||
public Builder modelName(String modelName) {
|
||||
return (Builder) super.modelName(modelName);
|
||||
}
|
||||
|
||||
public Builder topP(Double topP) {
|
||||
return (Builder) super.topP(topP);
|
||||
}
|
||||
|
||||
public Builder topK(Double topK) {
|
||||
return (Builder) super.topK(topK);
|
||||
}
|
||||
|
||||
public Builder enableSearch(Boolean enableSearch) {
|
||||
return (Builder) super.enableSearch(enableSearch);
|
||||
}
|
||||
|
||||
public Builder seed(Integer seed) {
|
||||
return (Builder) super.seed(seed);
|
||||
}
|
||||
|
||||
public QwenLanguageModel build() {
|
||||
ensureOptions();
|
||||
return new QwenLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
/**
|
||||
* The LLMs provided by Alibaba Cloud, performs better than most LLMs in Asia languages.
|
||||
*/
|
||||
public class QwenModelName {
|
||||
// Use with QwenChatModel and QwenLanguageModel
|
||||
public static final String QWEN_V1 = "qwen-v1"; // Qwen base model, 4k context.
|
||||
public static final String QWEN_PLUS_V1 = "qwen-plus-v1"; // Qwen plus model, 8k context.
|
||||
public static final String QWEN_7B_CHAT_V1 = "qwen-7b-chat-v1"; // Qwen open sourced 7-billion-parameters version, 4k context.
|
||||
public static final String QWEN_SPARK_V1 = "qwen-spark-v1"; // Qwen sft for conversation scene, 4k context.
|
||||
|
||||
// Use with QwenEmbeddingModel
|
||||
public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1";
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class QwenParamHelper {
|
||||
private static final String SYSTEM_PREFIX = "<|system|>:";
|
||||
private static final String ASSISTANT_PREFIX = "<|assistant|>:";
|
||||
private static final String USER_PREFIX = "<|user|>:";
|
||||
|
||||
/* Qwen prompt format:
|
||||
* <|system|>: ...
|
||||
*
|
||||
* <|user|>: ...
|
||||
*
|
||||
* <|assistant|>: ...
|
||||
*
|
||||
* <|user|>: ...
|
||||
* ...
|
||||
*/
|
||||
public static String toQwenPrompt(List<ChatMessage> messages) {
|
||||
return messages.stream()
|
||||
.map(QwenParamHelper::toQwenMessage)
|
||||
.collect(Collectors.joining("\n\n"));
|
||||
}
|
||||
|
||||
public static String toQwenMessage(ChatMessage message) {
|
||||
return prefixFrom(message) + message.text();
|
||||
}
|
||||
|
||||
public static String prefixFrom(ChatMessage message) {
|
||||
if (message instanceof AiMessage) {
|
||||
return ASSISTANT_PREFIX;
|
||||
} else if (message instanceof SystemMessage) {
|
||||
return SYSTEM_PREFIX;
|
||||
} else {
|
||||
return USER_PREFIX;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
||||
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
|
||||
import com.alibaba.dashscope.common.ResultCallback;
|
||||
import com.alibaba.dashscope.exception.InputRequiredException;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class QwenStreamingChatModel extends QwenChatModel implements StreamingChatLanguageModel {
|
||||
protected QwenStreamingChatModel(String apiKey,
|
||||
String modelName,
|
||||
Double topP,
|
||||
Double topK,
|
||||
Boolean enableSearch,
|
||||
Integer seed) {
|
||||
super(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler) {
|
||||
sendMessage(QwenParamHelper.toQwenPrompt(messages), handler);
|
||||
}
|
||||
|
||||
protected void sendMessage(String prompt, StreamingResponseHandler handler) {
|
||||
QwenParam param = QwenParam.builder()
|
||||
.apiKey(apiKey)
|
||||
.model(modelName)
|
||||
.topP(topP)
|
||||
.topK(topK)
|
||||
.enableSearch(enableSearch)
|
||||
.seed(seed)
|
||||
.prompt(prompt)
|
||||
.build();
|
||||
|
||||
try {
|
||||
gen.call(param, new ResultCallback<GenerationResult>() {
|
||||
@Override
|
||||
public void onEvent(GenerationResult result) {
|
||||
handler.onNext(result.getOutput().getText());
|
||||
}
|
||||
@Override
|
||||
public void onComplete() {
|
||||
handler.onComplete();
|
||||
}
|
||||
@Override
|
||||
public void onError(Exception e) {
|
||||
handler.onError(e);
|
||||
}
|
||||
});
|
||||
} catch (NoApiKeyException | InputRequiredException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendMessages(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
StreamingResponseHandler handler) {
|
||||
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendMessages(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification,
|
||||
StreamingResponseHandler handler) {
|
||||
throw new IllegalArgumentException("Tools are currently not supported for qwen models");
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder extends QwenChatModel.Builder {
|
||||
public Builder apiKey(String apiKey) {
|
||||
return (Builder) super.apiKey(apiKey);
|
||||
}
|
||||
|
||||
public Builder modelName(String modelName) {
|
||||
return (Builder) super.modelName(modelName);
|
||||
}
|
||||
|
||||
public Builder topP(Double topP) {
|
||||
return (Builder) super.topP(topP);
|
||||
}
|
||||
|
||||
public Builder topK(Double topK) {
|
||||
return (Builder) super.topK(topK);
|
||||
}
|
||||
|
||||
public Builder enableSearch(Boolean enableSearch) {
|
||||
return (Builder) super.enableSearch(enableSearch);
|
||||
}
|
||||
|
||||
public Builder seed(Integer seed) {
|
||||
return (Builder) super.seed(seed);
|
||||
}
|
||||
|
||||
public QwenStreamingChatModel build() {
|
||||
ensureOptions();
|
||||
return new QwenStreamingChatModel(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.language.StreamingLanguageModel;
|
||||
|
||||
public class QwenStreamingLanguageModel extends QwenStreamingChatModel implements StreamingLanguageModel {
|
||||
protected QwenStreamingLanguageModel(String apiKey,
|
||||
String modelName,
|
||||
Double topP,
|
||||
Double topK,
|
||||
Boolean enableSearch,
|
||||
Integer seed) {
|
||||
super(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String text, StreamingResponseHandler handler) {
|
||||
sendMessage(text, handler);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder extends QwenStreamingChatModel.Builder {
|
||||
public Builder apiKey(String apiKey) {
|
||||
return (Builder) super.apiKey(apiKey);
|
||||
}
|
||||
|
||||
public Builder modelName(String modelName) {
|
||||
return (Builder) super.modelName(modelName);
|
||||
}
|
||||
|
||||
public Builder topP(Double topP) {
|
||||
return (Builder) super.topP(topP);
|
||||
}
|
||||
|
||||
public Builder topK(Double topK) {
|
||||
return (Builder) super.topK(topK);
|
||||
}
|
||||
|
||||
public Builder enableSearch(Boolean enableSearch) {
|
||||
return (Builder) super.enableSearch(enableSearch);
|
||||
}
|
||||
|
||||
public Builder seed(Integer seed) {
|
||||
return (Builder) super.seed(seed);
|
||||
}
|
||||
public QwenStreamingLanguageModel build() {
|
||||
ensureOptions();
|
||||
return new QwenStreamingLanguageModel(apiKey, modelName, topP, topK, enableSearch, seed);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class QwenChatModelIT {
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
|
||||
public void should_send_messages_and_receive_response(String modelName) {
|
||||
String apiKey = QwenTestHelper.apiKey();
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
return;
|
||||
}
|
||||
ChatLanguageModel model = QwenChatModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
AiMessage answer = model.sendMessages(QwenTestHelper.chatMessages());
|
||||
assertThat(answer.text()).containsIgnoringCase("rain");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.data.segment.TextSegment.textSegment;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class QwenEmbeddingModelIT {
|
||||
private EmbeddingModel getModel(String modelName) {
|
||||
String apiKey = QwenTestHelper.apiKey();
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
return null;
|
||||
}
|
||||
return QwenEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
}
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||
void should_embed_one_text(String modelName) {
|
||||
EmbeddingModel model = getModel(modelName);
|
||||
if (model == null) {
|
||||
return;
|
||||
}
|
||||
Embedding embedding = model.embed("hello");
|
||||
assertThat(embedding.vector()).isNotEmpty();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||
void should_embed_documents(String modelName) {
|
||||
EmbeddingModel model = getModel(modelName);
|
||||
if (model == null) {
|
||||
return;
|
||||
}
|
||||
List<Embedding> embeddings = model.embedAll(asList(
|
||||
textSegment("hello"),
|
||||
textSegment("how are you?")
|
||||
));
|
||||
|
||||
assertThat(embeddings).hasSize(2);
|
||||
assertThat(embeddings.get(0).vector()).isNotEmpty();
|
||||
assertThat(embeddings.get(1).vector()).isNotEmpty();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||
void should_embed_queries(String modelName) {
|
||||
EmbeddingModel model = getModel(modelName);
|
||||
if (model == null) {
|
||||
return;
|
||||
}
|
||||
List<Embedding> embeddings = model.embedAll(asList(
|
||||
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
|
||||
textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY))
|
||||
));
|
||||
|
||||
assertThat(embeddings).hasSize(2);
|
||||
assertThat(embeddings.get(0).vector()).isNotEmpty();
|
||||
assertThat(embeddings.get(1).vector()).isNotEmpty();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
|
||||
void should_embed_mix_segments(String modelName) {
|
||||
EmbeddingModel model = getModel(modelName);
|
||||
if (model == null) {
|
||||
return;
|
||||
}
|
||||
List<Embedding> embeddings = model.embedAll(asList(
|
||||
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
|
||||
textSegment("how are you?")
|
||||
));
|
||||
|
||||
assertThat(embeddings).hasSize(2);
|
||||
assertThat(embeddings.get(0).vector()).isNotEmpty();
|
||||
assertThat(embeddings.get(1).vector()).isNotEmpty();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.language.LanguageModel;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class QwenLanguageModelIT {
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
|
||||
public void should_send_messages_and_receive_response(String modelName) {
|
||||
String apiKey = QwenTestHelper.apiKey();
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
return;
|
||||
}
|
||||
LanguageModel model = QwenLanguageModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
assertThat(model.process("Please say 'hello' to me")).containsIgnoringCase("hello");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class QwenStreamingChatModelIT {
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
|
||||
public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException {
|
||||
String apiKey = QwenTestHelper.apiKey();
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
return;
|
||||
}
|
||||
StreamingChatLanguageModel model = QwenStreamingChatModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
|
||||
CompletableFuture<String> future = new CompletableFuture<>();
|
||||
model.sendMessages(
|
||||
QwenTestHelper.chatMessages(),
|
||||
new StreamingResponseHandler() {
|
||||
final StringBuilder answerBuilder = new StringBuilder();
|
||||
@Override
|
||||
public void onNext(String partialResult) {
|
||||
answerBuilder.append(partialResult);
|
||||
System.out.println("onPartialResult: '" + partialResult + "'");
|
||||
}
|
||||
@Override
|
||||
public void onComplete() {
|
||||
future.complete(answerBuilder.toString());
|
||||
System.out.println("onComplete");
|
||||
}
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
future.completeExceptionally(error);
|
||||
}
|
||||
});
|
||||
|
||||
String answer = future.get(30, SECONDS);
|
||||
assertThat(answer).containsIgnoringCase("rain");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.language.StreamingLanguageModel;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class QwenStreamingLanguageModelIT {
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#chatModelNameProvider")
|
||||
public void should_send_messages_and_receive_response(String modelName) throws ExecutionException, InterruptedException, TimeoutException {
|
||||
String apiKey = QwenTestHelper.apiKey();
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
return;
|
||||
}
|
||||
StreamingLanguageModel model = QwenStreamingLanguageModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
|
||||
CompletableFuture<String> future = new CompletableFuture<>();
|
||||
model.process(
|
||||
"Please say 'hello' to me",
|
||||
new StreamingResponseHandler() {
|
||||
final StringBuilder answerBuilder = new StringBuilder();
|
||||
@Override
|
||||
public void onNext(String partialResult) {
|
||||
answerBuilder.append(partialResult);
|
||||
System.out.println("onPartialResult: '" + partialResult + "'");
|
||||
}
|
||||
@Override
|
||||
public void onComplete() {
|
||||
future.complete(answerBuilder.toString());
|
||||
System.out.println("onComplete");
|
||||
}
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
future.completeExceptionally(error);
|
||||
}
|
||||
});
|
||||
|
||||
String answer = future.get(30, SECONDS);
|
||||
assertThat(answer).containsIgnoringCase("hello");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class QwenTestHelper {
|
||||
public static Stream<Arguments> chatModelNameProvider() {
|
||||
return Stream.of(
|
||||
Arguments.of(QwenModelName.QWEN_V1),
|
||||
Arguments.of(QwenModelName.QWEN_PLUS_V1),
|
||||
Arguments.of(QwenModelName.QWEN_7B_CHAT_V1),
|
||||
Arguments.of(QwenModelName.QWEN_SPARK_V1)
|
||||
);
|
||||
}
|
||||
|
||||
public static Stream<Arguments> embeddingModelNameProvider() {
|
||||
return Stream.of(
|
||||
Arguments.of(QwenModelName.TEXT_EMBEDDING_V1)
|
||||
);
|
||||
}
|
||||
|
||||
public static String apiKey() {
|
||||
return System.getenv("DASHSCOPE_API_KEY");
|
||||
}
|
||||
|
||||
public static List<ChatMessage> chatMessages() {
|
||||
List<ChatMessage> messages = new LinkedList<>();
|
||||
messages.add(SystemMessage.from("Your name is Jack." +
|
||||
" You like to answer other people's questions briefly." +
|
||||
" It's rainy today."));
|
||||
messages.add(UserMessage.from("Hello. What's your name?"));
|
||||
messages.add(AiMessage.from("Jack."));
|
||||
messages.add(UserMessage.from("How about the weather today?"));
|
||||
return messages;
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${revision}</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>langchain4j parent POM</name>
|
||||
|
@ -14,6 +14,7 @@
|
|||
<url>https://github.com/langchain4j/langchain4j</url>
|
||||
|
||||
<properties>
|
||||
<revision>0.20.0</revision>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -23,7 +23,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -22,7 +22,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -25,7 +25,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${revision}</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -24,7 +24,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.20.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
12
pom.xml
12
pom.xml
|
@ -10,17 +10,21 @@
|
|||
<packaging>pom</packaging>
|
||||
|
||||
<modules>
|
||||
|
||||
<module>langchain4j-parent</module>
|
||||
|
||||
<module>langchain4j-core</module>
|
||||
<module>langchain4j</module>
|
||||
|
||||
<module>langchain4j-pinecone</module>
|
||||
<module>langchain4j-weaviate</module>
|
||||
|
||||
<module>langchain4j-spring-boot-starter</module>
|
||||
|
||||
<!-- model providers -->
|
||||
<module>langchain4j-dashscope</module>
|
||||
|
||||
<!-- embedding stores -->
|
||||
<module>langchain4j-milvus</module>
|
||||
<module>langchain4j-pinecone</module>
|
||||
<module>langchain4j-weaviate</module>
|
||||
|
||||
</modules>
|
||||
|
||||
</project>
|
Loading…
Reference in New Issue