OpenAI DALL·E support (#298)
Motivating generated picture 😄 ❤️ ![image](https://github.com/langchain4j/langchain4j/assets/33568148/5f5463f0-8d43-47a3-8127-146340871132)
This commit is contained in:
parent
38049a197b
commit
cf4f2da604
|
@ -1,3 +1,4 @@
|
||||||
{
|
{
|
||||||
|
"tabWidth": 4,
|
||||||
"printWidth": 120
|
"printWidth": 120
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
package dev.langchain4j.data.image;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.quoted;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public final class Image {
|
||||||
|
|
||||||
|
private URI url;
|
||||||
|
private String base64;
|
||||||
|
private String revisedPrompt;
|
||||||
|
|
||||||
|
private Image(Builder builder) {
|
||||||
|
this.url = builder.url;
|
||||||
|
this.base64 = builder.base64;
|
||||||
|
this.revisedPrompt = builder.revisedPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public URI url() {
|
||||||
|
return url;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String base64() {
|
||||||
|
return base64;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String revisedPrompt() {
|
||||||
|
return revisedPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
Image image = (Image) o;
|
||||||
|
return (
|
||||||
|
Objects.equals(url, image.url) &&
|
||||||
|
Objects.equals(base64, image.base64) &&
|
||||||
|
Objects.equals(revisedPrompt, image.revisedPrompt)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(url, revisedPrompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return (
|
||||||
|
"Image{" +
|
||||||
|
" url=" +
|
||||||
|
quoted(url.toString()) +
|
||||||
|
", base64=" +
|
||||||
|
quoted(base64) +
|
||||||
|
", revisedPrompt=" +
|
||||||
|
quoted(revisedPrompt) +
|
||||||
|
'}'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
private URI url;
|
||||||
|
private String base64;
|
||||||
|
private String revisedPrompt;
|
||||||
|
|
||||||
|
public Builder url(URI url) {
|
||||||
|
this.url = url;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder base64(String base64) {
|
||||||
|
this.base64 = base64;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder revisedPrompt(String revisedPrompt) {
|
||||||
|
this.revisedPrompt = revisedPrompt;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Image build() {
|
||||||
|
return new Image(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package dev.langchain4j.model.image;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.image.Image;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ImageModel {
|
||||||
|
Response<Image> generate(String prompt);
|
||||||
|
|
||||||
|
default Response<List<Image>> generate(String prompt, int n) {
|
||||||
|
throw new IllegalArgumentException("Operation is not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
package dev.langchain4j.model.openai;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||||
|
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||||
|
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
|
||||||
|
import static java.time.Duration.ofSeconds;
|
||||||
|
|
||||||
|
import dev.ai4j.openai4j.OpenAiClient;
|
||||||
|
import dev.ai4j.openai4j.image.GenerateImagesRequest;
|
||||||
|
import dev.ai4j.openai4j.image.GenerateImagesResponse;
|
||||||
|
import dev.langchain4j.data.image.Image;
|
||||||
|
import dev.langchain4j.model.image.ImageModel;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import java.net.Proxy;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.NonNull;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents an OpenAI DALL·E models to generate artistic images. Versions 2 and 3 (default) are supported.
|
||||||
|
*/
|
||||||
|
public class OpenAiImageModel implements ImageModel {
|
||||||
|
|
||||||
|
private final String model;
|
||||||
|
private final String size;
|
||||||
|
private final String quality;
|
||||||
|
private final String style;
|
||||||
|
private final String user;
|
||||||
|
private final String responseFormat;
|
||||||
|
|
||||||
|
private final OpenAiClient client;
|
||||||
|
|
||||||
|
private final Integer maxRetries;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instantiates OpenAI DALL·E image processing model.
|
||||||
|
* Find the parameters description <a href="https://platform.openai.com/docs/api-reference/images/create">here</a>.
|
||||||
|
*
|
||||||
|
* @param model dall-e-3 is default one
|
||||||
|
* @param persistTo specifies the local path where the generated image will be downloaded to (in case provided).
|
||||||
|
* The URL within <code>dev.ai4j.openai4j.image.GenerateImagesResponse</code> will contain
|
||||||
|
* the URL to local images then.
|
||||||
|
* @param withPersisting generated response will be persisted under <code>java.io.tmpdir</code>.
|
||||||
|
* The URL within <code>dev.ai4j.openai4j.image.GenerateImagesResponse</code> will contain
|
||||||
|
* the URL to local images then.
|
||||||
|
*/
|
||||||
|
@Builder
|
||||||
|
@SuppressWarnings("rawtypes")
|
||||||
|
public OpenAiImageModel(
|
||||||
|
@NonNull String apiKey,
|
||||||
|
String model,
|
||||||
|
String size,
|
||||||
|
String quality,
|
||||||
|
String style,
|
||||||
|
String user,
|
||||||
|
String responseFormat,
|
||||||
|
Duration timeout,
|
||||||
|
Integer maxRetries,
|
||||||
|
Proxy proxy,
|
||||||
|
Boolean logRequests,
|
||||||
|
Boolean logResponses,
|
||||||
|
Boolean withPersisting,
|
||||||
|
Path persistTo
|
||||||
|
) {
|
||||||
|
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||||
|
|
||||||
|
OpenAiClient.Builder cBuilder = OpenAiClient
|
||||||
|
.builder()
|
||||||
|
.openAiApiKey(apiKey)
|
||||||
|
.callTimeout(timeout)
|
||||||
|
.connectTimeout(timeout)
|
||||||
|
.readTimeout(timeout)
|
||||||
|
.writeTimeout(timeout)
|
||||||
|
.proxy(proxy)
|
||||||
|
.logRequests(getOrDefault(logRequests, false))
|
||||||
|
.logResponses(getOrDefault(logResponses, false))
|
||||||
|
.persistTo(persistTo);
|
||||||
|
|
||||||
|
if (withPersisting != null && withPersisting) {
|
||||||
|
cBuilder.withPersisting();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.client = cBuilder.build();
|
||||||
|
|
||||||
|
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||||
|
this.model = model;
|
||||||
|
this.size = size;
|
||||||
|
this.quality = quality;
|
||||||
|
this.style = style;
|
||||||
|
this.user = user;
|
||||||
|
this.responseFormat = responseFormat;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Response<Image> generate(String prompt) {
|
||||||
|
GenerateImagesRequest request = requestBuilder(prompt).build();
|
||||||
|
|
||||||
|
GenerateImagesResponse response = withRetry(() -> client.imagesGeneration(request), maxRetries).execute();
|
||||||
|
|
||||||
|
return Response.from(fromImageData(response.data().get(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Response<List<Image>> generate(String prompt, int n) {
|
||||||
|
GenerateImagesRequest request = requestBuilder(prompt).n(n).build();
|
||||||
|
|
||||||
|
GenerateImagesResponse response = withRetry(() -> client.imagesGeneration(request), maxRetries).execute();
|
||||||
|
|
||||||
|
return Response.from(
|
||||||
|
response.data().stream().map(OpenAiImageModel::fromImageData).collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class OpenAiImageModelBuilder {
|
||||||
|
|
||||||
|
public OpenAiImageModelBuilder withPersisting() {
|
||||||
|
withPersisting = true;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OpenAiImageModelBuilder withApiKey(String apiKey) {
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Image fromImageData(GenerateImagesResponse.ImageData data) {
|
||||||
|
return Image.builder().url(data.url()).base64(data.b64Json()).revisedPrompt(data.revisedPrompt()).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private GenerateImagesRequest.Builder requestBuilder(String prompt) {
|
||||||
|
GenerateImagesRequest.Builder requestBuilder = GenerateImagesRequest
|
||||||
|
.builder()
|
||||||
|
.prompt(prompt)
|
||||||
|
.size(size)
|
||||||
|
.quality(quality)
|
||||||
|
.style(style)
|
||||||
|
.user(user)
|
||||||
|
.responseFormat(responseFormat);
|
||||||
|
|
||||||
|
if (DALL_E_2.equals(model)) {
|
||||||
|
requestBuilder.model(dev.ai4j.openai4j.image.ImageModel.DALL_E_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return requestBuilder;
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,12 +25,14 @@ public class OpenAiModelName {
|
||||||
// Use with OpenAiLanguageModel and OpenAiStreamingLanguageModel
|
// Use with OpenAiLanguageModel and OpenAiStreamingLanguageModel
|
||||||
public static final String GPT_3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct";
|
public static final String GPT_3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct";
|
||||||
|
|
||||||
|
|
||||||
// Use with OpenAiEmbeddingModel
|
// Use with OpenAiEmbeddingModel
|
||||||
public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002";
|
public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002";
|
||||||
|
|
||||||
|
|
||||||
// Use with OpenAiModerationModel
|
// Use with OpenAiModerationModel
|
||||||
public static final String TEXT_MODERATION_STABLE = "text-moderation-stable";
|
public static final String TEXT_MODERATION_STABLE = "text-moderation-stable";
|
||||||
public static final String TEXT_MODERATION_LATEST = "text-moderation-latest";
|
public static final String TEXT_MODERATION_LATEST = "text-moderation-latest";
|
||||||
|
|
||||||
|
// Use with OpenAiImageModel
|
||||||
|
public static final String DALL_E_2 = "dall-e-2"; // anyone still needs that? :)
|
||||||
|
public static final String DALL_E_3 = "dall-e-3";
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
package dev.langchain4j.model.openai;
|
||||||
|
|
||||||
|
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_QUALITY_HD;
|
||||||
|
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_RESPONSE_FORMAT_B64_JSON;
|
||||||
|
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_256_x_256;
|
||||||
|
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.image.Image;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.List;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
class OpenAiImageModelIT {
|
||||||
|
|
||||||
|
Logger log = LoggerFactory.getLogger(OpenAiImageModelIT.class);
|
||||||
|
|
||||||
|
OpenAiImageModel.OpenAiImageModelBuilder modelBuilder = OpenAiImageModel
|
||||||
|
.builder()
|
||||||
|
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||||
|
.model(DALL_E_2) // so that you pay not much :)
|
||||||
|
.size(DALL_E_SIZE_256_x_256)
|
||||||
|
.logRequests(true)
|
||||||
|
.logResponses(true);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void simple_image_generation_works() {
|
||||||
|
OpenAiImageModel model = modelBuilder.build();
|
||||||
|
|
||||||
|
Response<Image> response = model.generate("Beautiful house on country side");
|
||||||
|
|
||||||
|
URI remoteImage = response.content().url();
|
||||||
|
log.info("Your remote image is here: {}", remoteImage);
|
||||||
|
assertThat(remoteImage).isNotNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void image_generation_with_persisting_works() {
|
||||||
|
OpenAiImageModel model = modelBuilder.responseFormat(DALL_E_RESPONSE_FORMAT_B64_JSON).withPersisting().build();
|
||||||
|
|
||||||
|
Response<Image> response = model.generate("Bird flying in the sky");
|
||||||
|
|
||||||
|
URI localImage = response.content().url();
|
||||||
|
log.info("Your local image is here: {}", localImage);
|
||||||
|
assertThat(new File(localImage)).exists();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void multiple_images_generation_with_base64_works() {
|
||||||
|
OpenAiImageModel model = modelBuilder.responseFormat(DALL_E_RESPONSE_FORMAT_B64_JSON).withPersisting().build();
|
||||||
|
|
||||||
|
Response<List<Image>> response = model.generate("Cute red parrot sings", 2);
|
||||||
|
|
||||||
|
assertThat(response.content()).hasSize(2);
|
||||||
|
|
||||||
|
Image localImage1 = response.content().get(0);
|
||||||
|
log.info("Your first local image is here: {}", localImage1.url());
|
||||||
|
assertThat(new File(localImage1.url())).exists();
|
||||||
|
assertThat(localImage1.base64()).isNotNull().isBase64();
|
||||||
|
|
||||||
|
Image localImage2 = response.content().get(1);
|
||||||
|
log.info("Your second local image is here: {}", localImage2.url());
|
||||||
|
assertThat(new File(localImage2.url())).exists();
|
||||||
|
assertThat(localImage2.base64()).isNotNull().isBase64();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void image_generation_with_dalle3_works() {
|
||||||
|
OpenAiImageModel model = OpenAiImageModel
|
||||||
|
.builder()
|
||||||
|
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||||
|
.quality(DALL_E_QUALITY_HD)
|
||||||
|
.logRequests(true)
|
||||||
|
.logResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Response<Image> response = model.generate(
|
||||||
|
"Beautiful house on country side, cowboy plays guitar, dog sitting at the door"
|
||||||
|
);
|
||||||
|
|
||||||
|
URI remoteImage = response.content().url();
|
||||||
|
log.info("Your remote image is here: {}", remoteImage);
|
||||||
|
assertThat(remoteImage).isNotNull();
|
||||||
|
|
||||||
|
String revisedPrompt = response.content().revisedPrompt();
|
||||||
|
log.info("Your revised prompt: {}", revisedPrompt);
|
||||||
|
assertThat(revisedPrompt).hasSizeGreaterThan(50);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue