DashScope: Support Wanx Models (for text-generated images) (#1710)

## Change
Alibaba uses Wanx models to support text-to-image features (not Qwen),
and provides services on DashScope.
    
See:
https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-wanxiang
    
Integrate them into langchain4j-dashscope as ImageModel.



## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
This commit is contained in:
jiangsier-xyz 2024-09-05 15:22:30 +08:00 committed by GitHub
parent 4ee7b8af8b
commit c14c86c408
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 409 additions and 3 deletions

View File

@ -10,7 +10,7 @@ sidebar_position: 0
| [Anthropic](/integrations/language-models/anthropic) | ✅ | ✅ | text, image | | ✅ |
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | text, image | | |
| [ChatGLM](/integrations/language-models/chatglm) | | | text | | |
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | text, image | | |
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | text, image, audio | | |
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | text, image, audio, video, PDF | | |
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | text, image, audio, video, PDF | | |
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | text | | ✅ |

View File

@ -198,7 +198,7 @@ class QwenHelper {
}
}
private static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
String tmpDir = System.getProperty("java.io.tmpdir", "/tmp");
String tmpFileName = UUID.randomUUID().toString();
if (Utils.isNotNullOrBlank(mimeType)) {

View File

@ -24,7 +24,7 @@ public class QwenModelName {
public static final String QWEN_VL_PLUS = "qwen-vl-plus"; // Qwen multi-modal model, supports image and text information.
public static final String QWEN_VL_MAX = "qwen-vl-max"; // Qwen multi-modal model, offers optimal performance on a wider range of complex tasks.
public static final String QWEN_AUDIO_CHAT = "qwen-audio-chat"; // Qwen open sourced speech model, sft for chatting.
public static final String QWEN2_AUDIO_INSTRUCT = "qwen2-audio-instruct"; // Qwen open sourced speech model (v2), sft for instruction
public static final String QWEN2_AUDIO_INSTRUCT = "qwen2-audio-instruct"; // Qwen open sourced speech model (v2)
// Use with QwenEmbeddingModel
public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1"; // Support: en, zh, es, fr, pt, id

View File

@ -9,6 +9,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.dashscope.spi.QwenTokenizerBuilderFactory;
import lombok.Builder;
import java.util.Collections;
@ -16,6 +17,7 @@ import java.util.Collections;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.model.dashscope.QwenHelper.toQwenMessages;
import static dev.langchain4j.model.dashscope.QwenModelName.QWEN_PLUS;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
public class QwenTokenizer implements Tokenizer {
private final String apiKey;
@ -94,4 +96,18 @@ public class QwenTokenizer implements Tokenizer {
}
return true;
}
public static QwenTokenizer.QwenTokenizerBuilder builder() {
for (QwenTokenizerBuilderFactory factory : loadFactories(QwenTokenizerBuilderFactory.class)) {
return factory.get();
}
return new QwenTokenizer.QwenTokenizerBuilder();
}
public static class QwenTokenizerBuilder {
public QwenTokenizerBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
}
}

View File

@ -0,0 +1,71 @@
package dev.langchain4j.model.dashscope;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisOutput;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.OSSUtils;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Utils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.*;
import java.util.stream.Collectors;
public class WanxHelper {
static List<Image> imagesFrom(ImageSynthesisResult result) {
return Optional.of(result)
.map(ImageSynthesisResult::getOutput)
.map(ImageSynthesisOutput::getResults)
.orElse(Collections.emptyList())
.stream()
.map(resultMap -> resultMap.get("url"))
.map(url -> Image.builder().url(url).build())
.collect(Collectors.toList());
}
static String imageUrl(Image image, String model, String apiKey) {
String imageUrl;
if (image.url() != null) {
imageUrl = image.url().toString();
} else if (Utils.isNotNullOrBlank(image.base64Data())) {
String filePath = saveDataAsTemporaryFile(image.base64Data(), image.mimeType());
try {
imageUrl = OSSUtils.upload(model, filePath, apiKey);
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
} else {
throw new IllegalArgumentException("Failed to get image url from " + image);
}
return imageUrl;
}
static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
String tmpDir = System.getProperty("java.io.tmpdir", "/tmp");
String tmpFileName = UUID.randomUUID().toString();
if (Utils.isNotNullOrBlank(mimeType)) {
// e.g. "image/png", "image/jpeg"...
int lastSlashIndex = mimeType.lastIndexOf("/");
if (lastSlashIndex >= 0 && lastSlashIndex < mimeType.length() - 1) {
String fileSuffix = mimeType.substring(lastSlashIndex + 1);
tmpFileName = tmpFileName + "." + fileSuffix;
}
}
Path tmpFilePath = Paths.get(tmpDir, tmpFileName);
byte[] data = Base64.getDecoder().decode(base64Data);
try {
Files.copy(new ByteArrayInputStream(data), tmpFilePath, StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
return tmpFilePath.toAbsolutePath().toString();
}
}

View File

@ -0,0 +1,150 @@
package dev.langchain4j.model.dashscope;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.dashscope.spi.WanxImageModelBuilderFactory;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import java.util.List;
import static dev.langchain4j.model.dashscope.WanxHelper.imageUrl;
import static dev.langchain4j.model.dashscope.WanxHelper.imagesFrom;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
/**
* Represents a Wanx models to generate artistic images.
* More details are available <a href="https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9">here</a>.
*/
public class WanxImageModel implements ImageModel {
private final String apiKey;
private final String modelName;
// The generation method of the reference image. The optional values are
// 'repaint' and 'refonly'; repaint represents the reference content and
// refonly represents the reference style. Default is 'repaint'.
private final WanxImageRefMode refMode;
// The similarity between the expected output result and the reference image,
// the value range is [0.0, 1.0]. The larger the number, the more similar the
// generated result is to the reference image. Default is 0.5.
private final Float refStrength;
private final Integer seed;
// The resolution of the generated image currently only supports '1024*1024',
// '720*1280', and '1280*720' resolutions. Default is '1024*1024'.
private final WanxImageSize size;
private final WanxImageStyle style;
private final ImageSynthesis imageSynthesis;
@Builder
public WanxImageModel(String baseUrl,
String apiKey,
String modelName,
WanxImageRefMode refMode,
Float refStrength,
Integer seed,
WanxImageSize size,
WanxImageStyle style) {
if (Utils.isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
}
this.modelName = Utils.isNullOrBlank(modelName) ? WanxModelName.WANX_V1 : modelName;
this.apiKey = apiKey;
this.refMode = refMode;
this.refStrength = refStrength;
this.seed = seed;
this.size = size;
this.style = style;
this.imageSynthesis = Utils.isNullOrBlank(baseUrl) ? new ImageSynthesis() : new ImageSynthesis("text2image", baseUrl);
}
@Override
public Response<Image> generate(String prompt) {
ImageSynthesisParam param = requestBuilder(prompt).n(1).build();
try {
ImageSynthesisResult result = imageSynthesis.call(param);
return Response.from(imagesFrom(result).get(0));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}
@Override
public Response<List<Image>> generate(String prompt, int n) {
ImageSynthesisParam param = requestBuilder(prompt).n(n).build();
try {
ImageSynthesisResult result = imageSynthesis.call(param);
return Response.from(imagesFrom(result));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}
@Override
public Response<Image> edit(Image image, String prompt) {
String imageUrl = imageUrl(image, modelName, apiKey);
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = requestBuilder(prompt)
.refImage(imageUrl)
.n(1);
if (imageUrl.startsWith("oss://")) {
builder.header("X-DashScope-OssResourceResolve", "enable");
}
try {
ImageSynthesisResult result = imageSynthesis.call(builder.build());
return Response.from(imagesFrom(result).get(0));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}
private ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> requestBuilder(String prompt) {
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = ImageSynthesisParam.builder()
.apiKey(apiKey)
.model(modelName)
.prompt(prompt);
if (seed != null) {
builder.seed(seed);
}
if (size != null) {
builder.size(size.toString());
}
if (style != null) {
builder.style(style.toString());
}
if (refMode != null) {
builder.parameter("ref_mode", refMode.toString());
}
if (refStrength != null) {
builder.parameter("ref_strength", refStrength);
}
return builder;
}
public static WanxImageModel.WanxImageModelBuilder builder() {
for (WanxImageModelBuilderFactory factory : loadFactories(WanxImageModelBuilderFactory.class)) {
return factory.get();
}
return new WanxImageModel.WanxImageModelBuilder();
}
public static class WanxImageModelBuilder {
public WanxImageModelBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
}
}

View File

@ -0,0 +1,17 @@
package dev.langchain4j.model.dashscope;
public enum WanxImageRefMode {
REPAINT("repaint"),
REFONLY("refonly");
private final String mode;
WanxImageRefMode(String mode) {
this.mode = mode;
}
@Override
public String toString() {
return mode;
}
}

View File

@ -0,0 +1,18 @@
package dev.langchain4j.model.dashscope;
public enum WanxImageSize {
SIZE_1024_1024("1024*1024"),
SIZE_720_1280("720*1280"),
SIZE_1280_720("1280*720");
private final String size;
WanxImageSize(String size) {
this.size = size;
}
@Override
public String toString() {
return size;
}
}

View File

@ -0,0 +1,25 @@
package dev.langchain4j.model.dashscope;
public enum WanxImageStyle {
PHOTOGRAPHY("<photography>"),
PORTRAIT("<portrait>"),
CARTOON_3D("<3d cartoon>"),
ANIME("<anime>"),
OIL_PAINTING("<oil painting>"),
WATERCOLOR("<watercolor>"),
SKETCH("<sketch>"),
CHINESE_PAINTING("<chinese painting>"),
FLAT_ILLUSTRATION("<flat illustration>"),
AUTO("<auto>");
private final String style;
WanxImageStyle(String style) {
this.style = style;
}
@Override
public String toString() {
return style;
}
}

View File

@ -0,0 +1,6 @@
package dev.langchain4j.model.dashscope;
public class WanxModelName {
// Use with WanxImageModel
public static final String WANX_V1 = "wanx-v1"; // Wanx model for text-generated images, supports Chinese and English
}

View File

@ -0,0 +1,8 @@
package dev.langchain4j.model.dashscope.spi;
import dev.langchain4j.model.dashscope.QwenTokenizer;
import java.util.function.Supplier;
public interface QwenTokenizerBuilderFactory extends Supplier<QwenTokenizer.QwenTokenizerBuilder> {
}

View File

@ -0,0 +1,8 @@
package dev.langchain4j.model.dashscope.spi;
import dev.langchain4j.model.dashscope.WanxImageModel;
import java.util.function.Supplier;
public interface WanxImageModelBuilderFactory extends Supplier<WanxImageModel.WanxImageModelBuilder> {
}

View File

@ -0,0 +1,74 @@
package dev.langchain4j.model.dashscope;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.URI;
import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey;
import static dev.langchain4j.model.dashscope.QwenTestHelper.multimodalImageData;
import static org.assertj.core.api.Assertions.assertThat;
@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+")
public class WanxImageModelIT {
Logger log = LoggerFactory.getLogger(WanxImageModelIT.class);
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_generation_works(String modelName) {
WanxImageModel model = WanxImageModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();
Response<Image> response = model.generate("Beautiful house on country side");
URI remoteImage = response.content().url();
log.info("Your remote image is here: {}", remoteImage);
assertThat(remoteImage).isNotNull();
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_edition_works_by_url(String modelName) {
WanxImageModel model = WanxImageModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();
Image image = Image.builder()
.url("https://img.alicdn.com/imgextra/i4/O1CN01K1DWat25own2MuQgF_!!6000000007574-0-tps-128-128.jpg")
.build();
Response<Image> response = model.edit(image, "Change the parrot's feathers with yellow");
URI remoteImage = response.content().url();
log.info("Your remote image is here: {}", remoteImage);
assertThat(remoteImage).isNotNull();
}
@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_edition_works_by_data(String modelName) {
WanxImageModel model = WanxImageModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();
Image image = Image.builder()
.base64Data(multimodalImageData())
.mimeType("image/jpg")
.build();
Response<Image> response = model.edit(image, "Change the parrot's feathers with yellow");
URI remoteImage = response.content().url();
log.info("Your remote image is here: {}", remoteImage);
assertThat(remoteImage).isNotNull();
}
}

View File

@ -0,0 +1,13 @@
package dev.langchain4j.model.dashscope;
import org.junit.jupiter.params.provider.Arguments;
import java.util.stream.Stream;
public class WanxTestHelper {
public static Stream<Arguments> imageModelNameProvider() {
return Stream.of(
Arguments.of(WanxModelName.WANX_V1)
);
}
}