DashScope: Support Wanx Models (for text-generated images) (#1710)
## Change Alibaba uses Wanx models to support text-to-image features (not Qwen), and provides services on DashScope. See: https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-wanxiang Integrate them into langchain4j-dashscope as ImageModel. ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green
This commit is contained in:
parent
4ee7b8af8b
commit
c14c86c408
|
@ -10,7 +10,7 @@ sidebar_position: 0
|
|||
| [Anthropic](/integrations/language-models/anthropic) | ✅ | ✅ | text, image | | ✅ |
|
||||
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | text, image | | |
|
||||
| [ChatGLM](/integrations/language-models/chatglm) | | | text | | |
|
||||
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | text, image | | |
|
||||
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | text, image, audio | | |
|
||||
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | text, image, audio, video, PDF | | |
|
||||
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | text, image, audio, video, PDF | | |
|
||||
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | text | | ✅ |
|
||||
|
|
|
@ -198,7 +198,7 @@ class QwenHelper {
|
|||
}
|
||||
}
|
||||
|
||||
private static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
|
||||
static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
|
||||
String tmpDir = System.getProperty("java.io.tmpdir", "/tmp");
|
||||
String tmpFileName = UUID.randomUUID().toString();
|
||||
if (Utils.isNotNullOrBlank(mimeType)) {
|
||||
|
|
|
@ -24,7 +24,7 @@ public class QwenModelName {
|
|||
public static final String QWEN_VL_PLUS = "qwen-vl-plus"; // Qwen multi-modal model, supports image and text information.
|
||||
public static final String QWEN_VL_MAX = "qwen-vl-max"; // Qwen multi-modal model, offers optimal performance on a wider range of complex tasks.
|
||||
public static final String QWEN_AUDIO_CHAT = "qwen-audio-chat"; // Qwen open sourced speech model, sft for chatting.
|
||||
public static final String QWEN2_AUDIO_INSTRUCT = "qwen2-audio-instruct"; // Qwen open sourced speech model (v2), sft for instruction
|
||||
public static final String QWEN2_AUDIO_INSTRUCT = "qwen2-audio-instruct"; // Qwen open sourced speech model (v2)
|
||||
|
||||
// Use with QwenEmbeddingModel
|
||||
public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1"; // Support: en, zh, es, fr, pt, id
|
||||
|
|
|
@ -9,6 +9,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
|||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.dashscope.spi.QwenTokenizerBuilderFactory;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collections;
|
||||
|
@ -16,6 +17,7 @@ import java.util.Collections;
|
|||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.model.dashscope.QwenHelper.toQwenMessages;
|
||||
import static dev.langchain4j.model.dashscope.QwenModelName.QWEN_PLUS;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
|
||||
public class QwenTokenizer implements Tokenizer {
|
||||
private final String apiKey;
|
||||
|
@ -94,4 +96,18 @@ public class QwenTokenizer implements Tokenizer {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public static QwenTokenizer.QwenTokenizerBuilder builder() {
|
||||
for (QwenTokenizerBuilderFactory factory : loadFactories(QwenTokenizerBuilderFactory.class)) {
|
||||
return factory.get();
|
||||
}
|
||||
return new QwenTokenizer.QwenTokenizerBuilder();
|
||||
}
|
||||
|
||||
public static class QwenTokenizerBuilder {
|
||||
public QwenTokenizerBuilder() {
|
||||
// This is public so it can be extended
|
||||
// By default with Lombok it becomes package private
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisOutput;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import com.alibaba.dashscope.utils.OSSUtils;
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class WanxHelper {
|
||||
static List<Image> imagesFrom(ImageSynthesisResult result) {
|
||||
return Optional.of(result)
|
||||
.map(ImageSynthesisResult::getOutput)
|
||||
.map(ImageSynthesisOutput::getResults)
|
||||
.orElse(Collections.emptyList())
|
||||
.stream()
|
||||
.map(resultMap -> resultMap.get("url"))
|
||||
.map(url -> Image.builder().url(url).build())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
static String imageUrl(Image image, String model, String apiKey) {
|
||||
String imageUrl;
|
||||
|
||||
if (image.url() != null) {
|
||||
imageUrl = image.url().toString();
|
||||
} else if (Utils.isNotNullOrBlank(image.base64Data())) {
|
||||
String filePath = saveDataAsTemporaryFile(image.base64Data(), image.mimeType());
|
||||
try {
|
||||
imageUrl = OSSUtils.upload(model, filePath, apiKey);
|
||||
} catch (NoApiKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Failed to get image url from " + image);
|
||||
}
|
||||
|
||||
return imageUrl;
|
||||
}
|
||||
|
||||
static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
|
||||
String tmpDir = System.getProperty("java.io.tmpdir", "/tmp");
|
||||
String tmpFileName = UUID.randomUUID().toString();
|
||||
if (Utils.isNotNullOrBlank(mimeType)) {
|
||||
// e.g. "image/png", "image/jpeg"...
|
||||
int lastSlashIndex = mimeType.lastIndexOf("/");
|
||||
if (lastSlashIndex >= 0 && lastSlashIndex < mimeType.length() - 1) {
|
||||
String fileSuffix = mimeType.substring(lastSlashIndex + 1);
|
||||
tmpFileName = tmpFileName + "." + fileSuffix;
|
||||
}
|
||||
}
|
||||
|
||||
Path tmpFilePath = Paths.get(tmpDir, tmpFileName);
|
||||
byte[] data = Base64.getDecoder().decode(base64Data);
|
||||
try {
|
||||
Files.copy(new ByteArrayInputStream(data), tmpFilePath, StandardCopyOption.REPLACE_EXISTING);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return tmpFilePath.toAbsolutePath().toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.dashscope.spi.WanxImageModelBuilderFactory;
|
||||
import dev.langchain4j.model.image.ImageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.model.dashscope.WanxHelper.imageUrl;
|
||||
import static dev.langchain4j.model.dashscope.WanxHelper.imagesFrom;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
|
||||
/**
|
||||
* Represents a Wanx models to generate artistic images.
|
||||
* More details are available <a href="https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9">here</a>.
|
||||
*/
|
||||
public class WanxImageModel implements ImageModel {
|
||||
private final String apiKey;
|
||||
private final String modelName;
|
||||
// The generation method of the reference image. The optional values are
|
||||
// 'repaint' and 'refonly'; repaint represents the reference content and
|
||||
// refonly represents the reference style. Default is 'repaint'.
|
||||
private final WanxImageRefMode refMode;
|
||||
// The similarity between the expected output result and the reference image,
|
||||
// the value range is [0.0, 1.0]. The larger the number, the more similar the
|
||||
// generated result is to the reference image. Default is 0.5.
|
||||
private final Float refStrength;
|
||||
private final Integer seed;
|
||||
// The resolution of the generated image currently only supports '1024*1024',
|
||||
// '720*1280', and '1280*720' resolutions. Default is '1024*1024'.
|
||||
private final WanxImageSize size;
|
||||
private final WanxImageStyle style;
|
||||
private final ImageSynthesis imageSynthesis;
|
||||
|
||||
@Builder
|
||||
public WanxImageModel(String baseUrl,
|
||||
String apiKey,
|
||||
String modelName,
|
||||
WanxImageRefMode refMode,
|
||||
Float refStrength,
|
||||
Integer seed,
|
||||
WanxImageSize size,
|
||||
WanxImageStyle style) {
|
||||
if (Utils.isNullOrBlank(apiKey)) {
|
||||
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
|
||||
}
|
||||
this.modelName = Utils.isNullOrBlank(modelName) ? WanxModelName.WANX_V1 : modelName;
|
||||
this.apiKey = apiKey;
|
||||
this.refMode = refMode;
|
||||
this.refStrength = refStrength;
|
||||
this.seed = seed;
|
||||
this.size = size;
|
||||
this.style = style;
|
||||
this.imageSynthesis = Utils.isNullOrBlank(baseUrl) ? new ImageSynthesis() : new ImageSynthesis("text2image", baseUrl);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<Image> generate(String prompt) {
|
||||
ImageSynthesisParam param = requestBuilder(prompt).n(1).build();
|
||||
|
||||
try {
|
||||
ImageSynthesisResult result = imageSynthesis.call(param);
|
||||
return Response.from(imagesFrom(result).get(0));
|
||||
} catch (NoApiKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<List<Image>> generate(String prompt, int n) {
|
||||
ImageSynthesisParam param = requestBuilder(prompt).n(n).build();
|
||||
|
||||
try {
|
||||
ImageSynthesisResult result = imageSynthesis.call(param);
|
||||
return Response.from(imagesFrom(result));
|
||||
} catch (NoApiKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<Image> edit(Image image, String prompt) {
|
||||
String imageUrl = imageUrl(image, modelName, apiKey);
|
||||
|
||||
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = requestBuilder(prompt)
|
||||
.refImage(imageUrl)
|
||||
.n(1);
|
||||
|
||||
if (imageUrl.startsWith("oss://")) {
|
||||
builder.header("X-DashScope-OssResourceResolve", "enable");
|
||||
}
|
||||
|
||||
try {
|
||||
ImageSynthesisResult result = imageSynthesis.call(builder.build());
|
||||
return Response.from(imagesFrom(result).get(0));
|
||||
} catch (NoApiKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> requestBuilder(String prompt) {
|
||||
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = ImageSynthesisParam.builder()
|
||||
.apiKey(apiKey)
|
||||
.model(modelName)
|
||||
.prompt(prompt);
|
||||
|
||||
if (seed != null) {
|
||||
builder.seed(seed);
|
||||
}
|
||||
|
||||
if (size != null) {
|
||||
builder.size(size.toString());
|
||||
}
|
||||
|
||||
if (style != null) {
|
||||
builder.style(style.toString());
|
||||
}
|
||||
|
||||
if (refMode != null) {
|
||||
builder.parameter("ref_mode", refMode.toString());
|
||||
}
|
||||
|
||||
if (refStrength != null) {
|
||||
builder.parameter("ref_strength", refStrength);
|
||||
}
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
public static WanxImageModel.WanxImageModelBuilder builder() {
|
||||
for (WanxImageModelBuilderFactory factory : loadFactories(WanxImageModelBuilderFactory.class)) {
|
||||
return factory.get();
|
||||
}
|
||||
return new WanxImageModel.WanxImageModelBuilder();
|
||||
}
|
||||
|
||||
public static class WanxImageModelBuilder {
|
||||
public WanxImageModelBuilder() {
|
||||
// This is public so it can be extended
|
||||
// By default with Lombok it becomes package private
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
public enum WanxImageRefMode {
|
||||
REPAINT("repaint"),
|
||||
REFONLY("refonly");
|
||||
|
||||
private final String mode;
|
||||
|
||||
WanxImageRefMode(String mode) {
|
||||
this.mode = mode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return mode;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
public enum WanxImageSize {
|
||||
SIZE_1024_1024("1024*1024"),
|
||||
SIZE_720_1280("720*1280"),
|
||||
SIZE_1280_720("1280*720");
|
||||
|
||||
private final String size;
|
||||
|
||||
WanxImageSize(String size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return size;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
public enum WanxImageStyle {
|
||||
PHOTOGRAPHY("<photography>"),
|
||||
PORTRAIT("<portrait>"),
|
||||
CARTOON_3D("<3d cartoon>"),
|
||||
ANIME("<anime>"),
|
||||
OIL_PAINTING("<oil painting>"),
|
||||
WATERCOLOR("<watercolor>"),
|
||||
SKETCH("<sketch>"),
|
||||
CHINESE_PAINTING("<chinese painting>"),
|
||||
FLAT_ILLUSTRATION("<flat illustration>"),
|
||||
AUTO("<auto>");
|
||||
|
||||
private final String style;
|
||||
|
||||
WanxImageStyle(String style) {
|
||||
this.style = style;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return style;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
public class WanxModelName {
|
||||
// Use with WanxImageModel
|
||||
public static final String WANX_V1 = "wanx-v1"; // Wanx model for text-generated images, supports Chinese and English
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package dev.langchain4j.model.dashscope.spi;
|
||||
|
||||
import dev.langchain4j.model.dashscope.QwenTokenizer;
|
||||
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public interface QwenTokenizerBuilderFactory extends Supplier<QwenTokenizer.QwenTokenizerBuilder> {
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package dev.langchain4j.model.dashscope.spi;
|
||||
|
||||
import dev.langchain4j.model.dashscope.WanxImageModel;
|
||||
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public interface WanxImageModelBuilderFactory extends Supplier<WanxImageModel.WanxImageModelBuilder> {
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.net.URI;
|
||||
|
||||
import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey;
|
||||
import static dev.langchain4j.model.dashscope.QwenTestHelper.multimodalImageData;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+")
|
||||
public class WanxImageModelIT {
|
||||
Logger log = LoggerFactory.getLogger(WanxImageModelIT.class);
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.WanxTestHelper#imageModelNameProvider")
|
||||
void simple_image_generation_works(String modelName) {
|
||||
WanxImageModel model = WanxImageModel.builder()
|
||||
.apiKey(apiKey())
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
|
||||
Response<Image> response = model.generate("Beautiful house on country side");
|
||||
|
||||
URI remoteImage = response.content().url();
|
||||
log.info("Your remote image is here: {}", remoteImage);
|
||||
assertThat(remoteImage).isNotNull();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.WanxTestHelper#imageModelNameProvider")
|
||||
void simple_image_edition_works_by_url(String modelName) {
|
||||
WanxImageModel model = WanxImageModel.builder()
|
||||
.apiKey(apiKey())
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
|
||||
Image image = Image.builder()
|
||||
.url("https://img.alicdn.com/imgextra/i4/O1CN01K1DWat25own2MuQgF_!!6000000007574-0-tps-128-128.jpg")
|
||||
.build();
|
||||
|
||||
Response<Image> response = model.edit(image, "Change the parrot's feathers with yellow");
|
||||
|
||||
URI remoteImage = response.content().url();
|
||||
log.info("Your remote image is here: {}", remoteImage);
|
||||
assertThat(remoteImage).isNotNull();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("dev.langchain4j.model.dashscope.WanxTestHelper#imageModelNameProvider")
|
||||
void simple_image_edition_works_by_data(String modelName) {
|
||||
WanxImageModel model = WanxImageModel.builder()
|
||||
.apiKey(apiKey())
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
|
||||
Image image = Image.builder()
|
||||
.base64Data(multimodalImageData())
|
||||
.mimeType("image/jpg")
|
||||
.build();
|
||||
|
||||
Response<Image> response = model.edit(image, "Change the parrot's feathers with yellow");
|
||||
|
||||
URI remoteImage = response.content().url();
|
||||
log.info("Your remote image is here: {}", remoteImage);
|
||||
assertThat(remoteImage).isNotNull();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package dev.langchain4j.model.dashscope;
|
||||
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class WanxTestHelper {
|
||||
public static Stream<Arguments> imageModelNameProvider() {
|
||||
return Stream.of(
|
||||
Arguments.of(WanxModelName.WANX_V1)
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue