OpenAI: Structured Outputs (#1590)
## Issue Closes #1581 ## Change - OpenAI: added support for [Structured Outputs](https://openai.com/index/introducing-structured-outputs-in-the-api/): - for tools - for json mode - Introduced new (still experimental) `ChatLanguageModel` API (which supports specifying json schema) ### OpenAI Structured Outputs for tools To enable Structured Outputs feature for tools, set `.strictTools(true)` when buidling the model: ```java OpenAiChatModel.builder() ... .strictTools(true) .build(), ``` Please note that this will automatically make all tool parameters mandatory (`required` in json schema) and set `additionalProperties=false` for each `object` in json schema. This is due to the current OpenAI limitations. ### OpenAI Structured Outputs for json mode To enable Structured Outputs feature for json mode, set `.responseFormat("json_schema")` and `.strictJsonSchema(true)` when buidling the model: ```java OpenAiChatModel.builder() ... .responseFormat("json_schema") .strictJsonSchema(true) .build(), ``` In this case `AiServices` will not append "You must answer strictly in the following JSON format: ..." string to the end of the last `UserMessage`, but will create a Json schema from the given POJO and pass it to the LLM. Please note that this works only when method return type is a POJO. If the return type is something else, (like an enum or a `List<String>`), the old behaviour is applied (with "You must answer strictly ..."). All return types will be supported in the near future. Please note that this feature is available now only for `gpt-4o-mini` and `gpt-4o-2024-08-06` models. ### Experimental `ChatLanguageModel` API This was drafted in https://github.com/langchain4j/langchain4j/pull/1261, but now it has to be rushed a bit in order to enable new Structured Outputs feature for OpenAI. A new method `ChatResponse chat(ChatRequest request)` was added into `ChatLanguageModel` which allows to specify messages, tools and response format (with json schema). In the future it will also support specifying model parameters like temperature. ## Upcoming Changes - Adopt new `ChatLanguageModel` API for Gemini - Adopt new `ChatLanguageModel` API for Azure OpenAI (once available) - Support Structured Outputs with all other method return types like `List<Pojo>` - Adopt new `JsonSchema` type for tools (instead of `ToolParameters`) Reated changes in openai4j: https://github.com/ai-for-java/openai4j/pull/33 ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
56d6184aae
commit
c5c146f9a0
|
@ -126,6 +126,9 @@
|
|||
<excludes>
|
||||
<exclude>dev.langchain4j.data.document</exclude>
|
||||
<exclude>dev.langchain4j.model.chat.listener</exclude>
|
||||
<exclude>dev.langchain4j.model.chat.request</exclude>
|
||||
<exclude>dev.langchain4j.model.chat.request.json</exclude>
|
||||
<exclude>dev.langchain4j.model.chat.response</exclude>
|
||||
<exclude>dev.langchain4j.model.listener</exclude>
|
||||
<exclude>dev.langchain4j.store.embedding</exclude>
|
||||
<exclude>dev.langchain4j.store.embedding.filter</exclude>
|
||||
|
@ -146,7 +149,7 @@
|
|||
<limit>
|
||||
<counter>INSTRUCTION</counter>
|
||||
<value>COVEREDRATIO</value>
|
||||
<minimum>0.9</minimum>
|
||||
<minimum>0.8</minimum>
|
||||
</limit>
|
||||
</limits>
|
||||
</rule>
|
||||
|
|
|
@ -11,6 +11,7 @@ import static java.util.Arrays.asList;
|
|||
* Describes a {@link Tool}.
|
||||
*/
|
||||
public class ToolSpecification {
|
||||
|
||||
private final String name;
|
||||
private final String description;
|
||||
private final ToolParameters parameters;
|
||||
|
|
|
@ -11,8 +11,11 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums;
|
|||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.from;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.items;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.objectItems;
|
||||
import static dev.langchain4j.internal.TypeUtils.*;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
|
||||
import static java.lang.String.format;
|
||||
import static java.util.Arrays.stream;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
@ -22,8 +25,6 @@ import java.lang.reflect.Method;
|
|||
import java.lang.reflect.Parameter;
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.lang.reflect.Type;
|
||||
import java.math.BigDecimal;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
|
@ -199,15 +200,15 @@ public class ToolSpecifications {
|
|||
return removeNulls(STRING, description);
|
||||
}
|
||||
|
||||
if (isBoolean(type)) {
|
||||
if (isJsonBoolean(type)) {
|
||||
return removeNulls(BOOLEAN, description);
|
||||
}
|
||||
|
||||
if (isInteger(type)) {
|
||||
if (isJsonInteger(type)) {
|
||||
return removeNulls(INTEGER, description);
|
||||
}
|
||||
|
||||
if (isNumber(type)) {
|
||||
if (isJsonNumber(type)) {
|
||||
return removeNulls(NUMBER, description);
|
||||
}
|
||||
|
||||
|
@ -234,36 +235,17 @@ public class ToolSpecifications {
|
|||
return items(JsonSchemaProperty.OBJECT);
|
||||
}
|
||||
|
||||
// TODO put constraints on min and max?
|
||||
private static boolean isNumber(Class<?> type) {
|
||||
return type == float.class || type == Float.class
|
||||
|| type == double.class || type == Double.class
|
||||
|| type == BigDecimal.class;
|
||||
}
|
||||
|
||||
private static boolean isInteger(Class<?> type) {
|
||||
return type == byte.class || type == Byte.class
|
||||
|| type == short.class || type == Short.class
|
||||
|| type == int.class || type == Integer.class
|
||||
|| type == long.class || type == Long.class
|
||||
|| type == BigInteger.class;
|
||||
}
|
||||
|
||||
private static boolean isBoolean(Class<?> type) {
|
||||
return type == boolean.class || type == Boolean.class;
|
||||
}
|
||||
|
||||
private static JsonSchemaProperty arrayTypeFrom(Class<?> clazz) {
|
||||
if (clazz == String.class) {
|
||||
return items(JsonSchemaProperty.STRING);
|
||||
}
|
||||
if (isBoolean(clazz)) {
|
||||
if (isJsonBoolean(clazz)) {
|
||||
return items(JsonSchemaProperty.BOOLEAN);
|
||||
}
|
||||
if (isInteger(clazz)) {
|
||||
if (isJsonInteger(clazz)) {
|
||||
return items(JsonSchemaProperty.INTEGER);
|
||||
}
|
||||
if (isNumber(clazz)) {
|
||||
if (isJsonNumber(clazz)) {
|
||||
return items(JsonSchemaProperty.NUMBER);
|
||||
}
|
||||
return objectItems(schema(clazz));
|
||||
|
|
|
@ -7,36 +7,80 @@ import java.io.*;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
import java.util.Map;
|
||||
|
||||
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
|
||||
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
|
||||
import static java.time.format.DateTimeFormatter.*;
|
||||
|
||||
class GsonJsonCodec implements Json.JsonCodec {
|
||||
|
||||
private static final Gson GSON = new GsonBuilder()
|
||||
.setPrettyPrinting()
|
||||
.registerTypeAdapter(
|
||||
LocalDate.class,
|
||||
(JsonSerializer<LocalDate>) (localDate, type, context) ->
|
||||
new JsonPrimitive(localDate.format(ISO_LOCAL_DATE))
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalDate.class,
|
||||
(JsonDeserializer<LocalDate>) (json, type, context) ->
|
||||
LocalDate.parse(json.getAsString(), ISO_LOCAL_DATE)
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalDateTime.class,
|
||||
(JsonSerializer<LocalDateTime>) (localDateTime, type, context) ->
|
||||
new JsonPrimitive(localDateTime.format(ISO_LOCAL_DATE_TIME))
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalDateTime.class,
|
||||
(JsonDeserializer<LocalDateTime>) (json, type, context) ->
|
||||
LocalDateTime.parse(json.getAsString(), ISO_LOCAL_DATE_TIME)
|
||||
)
|
||||
.create();
|
||||
.setPrettyPrinting()
|
||||
.registerTypeAdapter(
|
||||
LocalDate.class,
|
||||
(JsonSerializer<LocalDate>) (localDate, type, context) ->
|
||||
new JsonPrimitive(localDate.format(ISO_LOCAL_DATE))
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalDate.class,
|
||||
(JsonDeserializer<LocalDate>) (json, type, context) -> {
|
||||
if (json.isJsonObject()) {
|
||||
JsonObject jsonObject = (JsonObject) json;
|
||||
int year = jsonObject.get("year").getAsInt();
|
||||
int month = jsonObject.get("month").getAsInt();
|
||||
int day = jsonObject.get("day").getAsInt();
|
||||
return LocalDate.of(year, month, day);
|
||||
} else {
|
||||
return LocalDate.parse(json.getAsString(), ISO_LOCAL_DATE);
|
||||
}
|
||||
}
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalTime.class,
|
||||
(JsonSerializer<LocalTime>) (localTime, type, context) ->
|
||||
new JsonPrimitive(localTime.format(ISO_LOCAL_TIME))
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalTime.class,
|
||||
(JsonDeserializer<LocalTime>) (json, type, context) -> {
|
||||
if (json.isJsonObject()) {
|
||||
JsonObject jsonObject = (JsonObject) json;
|
||||
int hour = jsonObject.get("hour").getAsInt();
|
||||
int minute = jsonObject.get("minute").getAsInt();
|
||||
int second = jsonObject.get("second").getAsInt();
|
||||
int nano = jsonObject.get("nano").getAsInt();
|
||||
return LocalTime.of(hour, minute, second, nano);
|
||||
} else {
|
||||
return LocalTime.parse(json.getAsString(), ISO_LOCAL_TIME);
|
||||
}
|
||||
}
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalDateTime.class,
|
||||
(JsonSerializer<LocalDateTime>) (localDateTime, type, context) ->
|
||||
new JsonPrimitive(localDateTime.format(ISO_LOCAL_DATE_TIME))
|
||||
)
|
||||
.registerTypeAdapter(
|
||||
LocalDateTime.class,
|
||||
(JsonDeserializer<LocalDateTime>) (json, type, context) -> {
|
||||
if (json.isJsonObject()) {
|
||||
JsonObject jsonObject = (JsonObject) json;
|
||||
JsonObject date = jsonObject.get("date").getAsJsonObject();
|
||||
int year = date.get("year").getAsInt();
|
||||
int month = date.get("month").getAsInt();
|
||||
int day = date.get("day").getAsInt();
|
||||
JsonObject time = jsonObject.get("time").getAsJsonObject();
|
||||
int hour = time.get("hour").getAsInt();
|
||||
int minute = time.get("minute").getAsInt();
|
||||
int second = time.get("second").getAsInt();
|
||||
int nano = time.get("nano").getAsInt();
|
||||
return LocalDateTime.of(year, month, day, hour, minute, second, nano);
|
||||
} else {
|
||||
return LocalDateTime.parse(json.getAsString(), ISO_LOCAL_DATE_TIME);
|
||||
}
|
||||
}
|
||||
)
|
||||
.create();
|
||||
|
||||
@Override
|
||||
public String toJson(Object o) {
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package dev.langchain4j.internal;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.BigInteger;
|
||||
|
||||
public class TypeUtils {
|
||||
|
||||
public static boolean isJsonInteger(Class<?> type) {
|
||||
return type == byte.class || type == Byte.class
|
||||
|| type == short.class || type == Short.class
|
||||
|| type == int.class || type == Integer.class
|
||||
|| type == long.class || type == Long.class
|
||||
|| type == BigInteger.class;
|
||||
}
|
||||
|
||||
public static boolean isJsonNumber(Class<?> type) {
|
||||
return type == float.class || type == Float.class
|
||||
|| type == double.class || type == Double.class
|
||||
|| type == BigDecimal.class;
|
||||
}
|
||||
|
||||
public static boolean isJsonBoolean(Class<?> type) {
|
||||
return type == boolean.class || type == Boolean.class;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package dev.langchain4j.model.chat;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
@Experimental
|
||||
public enum Capability {
|
||||
|
||||
RESPONSE_FORMAT_JSON_SCHEMA
|
||||
}
|
|
@ -1,14 +1,19 @@
|
|||
package dev.langchain4j.model.chat;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.emptySet;
|
||||
|
||||
/**
|
||||
* Represents a language model that has a chat interface.
|
||||
|
@ -80,4 +85,14 @@ public interface ChatLanguageModel {
|
|||
default Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
throw new IllegalArgumentException("Tools are currently not supported by this model");
|
||||
}
|
||||
|
||||
@Experimental
|
||||
default ChatResponse chat(ChatRequest request) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Experimental
|
||||
default Set<Capability> supportedCapabilities() {
|
||||
return emptySet();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
package dev.langchain4j.model.chat.request;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static java.util.Arrays.asList;
|
||||
|
||||
@Experimental
|
||||
public class ChatRequest {
|
||||
|
||||
private final List<ChatMessage> messages;
|
||||
private final List<ToolSpecification> toolSpecifications;
|
||||
private final ResponseFormat responseFormat;
|
||||
|
||||
private ChatRequest(Builder builder) {
|
||||
this.messages = new ArrayList<>(ensureNotEmpty(builder.messages, "messages"));
|
||||
this.toolSpecifications = copyIfNotNull(builder.toolSpecifications);
|
||||
this.responseFormat = builder.responseFormat;
|
||||
}
|
||||
|
||||
public List<ChatMessage> messages() {
|
||||
return messages;
|
||||
}
|
||||
|
||||
public List<ToolSpecification> toolSpecifications() {
|
||||
return toolSpecifications;
|
||||
}
|
||||
|
||||
public ResponseFormat responseFormat() {
|
||||
return responseFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ChatRequest that = (ChatRequest) o;
|
||||
return Objects.equals(this.messages, that.messages)
|
||||
&& Objects.equals(this.toolSpecifications, that.toolSpecifications)
|
||||
&& Objects.equals(this.responseFormat, that.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(messages, toolSpecifications, responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ChatRequest {" +
|
||||
" messages = " + messages +
|
||||
", toolSpecifications = " + toolSpecifications +
|
||||
", responseFormat = " + responseFormat +
|
||||
" }";
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private List<ChatMessage> messages;
|
||||
private List<ToolSpecification> toolSpecifications;
|
||||
private ResponseFormat responseFormat;
|
||||
|
||||
public Builder messages(List<ChatMessage> messages) {
|
||||
this.messages = messages;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder messages(ChatMessage... messages) {
|
||||
return messages(asList(messages));
|
||||
}
|
||||
|
||||
public Builder toolSpecifications(List<ToolSpecification> toolSpecifications) {
|
||||
this.toolSpecifications = toolSpecifications;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder toolSpecifications(ToolSpecification... toolSpecifications) {
|
||||
return toolSpecifications(asList(toolSpecifications));
|
||||
}
|
||||
|
||||
public Builder responseFormat(ResponseFormat responseFormat) {
|
||||
this.responseFormat = responseFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatRequest build() {
|
||||
return new ChatRequest(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package dev.langchain4j.model.chat.request;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
|
||||
|
||||
@Experimental
|
||||
public class ResponseFormat {
|
||||
|
||||
private final ResponseFormatType type;
|
||||
private final JsonSchema jsonSchema;
|
||||
|
||||
private ResponseFormat(Builder builder) {
|
||||
this.type = ensureNotNull(builder.type, "type");
|
||||
this.jsonSchema = builder.jsonSchema;
|
||||
if (jsonSchema != null && type != JSON) {
|
||||
throw new IllegalStateException("JsonSchema can be specified only when type=JSON");
|
||||
}
|
||||
}
|
||||
|
||||
public ResponseFormatType type() {
|
||||
return type;
|
||||
}
|
||||
|
||||
public JsonSchema jsonSchema() {
|
||||
return jsonSchema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ResponseFormat that = (ResponseFormat) o;
|
||||
return Objects.equals(this.type, that.type)
|
||||
&& Objects.equals(this.jsonSchema, that.jsonSchema);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(type, jsonSchema);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ResponseFormat {" +
|
||||
" type = " + type +
|
||||
", jsonSchema = " + jsonSchema +
|
||||
" }";
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private ResponseFormatType type;
|
||||
private JsonSchema jsonSchema;
|
||||
|
||||
public Builder type(ResponseFormatType type) {
|
||||
this.type = type;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder jsonSchema(JsonSchema jsonSchema) {
|
||||
this.jsonSchema = jsonSchema;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ResponseFormat build() {
|
||||
return new ResponseFormat(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package dev.langchain4j.model.chat.request;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
@Experimental
|
||||
public enum ResponseFormatType {
|
||||
|
||||
TEXT, JSON
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
@Experimental
|
||||
public class JsonArraySchema implements JsonSchemaElement {
|
||||
|
||||
private final String description;
|
||||
private final JsonSchemaElement items;
|
||||
|
||||
public JsonArraySchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
this.items = ensureNotNull(builder.items, "items");
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public JsonSchemaElement items() {
|
||||
return items;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
private JsonSchemaElement items;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder items(JsonSchemaElement items) {
|
||||
this.items = items;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonArraySchema build() {
|
||||
return new JsonArraySchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonArraySchema that = (JsonArraySchema) o;
|
||||
return Objects.equals(this.description, that.description)
|
||||
&& Objects.equals(this.items, that.items);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description, items);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonArraySchema {" +
|
||||
"description = " + quoted(description) +
|
||||
", items = " + items +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
|
||||
@Experimental
|
||||
public class JsonBooleanSchema implements JsonSchemaElement {
|
||||
|
||||
public static final JsonBooleanSchema JSON_BOOLEAN_SCHEMA = JsonBooleanSchema.builder().build();
|
||||
|
||||
private final String description;
|
||||
|
||||
public JsonBooleanSchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonBooleanSchema build() {
|
||||
return new JsonBooleanSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonBooleanSchema that = (JsonBooleanSchema) o;
|
||||
return Objects.equals(this.description, that.description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonBooleanSchema {" +
|
||||
"description = " + quoted(description) +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Arrays.stream;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
@Experimental
|
||||
public class JsonEnumSchema implements JsonSchemaElement {
|
||||
|
||||
private final String description;
|
||||
@SerializedName("enum")
|
||||
private final List<String> enumValues;
|
||||
|
||||
public JsonEnumSchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
this.enumValues = new ArrayList<>(ensureNotEmpty(builder.enumValues, "enumValues"));
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public List<String> enumValues() {
|
||||
return enumValues;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
private List<String> enumValues;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder enumValues(List<String> enumValues) {
|
||||
this.enumValues = enumValues;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder enumValues(String... enumValues) {
|
||||
return enumValues(asList(enumValues));
|
||||
}
|
||||
|
||||
public Builder enumValues(Class<?> enumClass) {
|
||||
if (!enumClass.isEnum()) {
|
||||
throw new RuntimeException("Class " + enumClass.getName() + " must be enum");
|
||||
}
|
||||
|
||||
List<String> enumValues = stream(enumClass.getEnumConstants())
|
||||
.map(Object::toString)
|
||||
.collect(toList());
|
||||
|
||||
return enumValues(enumValues);
|
||||
}
|
||||
|
||||
public JsonEnumSchema build() {
|
||||
return new JsonEnumSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonEnumSchema that = (JsonEnumSchema) o;
|
||||
return Objects.equals(this.description, that.description)
|
||||
&& Objects.equals(this.enumValues, that.enumValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description, enumValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonEnumSchema {" +
|
||||
"description = " + quoted(description) +
|
||||
", enumValues = " + enumValues +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
|
||||
@Experimental
|
||||
public class JsonIntegerSchema implements JsonSchemaElement {
|
||||
|
||||
public static final JsonIntegerSchema JSON_INTEGER_SCHEMA = JsonIntegerSchema.builder().build();
|
||||
|
||||
private final String description;
|
||||
|
||||
public JsonIntegerSchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonIntegerSchema build() {
|
||||
return new JsonIntegerSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonIntegerSchema that = (JsonIntegerSchema) o;
|
||||
return Objects.equals(this.description, that.description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonIntegerSchema {" +
|
||||
"description = " + quoted(description) +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
|
||||
@Experimental
|
||||
public class JsonNumberSchema implements JsonSchemaElement {
|
||||
|
||||
public static final JsonNumberSchema JSON_NUMBER_SCHEMA = JsonNumberSchema.builder().build();
|
||||
|
||||
private final String description;
|
||||
|
||||
public JsonNumberSchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonNumberSchema build() {
|
||||
return new JsonNumberSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonNumberSchema that = (JsonNumberSchema) o;
|
||||
return Objects.equals(this.description, that.description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonNumberSchema {" +
|
||||
"description = " + quoted(description) +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static java.util.Arrays.asList;
|
||||
|
||||
@Experimental
|
||||
public class JsonObjectSchema implements JsonSchemaElement {
|
||||
|
||||
private final String description;
|
||||
private final Map<String, JsonSchemaElement> properties;
|
||||
private final List<String> required;
|
||||
@SerializedName("additionalProperties")
|
||||
private final Boolean additionalProperties;
|
||||
|
||||
public JsonObjectSchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
this.properties = new LinkedHashMap<>(ensureNotEmpty(builder.properties, "properties"));
|
||||
this.required = copyIfNotNull(builder.required);
|
||||
this.additionalProperties = builder.additionalProperties;
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public Map<String, JsonSchemaElement> properties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
public List<String> required() {
|
||||
return required;
|
||||
}
|
||||
|
||||
public Boolean additionalProperties() {
|
||||
return additionalProperties;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
private Map<String, JsonSchemaElement> properties = new LinkedHashMap<>();
|
||||
private List<String> required = new ArrayList<>();
|
||||
private Boolean additionalProperties;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder properties(Map<String, JsonSchemaElement> properties) {
|
||||
this.properties = properties;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder required(List<String> required) {
|
||||
this.required = required;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder required(String... required) {
|
||||
return required(asList(required));
|
||||
}
|
||||
|
||||
public Builder additionalProperties(Boolean additionalProperties) {
|
||||
this.additionalProperties = additionalProperties;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonObjectSchema build() {
|
||||
return new JsonObjectSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonObjectSchema that = (JsonObjectSchema) o;
|
||||
return Objects.equals(this.description, that.description)
|
||||
&& Objects.equals(this.properties, that.properties)
|
||||
&& Objects.equals(this.required, that.required)
|
||||
&& Objects.equals(this.additionalProperties, that.additionalProperties);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description, properties, required, additionalProperties);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonObjectSchema {" +
|
||||
"description = " + quoted(description) +
|
||||
", properties = " + properties +
|
||||
", required = " + required +
|
||||
", additionalProperties = " + additionalProperties +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
|
||||
@Experimental
|
||||
public class JsonSchema {
|
||||
|
||||
private final String name;
|
||||
private final JsonObjectSchema schema;
|
||||
|
||||
public JsonSchema(Builder builder) {
|
||||
this.name = builder.name;
|
||||
this.schema = builder.schema;
|
||||
}
|
||||
|
||||
public String name() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public JsonObjectSchema schema() {
|
||||
return schema;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String name;
|
||||
private JsonObjectSchema schema;
|
||||
|
||||
public Builder name(String name) {
|
||||
this.name = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder schema(JsonObjectSchema schema) {
|
||||
this.schema = schema;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonSchema build() {
|
||||
return new JsonSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonSchema that = (JsonSchema) o;
|
||||
return Objects.equals(this.name, that.name)
|
||||
&& Objects.equals(this.schema, that.schema);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(name, schema);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonSchema {" +
|
||||
" name = " + quoted(name) +
|
||||
", schema = " + schema +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
@Experimental
|
||||
public interface JsonSchemaElement {
|
||||
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.model.chat.request.json;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
|
||||
@Experimental
|
||||
public class JsonStringSchema implements JsonSchemaElement {
|
||||
|
||||
public static final JsonStringSchema JSON_STRING_SCHEMA = JsonStringSchema.builder().build();
|
||||
|
||||
private final String description;
|
||||
|
||||
public JsonStringSchema(Builder builder) {
|
||||
this.description = builder.description;
|
||||
}
|
||||
|
||||
public String description() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String description;
|
||||
|
||||
public Builder description(String description) {
|
||||
this.description = description;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonStringSchema build() {
|
||||
return new JsonStringSchema(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
JsonStringSchema that = (JsonStringSchema) o;
|
||||
return Objects.equals(this.description, that.description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(description);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "JsonStringSchema {" +
|
||||
"description = " + quoted(description) +
|
||||
" }";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package dev.langchain4j.model.chat.response;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
@Experimental
|
||||
public class ChatResponse {
|
||||
|
||||
private final AiMessage aiMessage;
|
||||
private final TokenUsage tokenUsage;
|
||||
private final FinishReason finishReason;
|
||||
|
||||
private ChatResponse(Builder builder) {
|
||||
this.aiMessage = ensureNotNull(builder.aiMessage, "aiMessage");
|
||||
this.tokenUsage = builder.tokenUsage;
|
||||
this.finishReason = builder.finishReason;
|
||||
}
|
||||
|
||||
public AiMessage aiMessage() {
|
||||
return aiMessage;
|
||||
}
|
||||
|
||||
public TokenUsage tokenUsage() {
|
||||
return tokenUsage;
|
||||
}
|
||||
|
||||
public FinishReason finishReason() {
|
||||
return finishReason;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ChatResponse that = (ChatResponse) o;
|
||||
return Objects.equals(this.aiMessage, that.aiMessage)
|
||||
&& Objects.equals(this.tokenUsage, that.tokenUsage)
|
||||
&& Objects.equals(this.finishReason, that.finishReason);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(aiMessage, tokenUsage, finishReason);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ChatResponse {" +
|
||||
" aiMessage = " + aiMessage +
|
||||
", tokenUsage = " + tokenUsage +
|
||||
", finishReason = " + finishReason +
|
||||
" }";
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private AiMessage aiMessage;
|
||||
private TokenUsage tokenUsage;
|
||||
private FinishReason finishReason;
|
||||
|
||||
public Builder aiMessage(AiMessage aiMessage) {
|
||||
this.aiMessage = aiMessage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder tokenUsage(TokenUsage tokenUsage) {
|
||||
this.tokenUsage = tokenUsage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder finishReason(FinishReason finishReason) {
|
||||
this.finishReason = finishReason;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatResponse build() {
|
||||
return new ChatResponse(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,18 +4,20 @@ import java.lang.annotation.Retention;
|
|||
import java.lang.annotation.Target;
|
||||
|
||||
import static java.lang.annotation.ElementType.FIELD;
|
||||
import static java.lang.annotation.ElementType.TYPE;
|
||||
import static java.lang.annotation.RetentionPolicy.RUNTIME;
|
||||
|
||||
/**
|
||||
* Annotation to attach a description to a class field.
|
||||
*/
|
||||
@Target(FIELD)
|
||||
@Target({FIELD, TYPE})
|
||||
@Retention(RUNTIME)
|
||||
public @interface Description {
|
||||
|
||||
/**
|
||||
* The description can be defined in one line or multiple lines.
|
||||
* If the description is defined in multiple lines, the lines will be joined with a space (" ") automatically.
|
||||
*
|
||||
* @return The description.
|
||||
*/
|
||||
String[] value();
|
||||
|
|
|
@ -1,31 +1,64 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.chat.*;
|
||||
import dev.ai4j.openai4j.chat.AssistantMessage;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
||||
import dev.ai4j.openai4j.chat.ContentType;
|
||||
import dev.ai4j.openai4j.chat.Function;
|
||||
import dev.ai4j.openai4j.chat.FunctionCall;
|
||||
import dev.ai4j.openai4j.chat.FunctionMessage;
|
||||
import dev.ai4j.openai4j.chat.ImageDetail;
|
||||
import dev.ai4j.openai4j.chat.ImageUrl;
|
||||
import dev.ai4j.openai4j.chat.Message;
|
||||
import dev.ai4j.openai4j.chat.Tool;
|
||||
import dev.ai4j.openai4j.chat.ToolCall;
|
||||
import dev.ai4j.openai4j.chat.ToolMessage;
|
||||
import dev.ai4j.openai4j.shared.Usage;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolParameters;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.Content;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.TextContent;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.ai4j.openai4j.chat.ContentType.IMAGE_URL;
|
||||
import static dev.ai4j.openai4j.chat.ContentType.TEXT;
|
||||
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_OBJECT;
|
||||
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_SCHEMA;
|
||||
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static dev.langchain4j.model.chat.request.ResponseFormatType.TEXT;
|
||||
import static dev.langchain4j.model.output.FinishReason.CONTENT_FILTER;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
import static java.lang.String.format;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
@ -129,14 +162,14 @@ public class InternalOpenAiHelper {
|
|||
|
||||
private static dev.ai4j.openai4j.chat.Content toOpenAiContent(TextContent content) {
|
||||
return dev.ai4j.openai4j.chat.Content.builder()
|
||||
.type(TEXT)
|
||||
.type(ContentType.TEXT)
|
||||
.text(content.text())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static dev.ai4j.openai4j.chat.Content toOpenAiContent(ImageContent content) {
|
||||
return dev.ai4j.openai4j.chat.Content.builder()
|
||||
.type(IMAGE_URL)
|
||||
.type(ContentType.IMAGE_URL)
|
||||
.imageUrl(ImageUrl.builder()
|
||||
.url(toUrl(content.image()))
|
||||
.detail(toDetail(content.detailLevel()))
|
||||
|
@ -158,23 +191,26 @@ public class InternalOpenAiHelper {
|
|||
return ImageDetail.valueOf(detailLevel.name());
|
||||
}
|
||||
|
||||
public static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications) {
|
||||
public static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications, boolean strict) {
|
||||
return toolSpecifications.stream()
|
||||
.map(InternalOpenAiHelper::toTool)
|
||||
.map((ToolSpecification toolSpecification) -> toTool(toolSpecification, strict))
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
private static Tool toTool(ToolSpecification toolSpecification) {
|
||||
Function function = Function.builder()
|
||||
private static Tool toTool(ToolSpecification toolSpecification, boolean strict) {
|
||||
Function.Builder functionBuilder = Function.builder()
|
||||
.name(toolSpecification.name())
|
||||
.description(toolSpecification.description())
|
||||
.parameters(toOpenAiParameters(toolSpecification.parameters()))
|
||||
.build();
|
||||
.parameters(toOpenAiParameters(toolSpecification.parameters(), strict));
|
||||
if (strict) {
|
||||
functionBuilder.strict(true);
|
||||
}
|
||||
Function function = functionBuilder.build();
|
||||
return Tool.from(function);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Functions are deprecated by OpenAI, use {@link #toTools(Collection)} instead
|
||||
* @deprecated Functions are deprecated by OpenAI, use {@link #toTools(Collection, boolean)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public static List<Function> toFunctions(Collection<ToolSpecification> toolSpecifications) {
|
||||
|
@ -184,25 +220,101 @@ public class InternalOpenAiHelper {
|
|||
}
|
||||
|
||||
/**
|
||||
* @deprecated Functions are deprecated by OpenAI, use {@link #toTool(ToolSpecification)} ()} instead
|
||||
* @deprecated Functions are deprecated by OpenAI, use {@link #toTool(ToolSpecification, boolean)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
private static Function toFunction(ToolSpecification toolSpecification) {
|
||||
return Function.builder()
|
||||
.name(toolSpecification.name())
|
||||
.description(toolSpecification.description())
|
||||
.parameters(toOpenAiParameters(toolSpecification.parameters()))
|
||||
.parameters(toOpenAiParameters(toolSpecification.parameters(), false))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static dev.ai4j.openai4j.chat.Parameters toOpenAiParameters(ToolParameters toolParameters) {
|
||||
private static dev.ai4j.openai4j.chat.JsonObjectSchema toOpenAiParameters(ToolParameters toolParameters, boolean strict) {
|
||||
if (toolParameters == null) {
|
||||
return dev.ai4j.openai4j.chat.Parameters.builder().build();
|
||||
dev.ai4j.openai4j.chat.JsonObjectSchema.Builder builder = dev.ai4j.openai4j.chat.JsonObjectSchema.builder();
|
||||
if (strict) {
|
||||
// when strict, additionalProperties must be false:
|
||||
// https://platform.openai.com/docs/guides/structured-outputs/additionalproperties-false-must-always-be-set-in-objects
|
||||
builder.additionalProperties(false);
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
dev.ai4j.openai4j.chat.JsonObjectSchema.Builder builder = dev.ai4j.openai4j.chat.JsonObjectSchema.builder()
|
||||
.properties(toOpenAiProperties(toolParameters.properties(), strict))
|
||||
.required(toolParameters.required());
|
||||
if (strict) {
|
||||
builder
|
||||
// when strict, all fields must be required:
|
||||
// https://platform.openai.com/docs/guides/structured-outputs/all-fields-must-be-required
|
||||
.required(new ArrayList<>(toolParameters.properties().keySet()))
|
||||
// when strict, additionalProperties must be false:
|
||||
// https://platform.openai.com/docs/guides/structured-outputs/additionalproperties-false-must-always-be-set-in-objects
|
||||
.additionalProperties(false);
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private static Map<String, dev.ai4j.openai4j.chat.JsonSchemaElement> toOpenAiProperties(Map<String, ?> properties, boolean strict) {
|
||||
Map<String, dev.ai4j.openai4j.chat.JsonSchemaElement> openAiProperties = new LinkedHashMap<>();
|
||||
properties.forEach((key, value) ->
|
||||
openAiProperties.put(key, toOpenAiJsonSchemaElement((Map<String, ?>) value, strict)));
|
||||
return openAiProperties;
|
||||
}
|
||||
|
||||
private static dev.ai4j.openai4j.chat.JsonSchemaElement toOpenAiJsonSchemaElement(Map<String, ?> properties, boolean strict) {
|
||||
// TODO rewrite when JsonSchemaElement will be used for ToolSpecification.properties
|
||||
Object type = properties.get("type");
|
||||
String description = (String) properties.get("description");
|
||||
if ("object".equals(type)) {
|
||||
List<String> required = (List<String>) properties.get("required");
|
||||
dev.ai4j.openai4j.chat.JsonObjectSchema.Builder builder = dev.ai4j.openai4j.chat.JsonObjectSchema.builder()
|
||||
.description(description)
|
||||
.properties(toOpenAiProperties((Map<String, ?>) properties.get("properties"), strict));
|
||||
if (required != null) {
|
||||
builder.required(required);
|
||||
}
|
||||
if (strict) {
|
||||
builder
|
||||
// when strict, all fields must be required:
|
||||
// https://platform.openai.com/docs/guides/structured-outputs/all-fields-must-be-required
|
||||
.required(new ArrayList<>(((Map<String, ?>) properties.get("properties")).keySet()))
|
||||
// when strict, additionalProperties must be false:
|
||||
// https://platform.openai.com/docs/guides/structured-outputs/additionalproperties-false-must-always-be-set-in-objects
|
||||
.additionalProperties(false);
|
||||
}
|
||||
return builder.build();
|
||||
} else if ("array".equals(type)) {
|
||||
return dev.ai4j.openai4j.chat.JsonArraySchema.builder()
|
||||
.description(description)
|
||||
.items(toOpenAiJsonSchemaElement((Map<String, ?>) properties.get("items"), strict))
|
||||
.build();
|
||||
} else if (properties.get("enum") != null) {
|
||||
return dev.ai4j.openai4j.chat.JsonEnumSchema.builder()
|
||||
.description(description)
|
||||
.enumValues((List<String>) properties.get("enum"))
|
||||
.build();
|
||||
} else if ("string".equals(type)) {
|
||||
return dev.ai4j.openai4j.chat.JsonStringSchema.builder()
|
||||
.description(description)
|
||||
.build();
|
||||
} else if ("integer".equals(type)) {
|
||||
return dev.ai4j.openai4j.chat.JsonIntegerSchema.builder()
|
||||
.description(description)
|
||||
.build();
|
||||
} else if ("number".equals(type)) {
|
||||
return dev.ai4j.openai4j.chat.JsonNumberSchema.builder()
|
||||
.description(description)
|
||||
.build();
|
||||
} else if ("boolean".equals(type)) {
|
||||
return dev.ai4j.openai4j.chat.JsonBooleanSchema.builder()
|
||||
.description(description)
|
||||
.build();
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unknown type " + type);
|
||||
}
|
||||
return dev.ai4j.openai4j.chat.Parameters.builder()
|
||||
.properties(toolParameters.properties())
|
||||
.required(toolParameters.required())
|
||||
.build();
|
||||
}
|
||||
|
||||
public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
|
||||
|
@ -317,4 +429,64 @@ public class InternalOpenAiHelper {
|
|||
.aiMessage(response.content())
|
||||
.build();
|
||||
}
|
||||
|
||||
static dev.ai4j.openai4j.chat.ResponseFormat toOpenAiResponseFormat(ResponseFormat responseFormat, Boolean strict) {
|
||||
if (responseFormat == null || responseFormat.type() == TEXT) {
|
||||
return null;
|
||||
}
|
||||
|
||||
JsonSchema jsonSchema = responseFormat.jsonSchema();
|
||||
if (jsonSchema == null) {
|
||||
return new dev.ai4j.openai4j.chat.ResponseFormat(JSON_OBJECT, null);
|
||||
} else {
|
||||
dev.ai4j.openai4j.chat.JsonSchema openAiJsonSchema = dev.ai4j.openai4j.chat.JsonSchema.builder()
|
||||
.name(jsonSchema.name())
|
||||
.strict(strict)
|
||||
.schema((dev.ai4j.openai4j.chat.JsonObjectSchema) toOpenAiJsonSchemaElement(jsonSchema.schema()))
|
||||
.build();
|
||||
return new dev.ai4j.openai4j.chat.ResponseFormat(JSON_SCHEMA, openAiJsonSchema);
|
||||
}
|
||||
}
|
||||
|
||||
private static dev.ai4j.openai4j.chat.JsonSchemaElement toOpenAiJsonSchemaElement(JsonSchemaElement jsonSchemaElement) {
|
||||
if (jsonSchemaElement instanceof JsonStringSchema) {
|
||||
return dev.ai4j.openai4j.chat.JsonStringSchema.builder()
|
||||
.description(((JsonStringSchema) jsonSchemaElement).description())
|
||||
.build();
|
||||
} else if (jsonSchemaElement instanceof JsonIntegerSchema) {
|
||||
return dev.ai4j.openai4j.chat.JsonIntegerSchema.builder()
|
||||
.description(((JsonIntegerSchema) jsonSchemaElement).description())
|
||||
.build();
|
||||
} else if (jsonSchemaElement instanceof JsonNumberSchema) {
|
||||
return dev.ai4j.openai4j.chat.JsonNumberSchema.builder()
|
||||
.description(((JsonNumberSchema) jsonSchemaElement).description())
|
||||
.build();
|
||||
} else if (jsonSchemaElement instanceof JsonBooleanSchema) {
|
||||
return dev.ai4j.openai4j.chat.JsonBooleanSchema.builder()
|
||||
.description(((JsonBooleanSchema) jsonSchemaElement).description())
|
||||
.build();
|
||||
} else if (jsonSchemaElement instanceof JsonEnumSchema) {
|
||||
return dev.ai4j.openai4j.chat.JsonEnumSchema.builder()
|
||||
.description(((JsonEnumSchema) jsonSchemaElement).description())
|
||||
.enumValues(((JsonEnumSchema) jsonSchemaElement).enumValues())
|
||||
.build();
|
||||
} else if (jsonSchemaElement instanceof JsonArraySchema) {
|
||||
return dev.ai4j.openai4j.chat.JsonArraySchema.builder()
|
||||
.description(((JsonArraySchema) jsonSchemaElement).description())
|
||||
.items(toOpenAiJsonSchemaElement(((JsonArraySchema) jsonSchemaElement).items()))
|
||||
.build();
|
||||
} else if (jsonSchemaElement instanceof JsonObjectSchema) {
|
||||
Map<String, JsonSchemaElement> properties = ((JsonObjectSchema) jsonSchemaElement).properties();
|
||||
Map<String, dev.ai4j.openai4j.chat.JsonSchemaElement> openAiProperties = new LinkedHashMap<>();
|
||||
properties.forEach((key, value) -> openAiProperties.put(key, toOpenAiJsonSchemaElement(value)));
|
||||
return dev.ai4j.openai4j.chat.JsonObjectSchema.builder()
|
||||
.description(((JsonObjectSchema) jsonSchemaElement).description())
|
||||
.properties(openAiProperties)
|
||||
.required(((JsonObjectSchema) jsonSchemaElement).required())
|
||||
.additionalProperties(((JsonObjectSchema) jsonSchemaElement).additionalProperties())
|
||||
.build();
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unknown type: " + jsonSchemaElement);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,13 +4,22 @@ import dev.ai4j.openai4j.OpenAiClient;
|
|||
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
||||
import dev.ai4j.openai4j.chat.ResponseFormat;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.chat.Capability;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
|
@ -19,13 +28,27 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import java.net.Proxy;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
@ -48,9 +71,12 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
private final Double presencePenalty;
|
||||
private final Double frequencyPenalty;
|
||||
private final Map<String, Integer> logitBias;
|
||||
private final String responseFormat;
|
||||
private final ResponseFormat responseFormat;
|
||||
private final Boolean strictJsonSchema;
|
||||
private final Integer seed;
|
||||
private final String user;
|
||||
private final Boolean strictTools;
|
||||
private final Boolean parallelToolCalls;
|
||||
private final Integer maxRetries;
|
||||
private final Tokenizer tokenizer;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
@ -68,8 +94,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
Double frequencyPenalty,
|
||||
Map<String, Integer> logitBias,
|
||||
String responseFormat,
|
||||
Boolean strictJsonSchema,
|
||||
Integer seed,
|
||||
String user,
|
||||
Boolean strictTools,
|
||||
Boolean parallelToolCalls,
|
||||
Duration timeout,
|
||||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
|
@ -108,9 +137,12 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
this.presencePenalty = presencePenalty;
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
this.logitBias = logitBias;
|
||||
this.responseFormat = responseFormat;
|
||||
this.responseFormat = responseFormat == null ? null : new ResponseFormat(responseFormat, null);
|
||||
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
|
||||
this.seed = seed;
|
||||
this.user = user;
|
||||
this.strictTools = getOrDefault(strictTools, false);
|
||||
this.parallelToolCalls = parallelToolCalls;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
|
@ -120,25 +152,56 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
return modelName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Capability> supportedCapabilities() {
|
||||
Set<Capability> capabilities = new HashSet<>();
|
||||
if (responseFormat != null && "json_schema".equals(responseFormat.type())) {
|
||||
capabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
|
||||
}
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
return generate(messages, null, null);
|
||||
return generate(messages, null, null, this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
return generate(messages, toolSpecifications, null);
|
||||
return generate(messages, toolSpecifications, null, this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(messages, singletonList(toolSpecification), toolSpecification);
|
||||
return generate(messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse chat(ChatRequest request) {
|
||||
Response<AiMessage> response = generate(
|
||||
request.messages(),
|
||||
request.toolSpecifications(),
|
||||
null,
|
||||
getOrDefault(toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema), this.responseFormat)
|
||||
);
|
||||
return ChatResponse.builder()
|
||||
.aiMessage(response.content())
|
||||
.tokenUsage(response.tokenUsage())
|
||||
.finishReason(response.finishReason())
|
||||
.build();
|
||||
}
|
||||
|
||||
private Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted
|
||||
) {
|
||||
ToolSpecification toolThatMustBeExecuted,
|
||||
ResponseFormat responseFormat) {
|
||||
|
||||
if (responseFormat != null
|
||||
&& "json_schema".equals(responseFormat.type())
|
||||
&& responseFormat.jsonSchema() == null) {
|
||||
responseFormat = null;
|
||||
}
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOpenAiMessages(messages))
|
||||
|
@ -151,10 +214,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
.logitBias(logitBias)
|
||||
.responseFormat(responseFormat)
|
||||
.seed(seed)
|
||||
.user(user);
|
||||
.user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
|
||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
||||
}
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
|
||||
|
|
|
@ -54,6 +54,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
private final String responseFormat;
|
||||
private final Integer seed;
|
||||
private final String user;
|
||||
private final Boolean strictTools;
|
||||
private final Boolean parallelToolCalls;
|
||||
private final Tokenizer tokenizer;
|
||||
private final boolean isOpenAiModel;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
@ -73,6 +75,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
String responseFormat,
|
||||
Integer seed,
|
||||
String user,
|
||||
Boolean strictTools,
|
||||
Boolean parallelToolCalls,
|
||||
Duration timeout,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
|
@ -108,6 +112,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
this.responseFormat = responseFormat;
|
||||
this.seed = seed;
|
||||
this.user = user;
|
||||
this.strictTools = getOrDefault(strictTools, false);
|
||||
this.parallelToolCalls = parallelToolCalls;
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.isOpenAiModel = isOpenAiModel(this.modelName);
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
|
@ -150,13 +156,14 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
.logitBias(logitBias)
|
||||
.responseFormat(responseFormat)
|
||||
.seed(seed)
|
||||
.user(user);
|
||||
.user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted)));
|
||||
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted), strictTools));
|
||||
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
|
||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
||||
}
|
||||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
|
|
@ -44,11 +44,6 @@ class OpenAiTokenizerIT {
|
|||
GPT_4_32K_0613
|
||||
));
|
||||
|
||||
private static final Set<ChatCompletionModel> MODELS_WITHOUT_TOOL_SUPPORT = new HashSet<>(asList(
|
||||
GPT_4_0314,
|
||||
GPT_4_VISION_PREVIEW
|
||||
));
|
||||
|
||||
private static final Set<ChatCompletionModel> MODELS_WITH_PARALLEL_TOOL_SUPPORT = new HashSet<>(asList(
|
||||
// TODO add GPT_3_5_TURBO once it points to GPT_3_5_TURBO_1106
|
||||
GPT_3_5_TURBO_1106,
|
||||
|
@ -151,7 +146,6 @@ class OpenAiTokenizerIT {
|
|||
static Stream<Arguments> should_count_tokens_in_messages_with_single_tool() {
|
||||
return stream(ChatCompletionModel.values())
|
||||
.filter(model -> !MODELS_WITHOUT_ACCESS.contains(model))
|
||||
.filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model))
|
||||
.flatMap(model -> Stream.of(
|
||||
|
||||
// various tool "name" lengths
|
||||
|
@ -798,7 +792,6 @@ class OpenAiTokenizerIT {
|
|||
static Stream<Arguments> should_count_tokens_in_tool_specifications() {
|
||||
return stream(ChatCompletionModel.values())
|
||||
.filter(model -> !MODELS_WITHOUT_ACCESS.contains(model))
|
||||
.filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model))
|
||||
.flatMap(model -> Stream.of(
|
||||
|
||||
// "name" of various lengths
|
||||
|
@ -1146,7 +1139,6 @@ class OpenAiTokenizerIT {
|
|||
static Stream<Arguments> should_count_tokens_in_tool_execution_request() {
|
||||
return stream(ChatCompletionModel.values())
|
||||
.filter(model -> !MODELS_WITHOUT_ACCESS.contains(model))
|
||||
.filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model))
|
||||
.flatMap(model -> Stream.of(
|
||||
|
||||
// no arguments, different lengths of "name"
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<project.build.outputTimestamp>1714382357</project.build.outputTimestamp>
|
||||
<openai4j.version>0.17.0</openai4j.version>
|
||||
<openai4j.version>0.18.0</openai4j.version>
|
||||
<azure-ai-openai.version>1.0.0-beta.10</azure-ai-openai.version>
|
||||
<azure-ai-search.version>11.7.0</azure-ai-search.version>
|
||||
<azure.storage-blob.version>12.27.0</azure.storage-blob.version>
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
package dev.langchain4j.service;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.memory.ChatMemory;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.input.structured.StructuredPrompt;
|
||||
|
@ -19,8 +25,19 @@ import dev.langchain4j.service.output.ServiceOutputParser;
|
|||
import dev.langchain4j.service.tool.ToolExecutor;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.*;
|
||||
import java.util.*;
|
||||
import java.lang.reflect.Array;
|
||||
import java.lang.reflect.InvocationHandler;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Parameter;
|
||||
import java.lang.reflect.Proxy;
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Scanner;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
@ -29,7 +46,10 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
|
|||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Exceptions.runtime;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
|
||||
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
||||
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
|
||||
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
|
||||
import static dev.langchain4j.service.output.JsonSchemas.jsonSchemaFrom;
|
||||
|
||||
class DefaultAiServices<T> extends AiServices<T> {
|
||||
|
||||
|
@ -111,12 +131,16 @@ class DefaultAiServices<T> extends AiServices<T> {
|
|||
|
||||
// TODO give user ability to provide custom OutputParser
|
||||
Type returnType = method.getGenericReturnType();
|
||||
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
|
||||
String text = userMessage.singleText() + outputFormatInstructions;
|
||||
if (isNotNullOrBlank(userMessage.name())) {
|
||||
userMessage = UserMessage.from(userMessage.name(), text);
|
||||
} else {
|
||||
userMessage = UserMessage.from(text);
|
||||
|
||||
boolean supportsJsonSchema = supportsJsonSchema();
|
||||
Optional<JsonSchema> jsonSchema = Optional.empty();
|
||||
if (supportsJsonSchema) {
|
||||
jsonSchema = jsonSchemaFrom(returnType);
|
||||
}
|
||||
|
||||
if (!supportsJsonSchema || !jsonSchema.isPresent()) {
|
||||
// TODO append after storing in the memory?
|
||||
userMessage = appendOutputFormatInstructions(returnType, userMessage);
|
||||
}
|
||||
|
||||
if (context.hasChatMemory()) {
|
||||
|
@ -140,9 +164,31 @@ class DefaultAiServices<T> extends AiServices<T> {
|
|||
return new AiServiceTokenStream(messages, context, memoryId); // TODO moderation
|
||||
}
|
||||
|
||||
Response<AiMessage> response = context.toolSpecifications == null
|
||||
? context.chatModel.generate(messages)
|
||||
: context.chatModel.generate(messages, context.toolSpecifications);
|
||||
Response<AiMessage> response;
|
||||
if (supportsJsonSchema && jsonSchema.isPresent()) {
|
||||
ChatRequest chatRequest = ChatRequest.builder()
|
||||
.messages(messages)
|
||||
.toolSpecifications(context.toolSpecifications)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(jsonSchema.get())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
|
||||
|
||||
response = new Response<>(
|
||||
chatResponse.aiMessage(),
|
||||
chatResponse.tokenUsage(),
|
||||
chatResponse.finishReason()
|
||||
);
|
||||
} else {
|
||||
// TODO migrate to new API
|
||||
response = context.toolSpecifications == null
|
||||
? context.chatModel.generate(messages)
|
||||
: context.chatModel.generate(messages, context.toolSpecifications);
|
||||
}
|
||||
|
||||
TokenUsage tokenUsageAccumulator = response.tokenUsage();
|
||||
|
||||
verifyModerationIfNeeded(moderationFuture);
|
||||
|
@ -206,6 +252,22 @@ class DefaultAiServices<T> extends AiServices<T> {
|
|||
}
|
||||
}
|
||||
|
||||
private boolean supportsJsonSchema() {
|
||||
return context.chatModel != null
|
||||
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
|
||||
}
|
||||
|
||||
private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) {
|
||||
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
|
||||
String text = userMessage.singleText() + outputFormatInstructions;
|
||||
if (isNotNullOrBlank(userMessage.name())) {
|
||||
userMessage = UserMessage.from(userMessage.name(), text);
|
||||
} else {
|
||||
userMessage = UserMessage.from(text);
|
||||
}
|
||||
return userMessage;
|
||||
}
|
||||
|
||||
private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
|
||||
if (method.isAnnotationPresent(Moderate.class)) {
|
||||
return executor.submit(() -> {
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
package dev.langchain4j.service.output;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.service.Result;
|
||||
import dev.langchain4j.service.TokenStream;
|
||||
import dev.langchain4j.service.TypeUtils;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
|
||||
import static dev.langchain4j.internal.TypeUtils.isJsonBoolean;
|
||||
import static dev.langchain4j.internal.TypeUtils.isJsonInteger;
|
||||
import static dev.langchain4j.internal.TypeUtils.isJsonNumber;
|
||||
import static dev.langchain4j.service.TypeUtils.getRawClass;
|
||||
import static dev.langchain4j.service.TypeUtils.resolveFirstGenericParameterClass;
|
||||
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
|
||||
import static java.lang.reflect.Modifier.isStatic;
|
||||
|
||||
@Experimental
|
||||
public class JsonSchemas {
|
||||
|
||||
public static Optional<JsonSchema> jsonSchemaFrom(Type returnType) {
|
||||
|
||||
if (typeHasRawClass(returnType, Result.class)) {
|
||||
returnType = resolveFirstGenericParameterClass(returnType);
|
||||
}
|
||||
|
||||
// TODO validate this earlier
|
||||
if (returnType == void.class) {
|
||||
throw illegalConfiguration("Return type of method '%s' cannot be void");
|
||||
}
|
||||
|
||||
if (!isPojo(returnType)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
Class<?> rawClass = getRawClass(returnType);
|
||||
|
||||
JsonSchema jsonSchema = JsonSchema.builder()
|
||||
.name(rawClass.getSimpleName())
|
||||
.schema(toJsonObjectSchema(rawClass, null))
|
||||
.build();
|
||||
|
||||
return Optional.of(jsonSchema);
|
||||
}
|
||||
|
||||
private static boolean isPojo(Type returnType) {
|
||||
|
||||
if (returnType == String.class
|
||||
|| returnType == AiMessage.class
|
||||
|| returnType == TokenStream.class
|
||||
|| returnType == Response.class) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Explanation (which will make this a lot easier to understand):
|
||||
// In the case of List<String> these two would be set like:
|
||||
// rawClass: List.class
|
||||
// typeArgumentClass: String.class
|
||||
Class<?> rawClass = getRawClass(returnType);
|
||||
Class<?> typeArgumentClass = TypeUtils.resolveFirstGenericParameterClass(returnType);
|
||||
|
||||
Optional<OutputParser<?>> outputParser = new DefaultOutputParserFactory().get(rawClass, typeArgumentClass);
|
||||
if (outputParser.isPresent()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static JsonObjectSchema toJsonObjectSchema(Class<?> type, String description) {
|
||||
|
||||
Map<String, JsonSchemaElement> properties = new LinkedHashMap<>();
|
||||
for (Field field : type.getDeclaredFields()) {
|
||||
String fieldName = field.getName();
|
||||
if (isStatic(field.getModifiers()) || fieldName.equals("__$hits$__") || fieldName.startsWith("this$")) {
|
||||
continue;
|
||||
}
|
||||
String fieldDescription = getDescription(field);
|
||||
JsonSchemaElement jsonSchemaElement = jsonSchema(field.getType(), field.getGenericType(), fieldDescription);
|
||||
properties.put(fieldName, jsonSchemaElement);
|
||||
}
|
||||
|
||||
return JsonObjectSchema.builder()
|
||||
.description(Optional.ofNullable(description).orElse(getDescription(type)))
|
||||
.properties(properties)
|
||||
.required(new ArrayList<>(properties.keySet()))
|
||||
.additionalProperties(false)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static String getDescription(Field field) {
|
||||
return getDescription(field.getAnnotation(Description.class));
|
||||
}
|
||||
|
||||
private static String getDescription(Class<?> type) {
|
||||
return getDescription(type.getAnnotation(Description.class));
|
||||
}
|
||||
|
||||
private static String getDescription(Description description) {
|
||||
if (description == null) {
|
||||
return null;
|
||||
}
|
||||
return String.join(" ", description.value());
|
||||
}
|
||||
|
||||
private static JsonSchemaElement jsonSchema(Class<?> clazz, Type type, String fieldDescription) {
|
||||
|
||||
if (clazz == String.class) {
|
||||
return JsonStringSchema.builder()
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (isJsonInteger(clazz)) {
|
||||
return JsonIntegerSchema.builder()
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (isJsonNumber(clazz)) {
|
||||
return JsonNumberSchema.builder()
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (isJsonBoolean(clazz)) {
|
||||
return JsonBooleanSchema.builder()
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (clazz.isEnum()) {
|
||||
return JsonEnumSchema.builder()
|
||||
.enumValues(clazz)
|
||||
.description(Optional.ofNullable(fieldDescription).orElse(getDescription(clazz)))
|
||||
.build();
|
||||
}
|
||||
|
||||
if (clazz.isArray()) {
|
||||
return JsonArraySchema.builder()
|
||||
.items(jsonSchema(clazz.getComponentType(), null, null))
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (clazz.equals(List.class) || clazz.equals(Set.class)) {
|
||||
return JsonArraySchema.builder()
|
||||
.items(jsonSchema(getActualType(type), null, null))
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
}
|
||||
|
||||
return toJsonObjectSchema(clazz, fieldDescription);
|
||||
}
|
||||
|
||||
private static Class<?> getActualType(Type type) {
|
||||
if (type instanceof ParameterizedType) {
|
||||
ParameterizedType parameterizedType = (ParameterizedType) type;
|
||||
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
|
||||
if (actualTypeArguments.length == 1) {
|
||||
return (Class<?>) actualTypeArguments[0];
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -83,6 +83,7 @@ public class ServiceOutputParser {
|
|||
return "";
|
||||
}
|
||||
|
||||
// TODO validate this earlier
|
||||
if (returnType == void.class) {
|
||||
throw illegalConfiguration("Return type of method '%s' cannot be void");
|
||||
}
|
||||
|
|
|
@ -27,15 +27,23 @@ import java.util.List;
|
|||
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.service.AiServicesIT.Ingredient.*;
|
||||
import static dev.langchain4j.service.AiServicesIT.IssueCategory.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
|
||||
import static dev.langchain4j.service.AiServicesIT.Ingredient.OIL;
|
||||
import static dev.langchain4j.service.AiServicesIT.Ingredient.PEPPER;
|
||||
import static dev.langchain4j.service.AiServicesIT.Ingredient.SALT;
|
||||
import static dev.langchain4j.service.AiServicesIT.IssueCategory.COMFORT_ISSUE;
|
||||
import static dev.langchain4j.service.AiServicesIT.IssueCategory.MAINTENANCE_ISSUE;
|
||||
import static dev.langchain4j.service.AiServicesIT.IssueCategory.OVERALL_EXPERIENCE_ISSUE;
|
||||
import static dev.langchain4j.service.AiServicesIT.IssueCategory.SERVICE_ISSUE;
|
||||
import static dev.langchain4j.service.AiServicesIT.Sentiment.POSITIVE;
|
||||
import static java.time.Month.JULY;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class AiServicesIT {
|
||||
|
@ -45,6 +53,7 @@ public class AiServicesIT {
|
|||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_4_O_MINI)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
|
@ -85,6 +94,7 @@ public class AiServicesIT {
|
|||
verify(chatLanguageModel).generate(singletonList(userMessage("Count number of 'egg' mentions in this sentence:\n" +
|
||||
"|||I have ten eggs in my basket and three in my pocket.|||\n" +
|
||||
"You must answer strictly in the following format: integer number")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,6 +115,7 @@ public class AiServicesIT {
|
|||
assertThat(joke).isNotBlank();
|
||||
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("Tell me a joke about AI")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -135,6 +146,7 @@ public class AiServicesIT {
|
|||
verify(chatLanguageModel).generate(singletonList(userMessage(
|
||||
"Extract date from " + text + "\n" +
|
||||
"You must answer strictly in the following format: yyyy-MM-dd")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -152,6 +164,7 @@ public class AiServicesIT {
|
|||
verify(chatLanguageModel).generate(singletonList(userMessage(
|
||||
"Extract time from " + text + "\n" +
|
||||
"You must answer strictly in the following format: HH:mm:ss")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -169,6 +182,7 @@ public class AiServicesIT {
|
|||
verify(chatLanguageModel).generate(singletonList(userMessage(
|
||||
"Extract date and time from " + text + "\n" +
|
||||
"You must answer strictly in the following format: yyyy-MM-ddTHH:mm:ss")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -201,6 +215,7 @@ public class AiServicesIT {
|
|||
"POSITIVE\n" +
|
||||
"NEUTRAL\n" +
|
||||
"NEGATIVE")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
public enum Weather {
|
||||
|
@ -239,6 +254,7 @@ public class AiServicesIT {
|
|||
"CLOUDY - The sky is covered with clouds with no rain, often creating a gray and overcast appearance\n" +
|
||||
"RAINY - Precipitation in the form of rain, with cloudy skies and wet conditions\n" +
|
||||
"SNOWY - Snowfall occurs, covering the ground in white and creating cold, wintry conditions")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
public enum Ingredient {
|
||||
|
@ -270,6 +286,7 @@ public class AiServicesIT {
|
|||
"PEPPER\n" +
|
||||
"VINEGAR\n" +
|
||||
"OIL")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
public enum IssueCategory {
|
||||
|
@ -320,6 +337,7 @@ public class AiServicesIT {
|
|||
"CONNECTIVITY_ISSUE - The feedback mentions issues with internet connectivity, such as unreliable Wi-Fi\n" +
|
||||
"CHECK_IN_ISSUE - The feedback mentions issues with the check-in process, such as it being tedious and time-consuming\n" +
|
||||
"OVERALL_EXPERIENCE_ISSUE - The feedback mentions a general dissatisfaction with the overall hotel experience due to multiple issues")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@ToString
|
||||
|
@ -377,6 +395,7 @@ public class AiServicesIT {
|
|||
"\"city\": (type: string)\n" +
|
||||
"})\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -424,6 +443,7 @@ public class AiServicesIT {
|
|||
"\"city\": (type: string)\n" +
|
||||
"})\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -481,6 +501,7 @@ public class AiServicesIT {
|
|||
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
|
||||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -504,6 +525,7 @@ public class AiServicesIT {
|
|||
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
|
||||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -527,6 +549,7 @@ public class AiServicesIT {
|
|||
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
|
||||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -550,6 +573,7 @@ public class AiServicesIT {
|
|||
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
|
||||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
interface BadChef {
|
||||
|
@ -626,6 +650,7 @@ public class AiServicesIT {
|
|||
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
|
||||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -657,6 +682,7 @@ public class AiServicesIT {
|
|||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -688,6 +714,7 @@ public class AiServicesIT {
|
|||
"\"preparationTimeMinutes\": (type: integer)\n" +
|
||||
"}")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
interface ProfessionalChef {
|
||||
|
@ -712,6 +739,7 @@ public class AiServicesIT {
|
|||
systemMessage("You are a professional chef. You are friendly, polite and concise."),
|
||||
userMessage(question)
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -738,6 +766,7 @@ public class AiServicesIT {
|
|||
systemMessage("You are a professional translator into german"),
|
||||
userMessage("Translate the following text: Hello, how are you?")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
interface Summarizer {
|
||||
|
@ -764,6 +793,7 @@ public class AiServicesIT {
|
|||
systemMessage("Summarize every message from user in 3 bullet points. Provide only bullet points."),
|
||||
userMessage(text + "\nYou must put every item on a separate line.")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -788,6 +818,7 @@ public class AiServicesIT {
|
|||
.hasMessage("Text \"" + message + "\" violates content policy");
|
||||
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage(message)));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
verify(moderationModel).moderate(singletonList(userMessage(message)));
|
||||
}
|
||||
|
||||
|
@ -806,6 +837,7 @@ public class AiServicesIT {
|
|||
assertThat(response).isNotBlank();
|
||||
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage(message)));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
verify(moderationModel).moderate(singletonList(userMessage(message)));
|
||||
}
|
||||
|
||||
|
@ -839,6 +871,7 @@ public class AiServicesIT {
|
|||
assertThat(result.sources()).isNull();
|
||||
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage(userMessage)));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
|
@ -877,6 +910,6 @@ public class AiServicesIT {
|
|||
"\"bookingId\": (type: string)\n" +
|
||||
"}")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,901 @@
|
|||
package dev.langchain4j.service;
|
||||
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
|
||||
import static dev.langchain4j.model.chat.request.json.JsonBooleanSchema.JSON_BOOLEAN_SCHEMA;
|
||||
import static dev.langchain4j.model.chat.request.json.JsonIntegerSchema.JSON_INTEGER_SCHEMA;
|
||||
import static dev.langchain4j.model.chat.request.json.JsonNumberSchema.JSON_NUMBER_SCHEMA;
|
||||
import static dev.langchain4j.model.chat.request.json.JsonStringSchema.JSON_STRING_SCHEMA;
|
||||
import static dev.langchain4j.service.AiServicesJsonSchemaIT.PersonExtractor3.MaritalStatus.SINGLE;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
public class AiServicesJsonSchemaIT {
|
||||
|
||||
static Stream<ChatLanguageModel> models() {
|
||||
return Stream.of(
|
||||
OpenAiChatModel.builder()
|
||||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_4_O_MINI)
|
||||
.responseFormat("json_schema")
|
||||
.strictJsonSchema(true)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build(),
|
||||
OpenAiChatModel.builder()
|
||||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_4_O_MINI)
|
||||
.responseFormat("json_schema")
|
||||
.strictJsonSchema(false)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor1 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
int age;
|
||||
Double height;
|
||||
boolean married;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_primitives(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor1 personExtractor = AiServices.create(PersonExtractor1.class, model);
|
||||
|
||||
String text = "Klaus is 37 years old, 1.78m height and single";
|
||||
|
||||
// when
|
||||
PersonExtractor1.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.age).isEqualTo(37);
|
||||
assertThat(person.height).isEqualTo(1.78);
|
||||
assertThat(person.married).isFalse();
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("age", JSON_INTEGER_SCHEMA);
|
||||
put("height", JSON_NUMBER_SCHEMA);
|
||||
put("married", JSON_BOOLEAN_SCHEMA);
|
||||
}})
|
||||
.required("name", "age", "height", "married")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor2 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
Address address;
|
||||
}
|
||||
|
||||
class Address {
|
||||
|
||||
String city;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_nested_pojo(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor2 personExtractor = AiServices.create(PersonExtractor2.class, model);
|
||||
|
||||
String text = "Klaus lives in Langley Falls";
|
||||
|
||||
// when
|
||||
PersonExtractor2.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.address.city).isEqualTo("Langley Falls");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("address", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("city", JSON_STRING_SCHEMA);
|
||||
}})
|
||||
.required("city")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "address")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor3 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
MaritalStatus maritalStatus;
|
||||
}
|
||||
|
||||
enum MaritalStatus {
|
||||
|
||||
SINGLE, MARRIED
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_enum(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor3 personExtractor = AiServices.create(PersonExtractor3.class, model);
|
||||
|
||||
String text = "Klaus is single";
|
||||
|
||||
// when
|
||||
PersonExtractor3.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.maritalStatus).isEqualTo(SINGLE);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("maritalStatus", JsonEnumSchema.builder()
|
||||
.enumValues("SINGLE", "MARRIED")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "maritalStatus")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor4 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
String[] favouriteColors;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_array_of_primitives(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor4 personExtractor = AiServices.create(PersonExtractor4.class, model);
|
||||
|
||||
String text = "Klaus likes orange and green";
|
||||
|
||||
// when
|
||||
PersonExtractor4.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.favouriteColors).containsExactly("orange", "green");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("favouriteColors", JsonArraySchema.builder()
|
||||
.items(JSON_STRING_SCHEMA)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "favouriteColors")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor5 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
List<String> favouriteColors;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_list_of_primitives(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor5 personExtractor = AiServices.create(PersonExtractor5.class, model);
|
||||
|
||||
String text = "Klaus likes orange and green";
|
||||
|
||||
// when
|
||||
PersonExtractor5.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.favouriteColors).containsExactly("orange", "green");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("favouriteColors", JsonArraySchema.builder()
|
||||
.items(JSON_STRING_SCHEMA)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "favouriteColors")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor6 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
Set<String> favouriteColors;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_set_of_primitives(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor6 personExtractor = AiServices.create(PersonExtractor6.class, model);
|
||||
|
||||
String text = "Klaus likes orange and green";
|
||||
|
||||
// when
|
||||
PersonExtractor6.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.favouriteColors).containsExactly("orange", "green");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("favouriteColors", JsonArraySchema.builder()
|
||||
.items(JSON_STRING_SCHEMA)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "favouriteColors")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor7 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
Pet[] pets;
|
||||
}
|
||||
|
||||
class Pet {
|
||||
|
||||
String name;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_array_of_pojos(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor7 personExtractor = AiServices.create(PersonExtractor7.class, model);
|
||||
|
||||
String text = "Klaus has 2 pets: Peanut and Muffin";
|
||||
|
||||
// when
|
||||
PersonExtractor7.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.pets).hasSize(2);
|
||||
assertThat(person.pets[0].name).isEqualTo("Peanut");
|
||||
assertThat(person.pets[1].name).isEqualTo("Muffin");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("pets", JsonArraySchema.builder()
|
||||
.items(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "pets")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor8 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
List<Pet> pets;
|
||||
}
|
||||
|
||||
class Pet {
|
||||
|
||||
String name;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_list_of_pojos(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor8 personExtractor = AiServices.create(PersonExtractor8.class, model);
|
||||
|
||||
String text = "Klaus has 2 pets: Peanut and Muffin";
|
||||
|
||||
// when
|
||||
PersonExtractor8.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.pets).hasSize(2);
|
||||
assertThat(person.pets.get(0).name).isEqualTo("Peanut");
|
||||
assertThat(person.pets.get(1).name).isEqualTo("Muffin");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("pets", JsonArraySchema.builder()
|
||||
.items(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "pets")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor9 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
Set<Pet> pets;
|
||||
}
|
||||
|
||||
class Pet {
|
||||
|
||||
String name;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_set_of_pojos(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor9 personExtractor = AiServices.create(PersonExtractor9.class, model);
|
||||
|
||||
String text = "Klaus has 2 pets: Peanut and Muffin";
|
||||
|
||||
// when
|
||||
PersonExtractor9.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.pets).hasSize(2);
|
||||
Iterator<PersonExtractor9.Pet> iterator = person.pets.iterator();
|
||||
assertThat(iterator.next().name).isEqualTo("Peanut");
|
||||
assertThat(iterator.next().name).isEqualTo("Muffin");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("pets", JsonArraySchema.builder()
|
||||
.items(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "pets")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor10 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
Group[] groups;
|
||||
}
|
||||
|
||||
enum Group {
|
||||
|
||||
A, B, C
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_array_of_enums(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor10 personExtractor = AiServices.create(PersonExtractor10.class, model);
|
||||
|
||||
String text = "Klaus is assigned to groups A and C";
|
||||
|
||||
// when
|
||||
PersonExtractor10.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.groups).containsExactly(PersonExtractor10.Group.A, PersonExtractor10.Group.C);
|
||||
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("groups", JsonArraySchema.builder()
|
||||
.items(JsonEnumSchema.builder()
|
||||
.enumValues("A", "B", "C")
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "groups")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor11 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
List<Group> groups;
|
||||
}
|
||||
|
||||
enum Group {
|
||||
|
||||
A, B, C
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_list_of_enums(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor11 personExtractor = AiServices.create(PersonExtractor11.class, model);
|
||||
|
||||
String text = "Klaus is assigned to groups A and C";
|
||||
|
||||
// when
|
||||
PersonExtractor11.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.groups).containsExactly(PersonExtractor11.Group.A, PersonExtractor11.Group.C);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("groups", JsonArraySchema.builder()
|
||||
.items(JsonEnumSchema.builder()
|
||||
.enumValues("A", "B", "C")
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "groups")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor12 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
Set<Group> groups;
|
||||
}
|
||||
|
||||
enum Group {
|
||||
|
||||
A, B, C
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_set_of_enums(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor12 personExtractor = AiServices.create(PersonExtractor12.class, model);
|
||||
|
||||
String text = "Klaus is assigned to groups A and C";
|
||||
|
||||
// when
|
||||
PersonExtractor12.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.groups).containsExactly(PersonExtractor12.Group.A, PersonExtractor12.Group.C);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("groups", JsonArraySchema.builder()
|
||||
.items(JsonEnumSchema.builder()
|
||||
.enumValues("A", "B", "C")
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "groups")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor13 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
LocalDate birthDate;
|
||||
LocalTime birthTime;
|
||||
LocalDateTime birthDateTime;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_extract_pojo_with_local_date_time_fields(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor13 personExtractor = AiServices.create(PersonExtractor13.class, model);
|
||||
|
||||
String text = "Klaus was born at 14:43:26 on 12th of August 1976";
|
||||
|
||||
// when
|
||||
PersonExtractor13.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.birthDate).isEqualTo(LocalDate.of(1976, 8, 12));
|
||||
assertThat(person.birthTime).isEqualTo(LocalTime.of(14, 43, 26));
|
||||
assertThat(person.birthDateTime)
|
||||
.isEqualTo(LocalDateTime.of(1976, 8, 12, 14, 43, 26));
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
put("birthDate", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("year", JSON_INTEGER_SCHEMA);
|
||||
put("month", JSON_INTEGER_SCHEMA);
|
||||
put("day", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("year", "month", "day")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
put("birthTime", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("hour", JSON_INTEGER_SCHEMA);
|
||||
put("minute", JSON_INTEGER_SCHEMA);
|
||||
put("second", JSON_INTEGER_SCHEMA);
|
||||
put("nano", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("hour", "minute", "second", "nano")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
put("birthDateTime", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("date", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("year", JSON_INTEGER_SCHEMA);
|
||||
put("month", JSON_INTEGER_SCHEMA);
|
||||
put("day", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("year", "month", "day")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
put("time", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("hour", JSON_INTEGER_SCHEMA);
|
||||
put("minute", JSON_INTEGER_SCHEMA);
|
||||
put("second", JSON_INTEGER_SCHEMA);
|
||||
put("nano", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("hour", "minute", "second", "nano")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
}})
|
||||
.required("date", "time")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "birthDate", "birthTime", "birthDateTime")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor14 {
|
||||
|
||||
class Person {
|
||||
|
||||
String name;
|
||||
}
|
||||
|
||||
Result<Person> extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_return_result_with_pojo(ChatLanguageModel model) {
|
||||
|
||||
// given
|
||||
model = spy(model);
|
||||
PersonExtractor14 personExtractor = AiServices.create(PersonExtractor14.class, model);
|
||||
|
||||
String text = "Klaus";
|
||||
|
||||
// when
|
||||
Result<PersonExtractor14.Person> result = personExtractor.extractPersonFrom(text);
|
||||
PersonExtractor14.Person person = result.content();
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JSON_STRING_SCHEMA);
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,969 @@
|
|||
package dev.langchain4j.service;
|
||||
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Spy;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
|
||||
import static dev.langchain4j.model.chat.request.json.JsonIntegerSchema.JSON_INTEGER_SCHEMA;
|
||||
import static dev.langchain4j.model.chat.request.json.JsonStringSchema.JSON_STRING_SCHEMA;
|
||||
import static dev.langchain4j.service.AiServicesJsonSchemaWithDescriptionsIT.PersonExtractor3.MaritalStatus.SINGLE;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class AiServicesJsonSchemaWithDescriptionsIT {
|
||||
|
||||
@Spy
|
||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
||||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_4_O_MINI)
|
||||
.responseFormat("json_schema")
|
||||
.strictJsonSchema(true)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
@AfterEach
|
||||
void afterEach() {
|
||||
verifyNoMoreInteractions(model);
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor1 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("an age")
|
||||
int age;
|
||||
|
||||
@Description("a height")
|
||||
Double height;
|
||||
|
||||
@Description("married or not")
|
||||
boolean married;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_primitives() {
|
||||
|
||||
// given
|
||||
PersonExtractor1 personExtractor = AiServices.create(PersonExtractor1.class, model);
|
||||
|
||||
String text = "Klaus is 37 years old, 1.78m height and single";
|
||||
|
||||
// when
|
||||
PersonExtractor1.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.age).isEqualTo(37);
|
||||
assertThat(person.height).isEqualTo(1.78);
|
||||
assertThat(person.married).isFalse();
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("age", JsonIntegerSchema.builder()
|
||||
.description("an age")
|
||||
.build());
|
||||
put("height", JsonNumberSchema.builder()
|
||||
.description("a height")
|
||||
.build());
|
||||
put("married", JsonBooleanSchema.builder()
|
||||
.description("married or not")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "age", "height", "married")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor2 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("an address override")
|
||||
Address address;
|
||||
}
|
||||
|
||||
@Description("an address")
|
||||
class Address {
|
||||
|
||||
@Description("a city")
|
||||
String city;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_nested_pojo() {
|
||||
|
||||
// given
|
||||
PersonExtractor2 personExtractor = AiServices.create(PersonExtractor2.class, model);
|
||||
|
||||
String text = "Klaus lives in Langley Falls";
|
||||
|
||||
// when
|
||||
PersonExtractor2.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.address.city).isEqualTo("Langley Falls");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("address", JsonObjectSchema.builder()
|
||||
.description("an address override")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("city", JsonStringSchema.builder()
|
||||
.description("a city")
|
||||
.build());
|
||||
}})
|
||||
.required("city")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "address")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor3 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("marital status override")
|
||||
MaritalStatus maritalStatus;
|
||||
}
|
||||
|
||||
@Description("marital status")
|
||||
enum MaritalStatus {
|
||||
|
||||
SINGLE, MARRIED
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_enum() {
|
||||
|
||||
// given
|
||||
PersonExtractor3 personExtractor = AiServices.create(PersonExtractor3.class, model);
|
||||
|
||||
String text = "Klaus is single";
|
||||
|
||||
// when
|
||||
PersonExtractor3.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.maritalStatus).isEqualTo(SINGLE);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("maritalStatus", JsonEnumSchema.builder()
|
||||
.enumValues("SINGLE", "MARRIED")
|
||||
.description("marital status override")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "maritalStatus")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor4 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("favourite colors")
|
||||
String[] favouriteColors;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_array_of_primitives() {
|
||||
|
||||
// given
|
||||
PersonExtractor4 personExtractor = AiServices.create(PersonExtractor4.class, model);
|
||||
|
||||
String text = "Klaus likes orange and green";
|
||||
|
||||
// when
|
||||
PersonExtractor4.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.favouriteColors).containsExactly("orange", "green");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("favouriteColors", JsonArraySchema.builder()
|
||||
.items(JSON_STRING_SCHEMA)
|
||||
.description("favourite colors")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "favouriteColors")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor5 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("favourite colors")
|
||||
List<String> favouriteColors;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_list_of_primitives() {
|
||||
|
||||
// given
|
||||
PersonExtractor5 personExtractor = AiServices.create(PersonExtractor5.class, model);
|
||||
|
||||
String text = "Klaus likes orange and green";
|
||||
|
||||
// when
|
||||
PersonExtractor5.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.favouriteColors).containsExactly("orange", "green");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("favouriteColors", JsonArraySchema.builder()
|
||||
.items(JSON_STRING_SCHEMA)
|
||||
.description("favourite colors")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "favouriteColors")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor6 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("favourite colors")
|
||||
Set<String> favouriteColors;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_set_of_primitives() {
|
||||
|
||||
// given
|
||||
PersonExtractor6 personExtractor = AiServices.create(PersonExtractor6.class, model);
|
||||
|
||||
String text = "Klaus likes orange and green";
|
||||
|
||||
// when
|
||||
PersonExtractor6.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.favouriteColors).containsExactly("orange", "green");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("favouriteColors", JsonArraySchema.builder()
|
||||
.items(JSON_STRING_SCHEMA)
|
||||
.description("favourite colors")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "favouriteColors")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor7 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("pets of a person")
|
||||
Pet[] pets;
|
||||
}
|
||||
|
||||
@Description("a pet")
|
||||
class Pet {
|
||||
|
||||
@Description("a name of a pet")
|
||||
String name;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_array_of_pojos() {
|
||||
|
||||
// given
|
||||
PersonExtractor7 personExtractor = AiServices.create(PersonExtractor7.class, model);
|
||||
|
||||
String text = "Klaus has 2 pets: Peanut and Muffin";
|
||||
|
||||
// when
|
||||
PersonExtractor7.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.pets).hasSize(2);
|
||||
assertThat(person.pets[0].name).isEqualTo("Peanut");
|
||||
assertThat(person.pets[1].name).isEqualTo("Muffin");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("pets", JsonArraySchema.builder()
|
||||
.items(JsonObjectSchema.builder()
|
||||
.description("a pet")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name of a pet")
|
||||
.build());
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.description("pets of a person")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "pets")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor8 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("pets of a person")
|
||||
List<Pet> pets;
|
||||
}
|
||||
|
||||
@Description("a pet")
|
||||
class Pet {
|
||||
|
||||
@Description("a name of a pet")
|
||||
String name;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_list_of_pojos() {
|
||||
|
||||
// given
|
||||
PersonExtractor8 personExtractor = AiServices.create(PersonExtractor8.class, model);
|
||||
|
||||
String text = "Klaus has 2 pets: Peanut and Muffin";
|
||||
|
||||
// when
|
||||
PersonExtractor8.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.pets).hasSize(2);
|
||||
assertThat(person.pets.get(0).name).isEqualTo("Peanut");
|
||||
assertThat(person.pets.get(1).name).isEqualTo("Muffin");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("pets", JsonArraySchema.builder()
|
||||
.items(JsonObjectSchema.builder()
|
||||
.description("a pet")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name of a pet")
|
||||
.build());
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.description("pets of a person")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "pets")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor9 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("pets of a person")
|
||||
Set<Pet> pets;
|
||||
}
|
||||
|
||||
@Description("a pet")
|
||||
class Pet {
|
||||
|
||||
@Description("a name of a pet")
|
||||
String name;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_set_of_pojos() {
|
||||
|
||||
// given
|
||||
PersonExtractor9 personExtractor = AiServices.create(PersonExtractor9.class, model);
|
||||
|
||||
String text = "Klaus has 2 pets: Peanut and Muffin";
|
||||
|
||||
// when
|
||||
PersonExtractor9.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.pets).hasSize(2);
|
||||
Iterator<PersonExtractor9.Pet> iterator = person.pets.iterator();
|
||||
assertThat(iterator.next().name).isEqualTo("Peanut");
|
||||
assertThat(iterator.next().name).isEqualTo("Muffin");
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("pets", JsonArraySchema.builder()
|
||||
.items(JsonObjectSchema.builder()
|
||||
.description("a pet")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name of a pet")
|
||||
.build());
|
||||
}})
|
||||
.required("name")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.description("pets of a person")
|
||||
.build());
|
||||
}})
|
||||
.required("name", "pets")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor10 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("groups")
|
||||
Group[] groups;
|
||||
}
|
||||
|
||||
@Description("a group")
|
||||
enum Group {
|
||||
|
||||
A, B, C
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_array_of_enums() {
|
||||
|
||||
// given
|
||||
PersonExtractor10 personExtractor = AiServices.create(PersonExtractor10.class, model);
|
||||
|
||||
String text = "Klaus is assigned to groups A and C";
|
||||
|
||||
// when
|
||||
PersonExtractor10.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.groups).containsExactly(PersonExtractor10.Group.A, PersonExtractor10.Group.C);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("groups", JsonArraySchema.builder()
|
||||
.description("groups")
|
||||
.items(JsonEnumSchema.builder()
|
||||
.description("a group")
|
||||
.enumValues("A", "B", "C")
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "groups")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor11 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("groups")
|
||||
List<Group> groups;
|
||||
}
|
||||
|
||||
@Description("a group")
|
||||
enum Group {
|
||||
|
||||
A, B, C
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_list_of_enums() {
|
||||
|
||||
// given
|
||||
PersonExtractor11 personExtractor = AiServices.create(PersonExtractor11.class, model);
|
||||
|
||||
String text = "Klaus is assigned to groups A and C";
|
||||
|
||||
// when
|
||||
PersonExtractor11.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.groups).containsExactly(PersonExtractor11.Group.A, PersonExtractor11.Group.C);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("groups", JsonArraySchema.builder()
|
||||
.description("groups")
|
||||
.items(JsonEnumSchema.builder()
|
||||
.description("a group")
|
||||
.enumValues("A", "B", "C")
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "groups")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor12 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("groups")
|
||||
Set<Group> groups;
|
||||
}
|
||||
|
||||
@Description("a group")
|
||||
enum Group {
|
||||
|
||||
A, B, C
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_set_of_enums() {
|
||||
|
||||
// given
|
||||
PersonExtractor12 personExtractor = AiServices.create(PersonExtractor12.class, model);
|
||||
|
||||
String text = "Klaus is assigned to groups A and C";
|
||||
|
||||
// when
|
||||
PersonExtractor12.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.groups).containsExactly(PersonExtractor12.Group.A, PersonExtractor12.Group.C);
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("groups", JsonArraySchema.builder()
|
||||
.description("groups")
|
||||
.items(JsonEnumSchema.builder()
|
||||
.description("a group")
|
||||
.enumValues("A", "B", "C")
|
||||
.build())
|
||||
.build());
|
||||
}})
|
||||
.required("name", "groups")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
|
||||
|
||||
interface PersonExtractor13 {
|
||||
|
||||
@Description("a person")
|
||||
class Person {
|
||||
|
||||
@Description("a name")
|
||||
String name;
|
||||
|
||||
@Description("a birth date")
|
||||
LocalDate birthDate;
|
||||
|
||||
@Description("a birth time")
|
||||
LocalTime birthTime;
|
||||
|
||||
@Description("a birth date and time")
|
||||
LocalDateTime birthDateTime;
|
||||
}
|
||||
|
||||
Person extractPersonFrom(String text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_extract_pojo_with_local_date_time_fields() {
|
||||
|
||||
// given
|
||||
PersonExtractor13 personExtractor = AiServices.create(PersonExtractor13.class, model);
|
||||
|
||||
String text = "Klaus was born at 14:43:26 on 12th of August 1976";
|
||||
|
||||
// when
|
||||
PersonExtractor13.Person person = personExtractor.extractPersonFrom(text);
|
||||
|
||||
// then
|
||||
assertThat(person.name).isEqualTo("Klaus");
|
||||
assertThat(person.birthDate).isEqualTo(LocalDate.of(1976, 8, 12));
|
||||
assertThat(person.birthTime).isEqualTo(LocalTime.of(14, 43, 26));
|
||||
assertThat(person.birthDateTime)
|
||||
.isEqualTo(LocalDateTime.of(1976, 8, 12, 14, 43, 26));
|
||||
|
||||
verify(model).chat(ChatRequest.builder()
|
||||
.messages(singletonList(userMessage(text)))
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.name("Person")
|
||||
.schema(JsonObjectSchema.builder()
|
||||
.description("a person")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("name", JsonStringSchema.builder()
|
||||
.description("a name")
|
||||
.build());
|
||||
put("birthDate", JsonObjectSchema.builder()
|
||||
.description("a birth date")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("year", JSON_INTEGER_SCHEMA);
|
||||
put("month", JSON_INTEGER_SCHEMA);
|
||||
put("day", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("year", "month", "day")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
put("birthTime", JsonObjectSchema.builder()
|
||||
.description("a birth time")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("hour", JSON_INTEGER_SCHEMA);
|
||||
put("minute", JSON_INTEGER_SCHEMA);
|
||||
put("second", JSON_INTEGER_SCHEMA);
|
||||
put("nano", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("hour", "minute", "second", "nano")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
put("birthDateTime", JsonObjectSchema.builder()
|
||||
.description("a birth date and time")
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("date", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("year", JSON_INTEGER_SCHEMA);
|
||||
put("month", JSON_INTEGER_SCHEMA);
|
||||
put("day", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("year", "month", "day")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
put("time", JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("hour", JSON_INTEGER_SCHEMA);
|
||||
put("minute", JSON_INTEGER_SCHEMA);
|
||||
put("second", JSON_INTEGER_SCHEMA);
|
||||
put("nano", JSON_INTEGER_SCHEMA);
|
||||
}})
|
||||
.required("hour", "minute", "second", "nano")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
}})
|
||||
.required("date", "time")
|
||||
.additionalProperties(false)
|
||||
.build());
|
||||
}})
|
||||
.required("name", "birthDate", "birthTime", "birthDateTime")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build());
|
||||
verify(model).supportedCapabilities();
|
||||
}
|
||||
}
|
|
@ -129,6 +129,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -146,6 +147,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -163,6 +165,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -180,6 +183,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -197,6 +201,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -214,6 +219,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -231,6 +237,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -248,6 +255,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -265,6 +273,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -282,6 +291,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -300,6 +310,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -318,6 +329,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -336,6 +348,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -354,6 +367,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -372,6 +386,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -390,6 +405,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -408,6 +424,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -426,6 +443,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -444,6 +462,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -462,6 +481,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("Given a name of a country, answer with a name of it's capital"),
|
||||
userMessage("Country: Germany")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -480,6 +500,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
|
|||
systemMessage("This message should take precedence over the one provided by systemMessageProvider"),
|
||||
userMessage("What is the capital of Germany?")
|
||||
));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -79,6 +79,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat1("What is the capital of Germany?"))
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -93,6 +94,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat2("What is the capital of Germany?"))
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -107,6 +109,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat3("What is the capital of {{country}}?", "Germany"))
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -121,6 +124,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat4())
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -135,6 +139,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat5("Germany"))
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -149,6 +154,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat6("Germany"))
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -163,6 +169,7 @@ class AiServicesUserMessageConfigTest {
|
|||
assertThat(aiService.chat7("capital", "Germany"))
|
||||
.containsIgnoringCase("Berlin");
|
||||
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
|
||||
verify(chatLanguageModel).supportedCapabilities();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -27,7 +27,9 @@ import static dev.langchain4j.service.AiServicesWithChatMemoryIT.ChatWithMemory.
|
|||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class AiServicesWithChatMemoryIT {
|
||||
|
@ -142,6 +144,7 @@ class AiServicesWithChatMemoryIT {
|
|||
));
|
||||
verify(chatMemory).add(aiMessage(fourthAiMessage));
|
||||
|
||||
verify(chatLanguageModel, times(4)).supportedCapabilities();
|
||||
verify(chatMemory, times(12)).messages();
|
||||
}
|
||||
|
||||
|
@ -171,6 +174,7 @@ class AiServicesWithChatMemoryIT {
|
|||
aiMessage(firstAiMessage),
|
||||
userMessage(secondUserMessage)
|
||||
));
|
||||
verify(chatLanguageModel, times(2)).supportedCapabilities();
|
||||
|
||||
verify(chatMemory, times(2)).add(systemMessage(SYSTEM_MESSAGE));
|
||||
verify(chatMemory).add(userMessage(firstUserMessage));
|
||||
|
@ -207,6 +211,7 @@ class AiServicesWithChatMemoryIT {
|
|||
systemMessage(ANOTHER_SYSTEM_MESSAGE),
|
||||
userMessage(secondUserMessage)
|
||||
));
|
||||
verify(chatLanguageModel, times(2)).supportedCapabilities();
|
||||
|
||||
verify(chatMemory).add(systemMessage(SYSTEM_MESSAGE));
|
||||
verify(chatMemory).add(userMessage(firstUserMessage));
|
||||
|
@ -304,5 +309,7 @@ class AiServicesWithChatMemoryIT {
|
|||
userMessage(secondMessageFromSecondUser),
|
||||
aiMessage(secondAiResponseToSecondUser)
|
||||
);
|
||||
|
||||
verify(chatLanguageModel, times(4)).supportedCapabilities();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,10 @@ package dev.langchain4j.service;
|
|||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import dev.langchain4j.agent.tool.*;
|
||||
import dev.langchain4j.agent.tool.P;
|
||||
import dev.langchain4j.agent.tool.Tool;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
|
@ -33,7 +36,7 @@ import java.util.stream.Stream;
|
|||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
||||
import static dev.langchain4j.model.mistralai.MistralAiChatModelName.MISTRAL_LARGE_LATEST;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO_0613;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.*;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.service.AiServicesWithToolsIT.Operator.EQUALS;
|
||||
import static dev.langchain4j.service.AiServicesWithToolsIT.TemperatureUnit.Kelvin;
|
||||
|
@ -54,6 +57,17 @@ class AiServicesWithToolsIT {
|
|||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_4_O)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build(),
|
||||
OpenAiChatModel.builder()
|
||||
.baseUrl(System.getenv("OPENAI_BASE_URL"))
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
|
||||
.modelName(GPT_4_O_MINI)
|
||||
.strictTools(true)
|
||||
.temperature(0.0)
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
|
@ -616,7 +630,7 @@ class AiServicesWithToolsIT {
|
|||
Operator operator;
|
||||
|
||||
@Description("Value to compare with")
|
||||
Object value;
|
||||
String value;
|
||||
}
|
||||
|
||||
enum Operator {
|
||||
|
@ -678,6 +692,33 @@ class AiServicesWithToolsIT {
|
|||
assertThat(response.content().text()).contains("2027");
|
||||
}
|
||||
|
||||
static class Clock {
|
||||
|
||||
@Tool
|
||||
String currentTime() {
|
||||
return "16:37:43";
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("models")
|
||||
void should_execute_tool_without_parameters(ChatLanguageModel chatLanguageModel) {
|
||||
|
||||
// given
|
||||
Clock clock = spy(new Clock());
|
||||
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
.chatLanguageModel(chatLanguageModel)
|
||||
.tools(clock)
|
||||
.build();
|
||||
|
||||
// when
|
||||
Response<AiMessage> response = assistant.chat("What is the time now?");
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).contains("16:37:43");
|
||||
}
|
||||
|
||||
private static Map<String, Object> toMap(String arguments) {
|
||||
try {
|
||||
return new ObjectMapper().readValue(arguments, new TypeReference<Map<String, Object>>() {
|
||||
|
|
|
@ -0,0 +1,220 @@
|
|||
package dev.langchain4j.service.output;
|
||||
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.service.Result;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.Optional;
|
||||
|
||||
import static dev.langchain4j.service.output.JsonSchemas.jsonSchemaFrom;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class JsonSchemasTest {
|
||||
|
||||
class Pojo {
|
||||
|
||||
String field;
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_json_schema_for_pojos() {
|
||||
assertThat(jsonSchemaFrom(Pojo.class)).isPresent();
|
||||
assertThat(jsonSchemaFrom(new TypeToken<Result<Pojo>>() {
|
||||
}.getType())).isPresent();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_empty_for_not_pojos() {
|
||||
assertThat(jsonSchemaFrom(String.class)).isEmpty();
|
||||
assertThat(jsonSchemaFrom(AiMessage.class)).isEmpty();
|
||||
assertThat(jsonSchemaFrom(Response.class)).isEmpty();
|
||||
assertThat(jsonSchemaFrom(Integer.class)).isEmpty();
|
||||
assertThat(jsonSchemaFrom(LocalDate.class)).isEmpty();
|
||||
assertThat(jsonSchemaFrom(new TypeToken<Result<String>>() {
|
||||
}.getType())).isEmpty();
|
||||
}
|
||||
|
||||
|
||||
// POJO
|
||||
|
||||
@Test
|
||||
void should_take_pojo_description_from_the_field() {
|
||||
|
||||
// given
|
||||
class Address {
|
||||
|
||||
String street;
|
||||
String city;
|
||||
}
|
||||
|
||||
class Person {
|
||||
|
||||
@Description("an address")
|
||||
Address address;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonObjectSchema addressSchema = (JsonObjectSchema) jsonSchema.get().schema().properties().get("address");
|
||||
assertThat(addressSchema.description()).isEqualTo("an address");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_take_pojo_description_from_the_class() {
|
||||
|
||||
// given
|
||||
@Description("an address")
|
||||
class Address {
|
||||
|
||||
String street;
|
||||
String city;
|
||||
}
|
||||
|
||||
class Person {
|
||||
|
||||
Address address;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonObjectSchema addressSchema = (JsonObjectSchema) jsonSchema.get().schema().properties().get("address");
|
||||
assertThat(addressSchema.description()).isEqualTo("an address");
|
||||
}
|
||||
|
||||
@Test
|
||||
void pojo_field_description_should_override_class_description() {
|
||||
|
||||
// given
|
||||
@Description("an address")
|
||||
class Address {
|
||||
|
||||
String street;
|
||||
String city;
|
||||
}
|
||||
|
||||
class Person {
|
||||
|
||||
@Description("an address 2")
|
||||
Address address;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonObjectSchema addressSchema = (JsonObjectSchema) jsonSchema.get().schema().properties().get("address");
|
||||
assertThat(addressSchema.description()).isEqualTo("an address 2");
|
||||
}
|
||||
|
||||
|
||||
// ENUM
|
||||
|
||||
enum MaritalStatus {
|
||||
|
||||
SINGLE, MARRIED
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_take_enum_description_from_the_field() {
|
||||
|
||||
// given
|
||||
class Person {
|
||||
|
||||
@Description("marital status")
|
||||
MaritalStatus maritalStatus;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonEnumSchema maritalStatusSchema = (JsonEnumSchema) jsonSchema.get().schema().properties().get("maritalStatus");
|
||||
assertThat(maritalStatusSchema.description()).isEqualTo("marital status");
|
||||
}
|
||||
|
||||
|
||||
@Description("marital status")
|
||||
enum MaritalStatus2 {
|
||||
|
||||
SINGLE, MARRIED
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_take_enum_description_from_the_enum() {
|
||||
|
||||
// given
|
||||
class Person {
|
||||
|
||||
MaritalStatus2 maritalStatus;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonEnumSchema maritalStatusSchema = (JsonEnumSchema) jsonSchema.get().schema().properties().get("maritalStatus");
|
||||
assertThat(maritalStatusSchema.description()).isEqualTo("marital status");
|
||||
}
|
||||
|
||||
|
||||
@Description("marital status")
|
||||
enum MaritalStatus3 {
|
||||
|
||||
SINGLE, MARRIED
|
||||
}
|
||||
|
||||
@Test
|
||||
void enum_field_description_should_override_class_description() {
|
||||
|
||||
// given
|
||||
class Person {
|
||||
|
||||
@Description("marital status 2")
|
||||
MaritalStatus3 maritalStatus;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonEnumSchema maritalStatusSchema = (JsonEnumSchema) jsonSchema.get().schema().properties().get("maritalStatus");
|
||||
assertThat(maritalStatusSchema.description()).isEqualTo("marital status 2");
|
||||
}
|
||||
|
||||
// ARRAY
|
||||
|
||||
@Test
|
||||
void should_take_array_description_from_the_field() {
|
||||
|
||||
// given
|
||||
class Pet {
|
||||
|
||||
String name;
|
||||
}
|
||||
|
||||
class Person {
|
||||
|
||||
@Description("pets of a person")
|
||||
Pet[] pets;
|
||||
}
|
||||
|
||||
// when
|
||||
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
|
||||
|
||||
// then
|
||||
JsonArraySchema petsSchema = (JsonArraySchema) jsonSchema.get().schema().properties().get("pets");
|
||||
assertThat(petsSchema.description()).isEqualTo("pets of a person");
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue