OpenAI DALL·E support (#298)

Motivating generated picture 😄 ❤️ 

![image](https://github.com/langchain4j/langchain4j/assets/33568148/5f5463f0-8d43-47a3-8127-146340871132)
This commit is contained in:
Alexey Titov 2023-12-19 11:37:11 +01:00 committed by GitHub
parent 38049a197b
commit cf4f2da604
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 353 additions and 2 deletions

View File

@ -1,3 +1,4 @@
{
"tabWidth": 4,
"printWidth": 120
}

View File

@ -0,0 +1,92 @@
package dev.langchain4j.data.image;
import static dev.langchain4j.internal.Utils.quoted;
import java.net.URI;
import java.util.Objects;
public final class Image {
private URI url;
private String base64;
private String revisedPrompt;
private Image(Builder builder) {
this.url = builder.url;
this.base64 = builder.base64;
this.revisedPrompt = builder.revisedPrompt;
}
public static Builder builder() {
return new Builder();
}
public URI url() {
return url;
}
public String base64() {
return base64;
}
public String revisedPrompt() {
return revisedPrompt;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Image image = (Image) o;
return (
Objects.equals(url, image.url) &&
Objects.equals(base64, image.base64) &&
Objects.equals(revisedPrompt, image.revisedPrompt)
);
}
@Override
public int hashCode() {
return Objects.hash(url, revisedPrompt);
}
@Override
public String toString() {
return (
"Image{" +
" url=" +
quoted(url.toString()) +
", base64=" +
quoted(base64) +
", revisedPrompt=" +
quoted(revisedPrompt) +
'}'
);
}
public static class Builder {
private URI url;
private String base64;
private String revisedPrompt;
public Builder url(URI url) {
this.url = url;
return this;
}
public Builder base64(String base64) {
this.base64 = base64;
return this;
}
public Builder revisedPrompt(String revisedPrompt) {
this.revisedPrompt = revisedPrompt;
return this;
}
public Image build() {
return new Image(this);
}
}
}

View File

@ -0,0 +1,13 @@
package dev.langchain4j.model.image;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
import java.util.List;
public interface ImageModel {
Response<Image> generate(String prompt);
default Response<List<Image>> generate(String prompt, int n) {
throw new IllegalArgumentException("Operation is not supported");
}
}

View File

@ -0,0 +1,150 @@
package dev.langchain4j.model.openai;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
import static java.time.Duration.ofSeconds;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.image.GenerateImagesRequest;
import dev.ai4j.openai4j.image.GenerateImagesResponse;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
import java.net.Proxy;
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Builder;
import lombok.NonNull;
/**
* Represents an OpenAI DALL·E models to generate artistic images. Versions 2 and 3 (default) are supported.
*/
public class OpenAiImageModel implements ImageModel {
private final String model;
private final String size;
private final String quality;
private final String style;
private final String user;
private final String responseFormat;
private final OpenAiClient client;
private final Integer maxRetries;
/**
* Instantiates OpenAI DALL·E image processing model.
* Find the parameters description <a href="https://platform.openai.com/docs/api-reference/images/create">here</a>.
*
* @param model dall-e-3 is default one
* @param persistTo specifies the local path where the generated image will be downloaded to (in case provided).
* The URL within <code>dev.ai4j.openai4j.image.GenerateImagesResponse</code> will contain
* the URL to local images then.
* @param withPersisting generated response will be persisted under <code>java.io.tmpdir</code>.
* The URL within <code>dev.ai4j.openai4j.image.GenerateImagesResponse</code> will contain
* the URL to local images then.
*/
@Builder
@SuppressWarnings("rawtypes")
public OpenAiImageModel(
@NonNull String apiKey,
String model,
String size,
String quality,
String style,
String user,
String responseFormat,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Boolean withPersisting,
Path persistTo
) {
timeout = getOrDefault(timeout, ofSeconds(60));
OpenAiClient.Builder cBuilder = OpenAiClient
.builder()
.openAiApiKey(apiKey)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.persistTo(persistTo);
if (withPersisting != null && withPersisting) {
cBuilder.withPersisting();
}
this.client = cBuilder.build();
this.maxRetries = getOrDefault(maxRetries, 3);
this.model = model;
this.size = size;
this.quality = quality;
this.style = style;
this.user = user;
this.responseFormat = responseFormat;
}
@Override
public Response<Image> generate(String prompt) {
GenerateImagesRequest request = requestBuilder(prompt).build();
GenerateImagesResponse response = withRetry(() -> client.imagesGeneration(request), maxRetries).execute();
return Response.from(fromImageData(response.data().get(0)));
}
@Override
public Response<List<Image>> generate(String prompt, int n) {
GenerateImagesRequest request = requestBuilder(prompt).n(n).build();
GenerateImagesResponse response = withRetry(() -> client.imagesGeneration(request), maxRetries).execute();
return Response.from(
response.data().stream().map(OpenAiImageModel::fromImageData).collect(Collectors.toList())
);
}
public static class OpenAiImageModelBuilder {
public OpenAiImageModelBuilder withPersisting() {
withPersisting = true;
return this;
}
public OpenAiImageModelBuilder withApiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}
}
private static Image fromImageData(GenerateImagesResponse.ImageData data) {
return Image.builder().url(data.url()).base64(data.b64Json()).revisedPrompt(data.revisedPrompt()).build();
}
private GenerateImagesRequest.Builder requestBuilder(String prompt) {
GenerateImagesRequest.Builder requestBuilder = GenerateImagesRequest
.builder()
.prompt(prompt)
.size(size)
.quality(quality)
.style(style)
.user(user)
.responseFormat(responseFormat);
if (DALL_E_2.equals(model)) {
requestBuilder.model(dev.ai4j.openai4j.image.ImageModel.DALL_E_2);
}
return requestBuilder;
}
}

View File

@ -25,12 +25,14 @@ public class OpenAiModelName {
// Use with OpenAiLanguageModel and OpenAiStreamingLanguageModel
public static final String GPT_3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct";
// Use with OpenAiEmbeddingModel
public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002";
// Use with OpenAiModerationModel
public static final String TEXT_MODERATION_STABLE = "text-moderation-stable";
public static final String TEXT_MODERATION_LATEST = "text-moderation-latest";
// Use with OpenAiImageModel
public static final String DALL_E_2 = "dall-e-2"; // anyone still needs that? :)
public static final String DALL_E_3 = "dall-e-3";
}

View File

@ -0,0 +1,93 @@
package dev.langchain4j.model.openai;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_QUALITY_HD;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_RESPONSE_FORMAT_B64_JSON;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_256_x_256;
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
import static org.assertj.core.api.Assertions.assertThat;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
import java.io.File;
import java.net.URI;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class OpenAiImageModelIT {
Logger log = LoggerFactory.getLogger(OpenAiImageModelIT.class);
OpenAiImageModel.OpenAiImageModelBuilder modelBuilder = OpenAiImageModel
.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.model(DALL_E_2) // so that you pay not much :)
.size(DALL_E_SIZE_256_x_256)
.logRequests(true)
.logResponses(true);
@Test
void simple_image_generation_works() {
OpenAiImageModel model = modelBuilder.build();
Response<Image> response = model.generate("Beautiful house on country side");
URI remoteImage = response.content().url();
log.info("Your remote image is here: {}", remoteImage);
assertThat(remoteImage).isNotNull();
}
@Test
void image_generation_with_persisting_works() {
OpenAiImageModel model = modelBuilder.responseFormat(DALL_E_RESPONSE_FORMAT_B64_JSON).withPersisting().build();
Response<Image> response = model.generate("Bird flying in the sky");
URI localImage = response.content().url();
log.info("Your local image is here: {}", localImage);
assertThat(new File(localImage)).exists();
}
@Test
void multiple_images_generation_with_base64_works() {
OpenAiImageModel model = modelBuilder.responseFormat(DALL_E_RESPONSE_FORMAT_B64_JSON).withPersisting().build();
Response<List<Image>> response = model.generate("Cute red parrot sings", 2);
assertThat(response.content()).hasSize(2);
Image localImage1 = response.content().get(0);
log.info("Your first local image is here: {}", localImage1.url());
assertThat(new File(localImage1.url())).exists();
assertThat(localImage1.base64()).isNotNull().isBase64();
Image localImage2 = response.content().get(1);
log.info("Your second local image is here: {}", localImage2.url());
assertThat(new File(localImage2.url())).exists();
assertThat(localImage2.base64()).isNotNull().isBase64();
}
@Test
void image_generation_with_dalle3_works() {
OpenAiImageModel model = OpenAiImageModel
.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.quality(DALL_E_QUALITY_HD)
.logRequests(true)
.logResponses(true)
.build();
Response<Image> response = model.generate(
"Beautiful house on country side, cowboy plays guitar, dog sitting at the door"
);
URI remoteImage = response.content().url();
log.info("Your remote image is here: {}", remoteImage);
assertThat(remoteImage).isNotNull();
String revisedPrompt = response.content().revisedPrompt();
log.info("Your revised prompt: {}", revisedPrompt);
assertThat(revisedPrompt).hasSizeGreaterThan(50);
}
}