OpenAI: Structured Outputs (#1590)

## Issue
Closes #1581

## Change
- OpenAI: added support for [Structured
Outputs](https://openai.com/index/introducing-structured-outputs-in-the-api/):
  - for tools
  - for json mode
- Introduced new (still experimental) `ChatLanguageModel` API (which
supports specifying json schema)

### OpenAI Structured Outputs for tools
To enable Structured Outputs feature for tools, set `.strictTools(true)`
when buidling the model:
```java
OpenAiChatModel.builder()
    ...
    .strictTools(true)
    .build(),
```
Please note that this will automatically make all tool parameters
mandatory (`required` in json schema)
and set `additionalProperties=false` for each `object` in json schema.
This is due to the current OpenAI limitations.

### OpenAI Structured Outputs for json mode
To enable Structured Outputs feature for json mode, set
`.responseFormat("json_schema")` and `.strictJsonSchema(true)` when
buidling the model:
```java
OpenAiChatModel.builder()
    ...
    .responseFormat("json_schema")
    .strictJsonSchema(true)
    .build(),
```
In this case `AiServices` will not append "You must answer strictly in
the following JSON format: ..." string to the end of the last
`UserMessage`, but will create a Json schema from the given POJO and
pass it to the LLM.
Please note that this works only when method return type is a POJO. If
the return type is something else, (like an enum or a `List<String>`),
the old behaviour is applied (with "You must answer strictly ..."). All
return types will be supported in the near future.

Please note that this feature is available now only for `gpt-4o-mini`
and `gpt-4o-2024-08-06` models.

### Experimental `ChatLanguageModel` API
This was drafted in
https://github.com/langchain4j/langchain4j/pull/1261, but now it has to
be rushed a bit in order to enable new Structured Outputs feature for
OpenAI.
A new method `ChatResponse chat(ChatRequest request)` was added into
`ChatLanguageModel` which allows to specify messages, tools and response
format (with json schema). In the future it will also support specifying
model parameters like temperature.

## Upcoming Changes
- Adopt new `ChatLanguageModel` API for Gemini
- Adopt new `ChatLanguageModel` API for Azure OpenAI (once available)
- Support Structured Outputs with all other method return types like
`List<Pojo>`
- Adopt new `JsonSchema` type for tools (instead of `ToolParameters`)

Reated changes in openai4j:
https://github.com/ai-for-java/openai4j/pull/33

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
LangChain4j 2024-08-14 15:25:25 +02:00 committed by GitHub
parent 56d6184aae
commit c5c146f9a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 3771 additions and 118 deletions

View File

@ -126,6 +126,9 @@
<excludes>
<exclude>dev.langchain4j.data.document</exclude>
<exclude>dev.langchain4j.model.chat.listener</exclude>
<exclude>dev.langchain4j.model.chat.request</exclude>
<exclude>dev.langchain4j.model.chat.request.json</exclude>
<exclude>dev.langchain4j.model.chat.response</exclude>
<exclude>dev.langchain4j.model.listener</exclude>
<exclude>dev.langchain4j.store.embedding</exclude>
<exclude>dev.langchain4j.store.embedding.filter</exclude>
@ -146,7 +149,7 @@
<limit>
<counter>INSTRUCTION</counter>
<value>COVEREDRATIO</value>
<minimum>0.9</minimum>
<minimum>0.8</minimum>
</limit>
</limits>
</rule>

View File

@ -11,6 +11,7 @@ import static java.util.Arrays.asList;
* Describes a {@link Tool}.
*/
public class ToolSpecification {
private final String name;
private final String description;
private final ToolParameters parameters;

View File

@ -11,8 +11,11 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.from;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.items;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.objectItems;
import static dev.langchain4j.internal.TypeUtils.*;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.output.structured.Description;
import static java.lang.String.format;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toList;
@ -22,8 +25,6 @@ import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
@ -199,15 +200,15 @@ public class ToolSpecifications {
return removeNulls(STRING, description);
}
if (isBoolean(type)) {
if (isJsonBoolean(type)) {
return removeNulls(BOOLEAN, description);
}
if (isInteger(type)) {
if (isJsonInteger(type)) {
return removeNulls(INTEGER, description);
}
if (isNumber(type)) {
if (isJsonNumber(type)) {
return removeNulls(NUMBER, description);
}
@ -234,36 +235,17 @@ public class ToolSpecifications {
return items(JsonSchemaProperty.OBJECT);
}
// TODO put constraints on min and max?
private static boolean isNumber(Class<?> type) {
return type == float.class || type == Float.class
|| type == double.class || type == Double.class
|| type == BigDecimal.class;
}
private static boolean isInteger(Class<?> type) {
return type == byte.class || type == Byte.class
|| type == short.class || type == Short.class
|| type == int.class || type == Integer.class
|| type == long.class || type == Long.class
|| type == BigInteger.class;
}
private static boolean isBoolean(Class<?> type) {
return type == boolean.class || type == Boolean.class;
}
private static JsonSchemaProperty arrayTypeFrom(Class<?> clazz) {
if (clazz == String.class) {
return items(JsonSchemaProperty.STRING);
}
if (isBoolean(clazz)) {
if (isJsonBoolean(clazz)) {
return items(JsonSchemaProperty.BOOLEAN);
}
if (isInteger(clazz)) {
if (isJsonInteger(clazz)) {
return items(JsonSchemaProperty.INTEGER);
}
if (isNumber(clazz)) {
if (isJsonNumber(clazz)) {
return items(JsonSchemaProperty.NUMBER);
}
return objectItems(schema(clazz));

View File

@ -7,36 +7,80 @@ import java.io.*;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.Map;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
import static java.time.format.DateTimeFormatter.*;
class GsonJsonCodec implements Json.JsonCodec {
private static final Gson GSON = new GsonBuilder()
.setPrettyPrinting()
.registerTypeAdapter(
LocalDate.class,
(JsonSerializer<LocalDate>) (localDate, type, context) ->
new JsonPrimitive(localDate.format(ISO_LOCAL_DATE))
)
.registerTypeAdapter(
LocalDate.class,
(JsonDeserializer<LocalDate>) (json, type, context) ->
LocalDate.parse(json.getAsString(), ISO_LOCAL_DATE)
)
.registerTypeAdapter(
LocalDateTime.class,
(JsonSerializer<LocalDateTime>) (localDateTime, type, context) ->
new JsonPrimitive(localDateTime.format(ISO_LOCAL_DATE_TIME))
)
.registerTypeAdapter(
LocalDateTime.class,
(JsonDeserializer<LocalDateTime>) (json, type, context) ->
LocalDateTime.parse(json.getAsString(), ISO_LOCAL_DATE_TIME)
)
.create();
.setPrettyPrinting()
.registerTypeAdapter(
LocalDate.class,
(JsonSerializer<LocalDate>) (localDate, type, context) ->
new JsonPrimitive(localDate.format(ISO_LOCAL_DATE))
)
.registerTypeAdapter(
LocalDate.class,
(JsonDeserializer<LocalDate>) (json, type, context) -> {
if (json.isJsonObject()) {
JsonObject jsonObject = (JsonObject) json;
int year = jsonObject.get("year").getAsInt();
int month = jsonObject.get("month").getAsInt();
int day = jsonObject.get("day").getAsInt();
return LocalDate.of(year, month, day);
} else {
return LocalDate.parse(json.getAsString(), ISO_LOCAL_DATE);
}
}
)
.registerTypeAdapter(
LocalTime.class,
(JsonSerializer<LocalTime>) (localTime, type, context) ->
new JsonPrimitive(localTime.format(ISO_LOCAL_TIME))
)
.registerTypeAdapter(
LocalTime.class,
(JsonDeserializer<LocalTime>) (json, type, context) -> {
if (json.isJsonObject()) {
JsonObject jsonObject = (JsonObject) json;
int hour = jsonObject.get("hour").getAsInt();
int minute = jsonObject.get("minute").getAsInt();
int second = jsonObject.get("second").getAsInt();
int nano = jsonObject.get("nano").getAsInt();
return LocalTime.of(hour, minute, second, nano);
} else {
return LocalTime.parse(json.getAsString(), ISO_LOCAL_TIME);
}
}
)
.registerTypeAdapter(
LocalDateTime.class,
(JsonSerializer<LocalDateTime>) (localDateTime, type, context) ->
new JsonPrimitive(localDateTime.format(ISO_LOCAL_DATE_TIME))
)
.registerTypeAdapter(
LocalDateTime.class,
(JsonDeserializer<LocalDateTime>) (json, type, context) -> {
if (json.isJsonObject()) {
JsonObject jsonObject = (JsonObject) json;
JsonObject date = jsonObject.get("date").getAsJsonObject();
int year = date.get("year").getAsInt();
int month = date.get("month").getAsInt();
int day = date.get("day").getAsInt();
JsonObject time = jsonObject.get("time").getAsJsonObject();
int hour = time.get("hour").getAsInt();
int minute = time.get("minute").getAsInt();
int second = time.get("second").getAsInt();
int nano = time.get("nano").getAsInt();
return LocalDateTime.of(year, month, day, hour, minute, second, nano);
} else {
return LocalDateTime.parse(json.getAsString(), ISO_LOCAL_DATE_TIME);
}
}
)
.create();
@Override
public String toJson(Object o) {

View File

@ -0,0 +1,25 @@
package dev.langchain4j.internal;
import java.math.BigDecimal;
import java.math.BigInteger;
public class TypeUtils {
public static boolean isJsonInteger(Class<?> type) {
return type == byte.class || type == Byte.class
|| type == short.class || type == Short.class
|| type == int.class || type == Integer.class
|| type == long.class || type == Long.class
|| type == BigInteger.class;
}
public static boolean isJsonNumber(Class<?> type) {
return type == float.class || type == Float.class
|| type == double.class || type == Double.class
|| type == BigDecimal.class;
}
public static boolean isJsonBoolean(Class<?> type) {
return type == boolean.class || type == Boolean.class;
}
}

View File

@ -0,0 +1,9 @@
package dev.langchain4j.model.chat;
import dev.langchain4j.Experimental;
@Experimental
public enum Capability {
RESPONSE_FORMAT_JSON_SCHEMA
}

View File

@ -1,14 +1,19 @@
package dev.langchain4j.model.chat;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.Response;
import java.util.List;
import java.util.Set;
import static java.util.Arrays.asList;
import static java.util.Collections.emptySet;
/**
* Represents a language model that has a chat interface.
@ -80,4 +85,14 @@ public interface ChatLanguageModel {
default Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
throw new IllegalArgumentException("Tools are currently not supported by this model");
}
@Experimental
default ChatResponse chat(ChatRequest request) {
throw new UnsupportedOperationException();
}
@Experimental
default Set<Capability> supportedCapabilities() {
return emptySet();
}
}

View File

@ -0,0 +1,101 @@
package dev.langchain4j.model.chat.request;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static java.util.Arrays.asList;
@Experimental
public class ChatRequest {
private final List<ChatMessage> messages;
private final List<ToolSpecification> toolSpecifications;
private final ResponseFormat responseFormat;
private ChatRequest(Builder builder) {
this.messages = new ArrayList<>(ensureNotEmpty(builder.messages, "messages"));
this.toolSpecifications = copyIfNotNull(builder.toolSpecifications);
this.responseFormat = builder.responseFormat;
}
public List<ChatMessage> messages() {
return messages;
}
public List<ToolSpecification> toolSpecifications() {
return toolSpecifications;
}
public ResponseFormat responseFormat() {
return responseFormat;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChatRequest that = (ChatRequest) o;
return Objects.equals(this.messages, that.messages)
&& Objects.equals(this.toolSpecifications, that.toolSpecifications)
&& Objects.equals(this.responseFormat, that.responseFormat);
}
@Override
public int hashCode() {
return Objects.hash(messages, toolSpecifications, responseFormat);
}
@Override
public String toString() {
return "ChatRequest {" +
" messages = " + messages +
", toolSpecifications = " + toolSpecifications +
", responseFormat = " + responseFormat +
" }";
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private List<ChatMessage> messages;
private List<ToolSpecification> toolSpecifications;
private ResponseFormat responseFormat;
public Builder messages(List<ChatMessage> messages) {
this.messages = messages;
return this;
}
public Builder messages(ChatMessage... messages) {
return messages(asList(messages));
}
public Builder toolSpecifications(List<ToolSpecification> toolSpecifications) {
this.toolSpecifications = toolSpecifications;
return this;
}
public Builder toolSpecifications(ToolSpecification... toolSpecifications) {
return toolSpecifications(asList(toolSpecifications));
}
public Builder responseFormat(ResponseFormat responseFormat) {
this.responseFormat = responseFormat;
return this;
}
public ChatRequest build() {
return new ChatRequest(this);
}
}
}

View File

@ -0,0 +1,78 @@
package dev.langchain4j.model.chat.request;
import dev.langchain4j.Experimental;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import java.util.Objects;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
@Experimental
public class ResponseFormat {
private final ResponseFormatType type;
private final JsonSchema jsonSchema;
private ResponseFormat(Builder builder) {
this.type = ensureNotNull(builder.type, "type");
this.jsonSchema = builder.jsonSchema;
if (jsonSchema != null && type != JSON) {
throw new IllegalStateException("JsonSchema can be specified only when type=JSON");
}
}
public ResponseFormatType type() {
return type;
}
public JsonSchema jsonSchema() {
return jsonSchema;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ResponseFormat that = (ResponseFormat) o;
return Objects.equals(this.type, that.type)
&& Objects.equals(this.jsonSchema, that.jsonSchema);
}
@Override
public int hashCode() {
return Objects.hash(type, jsonSchema);
}
@Override
public String toString() {
return "ResponseFormat {" +
" type = " + type +
", jsonSchema = " + jsonSchema +
" }";
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ResponseFormatType type;
private JsonSchema jsonSchema;
public Builder type(ResponseFormatType type) {
this.type = type;
return this;
}
public Builder jsonSchema(JsonSchema jsonSchema) {
this.jsonSchema = jsonSchema;
return this;
}
public ResponseFormat build() {
return new ResponseFormat(this);
}
}
}

View File

@ -0,0 +1,9 @@
package dev.langchain4j.model.chat.request;
import dev.langchain4j.Experimental;
@Experimental
public enum ResponseFormatType {
TEXT, JSON
}

View File

@ -0,0 +1,74 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
@Experimental
public class JsonArraySchema implements JsonSchemaElement {
private final String description;
private final JsonSchemaElement items;
public JsonArraySchema(Builder builder) {
this.description = builder.description;
this.items = ensureNotNull(builder.items, "items");
}
public String description() {
return description;
}
public JsonSchemaElement items() {
return items;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
private JsonSchemaElement items;
public Builder description(String description) {
this.description = description;
return this;
}
public Builder items(JsonSchemaElement items) {
this.items = items;
return this;
}
public JsonArraySchema build() {
return new JsonArraySchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonArraySchema that = (JsonArraySchema) o;
return Objects.equals(this.description, that.description)
&& Objects.equals(this.items, that.items);
}
@Override
public int hashCode() {
return Objects.hash(description, items);
}
@Override
public String toString() {
return "JsonArraySchema {" +
"description = " + quoted(description) +
", items = " + items +
" }";
}
}

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
@Experimental
public class JsonBooleanSchema implements JsonSchemaElement {
public static final JsonBooleanSchema JSON_BOOLEAN_SCHEMA = JsonBooleanSchema.builder().build();
private final String description;
public JsonBooleanSchema(Builder builder) {
this.description = builder.description;
}
public String description() {
return description;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
public Builder description(String description) {
this.description = description;
return this;
}
public JsonBooleanSchema build() {
return new JsonBooleanSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonBooleanSchema that = (JsonBooleanSchema) o;
return Objects.equals(this.description, that.description);
}
@Override
public int hashCode() {
return Objects.hash(description);
}
@Override
public String toString() {
return "JsonBooleanSchema {" +
"description = " + quoted(description) +
" }";
}
}

View File

@ -0,0 +1,97 @@
package dev.langchain4j.model.chat.request.json;
import com.google.gson.annotations.SerializedName;
import dev.langchain4j.Experimental;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static java.util.Arrays.asList;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toList;
@Experimental
public class JsonEnumSchema implements JsonSchemaElement {
private final String description;
@SerializedName("enum")
private final List<String> enumValues;
public JsonEnumSchema(Builder builder) {
this.description = builder.description;
this.enumValues = new ArrayList<>(ensureNotEmpty(builder.enumValues, "enumValues"));
}
public String description() {
return description;
}
public List<String> enumValues() {
return enumValues;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
private List<String> enumValues;
public Builder description(String description) {
this.description = description;
return this;
}
public Builder enumValues(List<String> enumValues) {
this.enumValues = enumValues;
return this;
}
public Builder enumValues(String... enumValues) {
return enumValues(asList(enumValues));
}
public Builder enumValues(Class<?> enumClass) {
if (!enumClass.isEnum()) {
throw new RuntimeException("Class " + enumClass.getName() + " must be enum");
}
List<String> enumValues = stream(enumClass.getEnumConstants())
.map(Object::toString)
.collect(toList());
return enumValues(enumValues);
}
public JsonEnumSchema build() {
return new JsonEnumSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonEnumSchema that = (JsonEnumSchema) o;
return Objects.equals(this.description, that.description)
&& Objects.equals(this.enumValues, that.enumValues);
}
@Override
public int hashCode() {
return Objects.hash(description, enumValues);
}
@Override
public String toString() {
return "JsonEnumSchema {" +
"description = " + quoted(description) +
", enumValues = " + enumValues +
" }";
}
}

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
@Experimental
public class JsonIntegerSchema implements JsonSchemaElement {
public static final JsonIntegerSchema JSON_INTEGER_SCHEMA = JsonIntegerSchema.builder().build();
private final String description;
public JsonIntegerSchema(Builder builder) {
this.description = builder.description;
}
public String description() {
return description;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
public Builder description(String description) {
this.description = description;
return this;
}
public JsonIntegerSchema build() {
return new JsonIntegerSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonIntegerSchema that = (JsonIntegerSchema) o;
return Objects.equals(this.description, that.description);
}
@Override
public int hashCode() {
return Objects.hash(description);
}
@Override
public String toString() {
return "JsonIntegerSchema {" +
"description = " + quoted(description) +
" }";
}
}

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
@Experimental
public class JsonNumberSchema implements JsonSchemaElement {
public static final JsonNumberSchema JSON_NUMBER_SCHEMA = JsonNumberSchema.builder().build();
private final String description;
public JsonNumberSchema(Builder builder) {
this.description = builder.description;
}
public String description() {
return description;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
public Builder description(String description) {
this.description = description;
return this;
}
public JsonNumberSchema build() {
return new JsonNumberSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonNumberSchema that = (JsonNumberSchema) o;
return Objects.equals(this.description, that.description);
}
@Override
public int hashCode() {
return Objects.hash(description);
}
@Override
public String toString() {
return "JsonNumberSchema {" +
"description = " + quoted(description) +
" }";
}
}

View File

@ -0,0 +1,114 @@
package dev.langchain4j.model.chat.request.json;
import com.google.gson.annotations.SerializedName;
import dev.langchain4j.Experimental;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.quoted;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static java.util.Arrays.asList;
@Experimental
public class JsonObjectSchema implements JsonSchemaElement {
private final String description;
private final Map<String, JsonSchemaElement> properties;
private final List<String> required;
@SerializedName("additionalProperties")
private final Boolean additionalProperties;
public JsonObjectSchema(Builder builder) {
this.description = builder.description;
this.properties = new LinkedHashMap<>(ensureNotEmpty(builder.properties, "properties"));
this.required = copyIfNotNull(builder.required);
this.additionalProperties = builder.additionalProperties;
}
public String description() {
return description;
}
public Map<String, JsonSchemaElement> properties() {
return properties;
}
public List<String> required() {
return required;
}
public Boolean additionalProperties() {
return additionalProperties;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
private Map<String, JsonSchemaElement> properties = new LinkedHashMap<>();
private List<String> required = new ArrayList<>();
private Boolean additionalProperties;
public Builder description(String description) {
this.description = description;
return this;
}
public Builder properties(Map<String, JsonSchemaElement> properties) {
this.properties = properties;
return this;
}
public Builder required(List<String> required) {
this.required = required;
return this;
}
public Builder required(String... required) {
return required(asList(required));
}
public Builder additionalProperties(Boolean additionalProperties) {
this.additionalProperties = additionalProperties;
return this;
}
public JsonObjectSchema build() {
return new JsonObjectSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonObjectSchema that = (JsonObjectSchema) o;
return Objects.equals(this.description, that.description)
&& Objects.equals(this.properties, that.properties)
&& Objects.equals(this.required, that.required)
&& Objects.equals(this.additionalProperties, that.additionalProperties);
}
@Override
public int hashCode() {
return Objects.hash(description, properties, required, additionalProperties);
}
@Override
public String toString() {
return "JsonObjectSchema {" +
"description = " + quoted(description) +
", properties = " + properties +
", required = " + required +
", additionalProperties = " + additionalProperties +
" }";
}
}

View File

@ -0,0 +1,73 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
@Experimental
public class JsonSchema {
private final String name;
private final JsonObjectSchema schema;
public JsonSchema(Builder builder) {
this.name = builder.name;
this.schema = builder.schema;
}
public String name() {
return name;
}
public JsonObjectSchema schema() {
return schema;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String name;
private JsonObjectSchema schema;
public Builder name(String name) {
this.name = name;
return this;
}
public Builder schema(JsonObjectSchema schema) {
this.schema = schema;
return this;
}
public JsonSchema build() {
return new JsonSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonSchema that = (JsonSchema) o;
return Objects.equals(this.name, that.name)
&& Objects.equals(this.schema, that.schema);
}
@Override
public int hashCode() {
return Objects.hash(name, schema);
}
@Override
public String toString() {
return "JsonSchema {" +
" name = " + quoted(name) +
", schema = " + schema +
" }";
}
}

View File

@ -0,0 +1,8 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
@Experimental
public interface JsonSchemaElement {
}

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.chat.request.json;
import dev.langchain4j.Experimental;
import java.util.Objects;
import static dev.langchain4j.internal.Utils.quoted;
@Experimental
public class JsonStringSchema implements JsonSchemaElement {
public static final JsonStringSchema JSON_STRING_SCHEMA = JsonStringSchema.builder().build();
private final String description;
public JsonStringSchema(Builder builder) {
this.description = builder.description;
}
public String description() {
return description;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String description;
public Builder description(String description) {
this.description = description;
return this;
}
public JsonStringSchema build() {
return new JsonStringSchema(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JsonStringSchema that = (JsonStringSchema) o;
return Objects.equals(this.description, that.description);
}
@Override
public int hashCode() {
return Objects.hash(description);
}
@Override
public String toString() {
return "JsonStringSchema {" +
"description = " + quoted(description) +
" }";
}
}

View File

@ -0,0 +1,90 @@
package dev.langchain4j.model.chat.response;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import java.util.Objects;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
@Experimental
public class ChatResponse {
private final AiMessage aiMessage;
private final TokenUsage tokenUsage;
private final FinishReason finishReason;
private ChatResponse(Builder builder) {
this.aiMessage = ensureNotNull(builder.aiMessage, "aiMessage");
this.tokenUsage = builder.tokenUsage;
this.finishReason = builder.finishReason;
}
public AiMessage aiMessage() {
return aiMessage;
}
public TokenUsage tokenUsage() {
return tokenUsage;
}
public FinishReason finishReason() {
return finishReason;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChatResponse that = (ChatResponse) o;
return Objects.equals(this.aiMessage, that.aiMessage)
&& Objects.equals(this.tokenUsage, that.tokenUsage)
&& Objects.equals(this.finishReason, that.finishReason);
}
@Override
public int hashCode() {
return Objects.hash(aiMessage, tokenUsage, finishReason);
}
@Override
public String toString() {
return "ChatResponse {" +
" aiMessage = " + aiMessage +
", tokenUsage = " + tokenUsage +
", finishReason = " + finishReason +
" }";
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private AiMessage aiMessage;
private TokenUsage tokenUsage;
private FinishReason finishReason;
public Builder aiMessage(AiMessage aiMessage) {
this.aiMessage = aiMessage;
return this;
}
public Builder tokenUsage(TokenUsage tokenUsage) {
this.tokenUsage = tokenUsage;
return this;
}
public Builder finishReason(FinishReason finishReason) {
this.finishReason = finishReason;
return this;
}
public ChatResponse build() {
return new ChatResponse(this);
}
}
}

View File

@ -4,18 +4,20 @@ import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
/**
* Annotation to attach a description to a class field.
*/
@Target(FIELD)
@Target({FIELD, TYPE})
@Retention(RUNTIME)
public @interface Description {
/**
* The description can be defined in one line or multiple lines.
* If the description is defined in multiple lines, the lines will be joined with a space (" ") automatically.
*
* @return The description.
*/
String[] value();

View File

@ -1,31 +1,64 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.chat.*;
import dev.ai4j.openai4j.chat.AssistantMessage;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.ContentType;
import dev.ai4j.openai4j.chat.Function;
import dev.ai4j.openai4j.chat.FunctionCall;
import dev.ai4j.openai4j.chat.FunctionMessage;
import dev.ai4j.openai4j.chat.ImageDetail;
import dev.ai4j.openai4j.chat.ImageUrl;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.Tool;
import dev.ai4j.openai4j.chat.ToolCall;
import dev.ai4j.openai4j.chat.ToolMessage;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static dev.ai4j.openai4j.chat.ContentType.IMAGE_URL;
import static dev.ai4j.openai4j.chat.ContentType.TEXT;
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_OBJECT;
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_SCHEMA;
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.*;
import static dev.langchain4j.model.chat.request.ResponseFormatType.TEXT;
import static dev.langchain4j.model.output.FinishReason.CONTENT_FILTER;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.lang.String.format;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
@ -129,14 +162,14 @@ public class InternalOpenAiHelper {
private static dev.ai4j.openai4j.chat.Content toOpenAiContent(TextContent content) {
return dev.ai4j.openai4j.chat.Content.builder()
.type(TEXT)
.type(ContentType.TEXT)
.text(content.text())
.build();
}
private static dev.ai4j.openai4j.chat.Content toOpenAiContent(ImageContent content) {
return dev.ai4j.openai4j.chat.Content.builder()
.type(IMAGE_URL)
.type(ContentType.IMAGE_URL)
.imageUrl(ImageUrl.builder()
.url(toUrl(content.image()))
.detail(toDetail(content.detailLevel()))
@ -158,23 +191,26 @@ public class InternalOpenAiHelper {
return ImageDetail.valueOf(detailLevel.name());
}
public static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications) {
public static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications, boolean strict) {
return toolSpecifications.stream()
.map(InternalOpenAiHelper::toTool)
.map((ToolSpecification toolSpecification) -> toTool(toolSpecification, strict))
.collect(toList());
}
private static Tool toTool(ToolSpecification toolSpecification) {
Function function = Function.builder()
private static Tool toTool(ToolSpecification toolSpecification, boolean strict) {
Function.Builder functionBuilder = Function.builder()
.name(toolSpecification.name())
.description(toolSpecification.description())
.parameters(toOpenAiParameters(toolSpecification.parameters()))
.build();
.parameters(toOpenAiParameters(toolSpecification.parameters(), strict));
if (strict) {
functionBuilder.strict(true);
}
Function function = functionBuilder.build();
return Tool.from(function);
}
/**
* @deprecated Functions are deprecated by OpenAI, use {@link #toTools(Collection)} instead
* @deprecated Functions are deprecated by OpenAI, use {@link #toTools(Collection, boolean)} instead
*/
@Deprecated
public static List<Function> toFunctions(Collection<ToolSpecification> toolSpecifications) {
@ -184,25 +220,101 @@ public class InternalOpenAiHelper {
}
/**
* @deprecated Functions are deprecated by OpenAI, use {@link #toTool(ToolSpecification)} ()} instead
* @deprecated Functions are deprecated by OpenAI, use {@link #toTool(ToolSpecification, boolean)} instead
*/
@Deprecated
private static Function toFunction(ToolSpecification toolSpecification) {
return Function.builder()
.name(toolSpecification.name())
.description(toolSpecification.description())
.parameters(toOpenAiParameters(toolSpecification.parameters()))
.parameters(toOpenAiParameters(toolSpecification.parameters(), false))
.build();
}
private static dev.ai4j.openai4j.chat.Parameters toOpenAiParameters(ToolParameters toolParameters) {
private static dev.ai4j.openai4j.chat.JsonObjectSchema toOpenAiParameters(ToolParameters toolParameters, boolean strict) {
if (toolParameters == null) {
return dev.ai4j.openai4j.chat.Parameters.builder().build();
dev.ai4j.openai4j.chat.JsonObjectSchema.Builder builder = dev.ai4j.openai4j.chat.JsonObjectSchema.builder();
if (strict) {
// when strict, additionalProperties must be false:
// https://platform.openai.com/docs/guides/structured-outputs/additionalproperties-false-must-always-be-set-in-objects
builder.additionalProperties(false);
}
return builder.build();
}
dev.ai4j.openai4j.chat.JsonObjectSchema.Builder builder = dev.ai4j.openai4j.chat.JsonObjectSchema.builder()
.properties(toOpenAiProperties(toolParameters.properties(), strict))
.required(toolParameters.required());
if (strict) {
builder
// when strict, all fields must be required:
// https://platform.openai.com/docs/guides/structured-outputs/all-fields-must-be-required
.required(new ArrayList<>(toolParameters.properties().keySet()))
// when strict, additionalProperties must be false:
// https://platform.openai.com/docs/guides/structured-outputs/additionalproperties-false-must-always-be-set-in-objects
.additionalProperties(false);
}
return builder.build();
}
private static Map<String, dev.ai4j.openai4j.chat.JsonSchemaElement> toOpenAiProperties(Map<String, ?> properties, boolean strict) {
Map<String, dev.ai4j.openai4j.chat.JsonSchemaElement> openAiProperties = new LinkedHashMap<>();
properties.forEach((key, value) ->
openAiProperties.put(key, toOpenAiJsonSchemaElement((Map<String, ?>) value, strict)));
return openAiProperties;
}
private static dev.ai4j.openai4j.chat.JsonSchemaElement toOpenAiJsonSchemaElement(Map<String, ?> properties, boolean strict) {
// TODO rewrite when JsonSchemaElement will be used for ToolSpecification.properties
Object type = properties.get("type");
String description = (String) properties.get("description");
if ("object".equals(type)) {
List<String> required = (List<String>) properties.get("required");
dev.ai4j.openai4j.chat.JsonObjectSchema.Builder builder = dev.ai4j.openai4j.chat.JsonObjectSchema.builder()
.description(description)
.properties(toOpenAiProperties((Map<String, ?>) properties.get("properties"), strict));
if (required != null) {
builder.required(required);
}
if (strict) {
builder
// when strict, all fields must be required:
// https://platform.openai.com/docs/guides/structured-outputs/all-fields-must-be-required
.required(new ArrayList<>(((Map<String, ?>) properties.get("properties")).keySet()))
// when strict, additionalProperties must be false:
// https://platform.openai.com/docs/guides/structured-outputs/additionalproperties-false-must-always-be-set-in-objects
.additionalProperties(false);
}
return builder.build();
} else if ("array".equals(type)) {
return dev.ai4j.openai4j.chat.JsonArraySchema.builder()
.description(description)
.items(toOpenAiJsonSchemaElement((Map<String, ?>) properties.get("items"), strict))
.build();
} else if (properties.get("enum") != null) {
return dev.ai4j.openai4j.chat.JsonEnumSchema.builder()
.description(description)
.enumValues((List<String>) properties.get("enum"))
.build();
} else if ("string".equals(type)) {
return dev.ai4j.openai4j.chat.JsonStringSchema.builder()
.description(description)
.build();
} else if ("integer".equals(type)) {
return dev.ai4j.openai4j.chat.JsonIntegerSchema.builder()
.description(description)
.build();
} else if ("number".equals(type)) {
return dev.ai4j.openai4j.chat.JsonNumberSchema.builder()
.description(description)
.build();
} else if ("boolean".equals(type)) {
return dev.ai4j.openai4j.chat.JsonBooleanSchema.builder()
.description(description)
.build();
} else {
throw new IllegalArgumentException("Unknown type " + type);
}
return dev.ai4j.openai4j.chat.Parameters.builder()
.properties(toolParameters.properties())
.required(toolParameters.required())
.build();
}
public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
@ -317,4 +429,64 @@ public class InternalOpenAiHelper {
.aiMessage(response.content())
.build();
}
static dev.ai4j.openai4j.chat.ResponseFormat toOpenAiResponseFormat(ResponseFormat responseFormat, Boolean strict) {
if (responseFormat == null || responseFormat.type() == TEXT) {
return null;
}
JsonSchema jsonSchema = responseFormat.jsonSchema();
if (jsonSchema == null) {
return new dev.ai4j.openai4j.chat.ResponseFormat(JSON_OBJECT, null);
} else {
dev.ai4j.openai4j.chat.JsonSchema openAiJsonSchema = dev.ai4j.openai4j.chat.JsonSchema.builder()
.name(jsonSchema.name())
.strict(strict)
.schema((dev.ai4j.openai4j.chat.JsonObjectSchema) toOpenAiJsonSchemaElement(jsonSchema.schema()))
.build();
return new dev.ai4j.openai4j.chat.ResponseFormat(JSON_SCHEMA, openAiJsonSchema);
}
}
private static dev.ai4j.openai4j.chat.JsonSchemaElement toOpenAiJsonSchemaElement(JsonSchemaElement jsonSchemaElement) {
if (jsonSchemaElement instanceof JsonStringSchema) {
return dev.ai4j.openai4j.chat.JsonStringSchema.builder()
.description(((JsonStringSchema) jsonSchemaElement).description())
.build();
} else if (jsonSchemaElement instanceof JsonIntegerSchema) {
return dev.ai4j.openai4j.chat.JsonIntegerSchema.builder()
.description(((JsonIntegerSchema) jsonSchemaElement).description())
.build();
} else if (jsonSchemaElement instanceof JsonNumberSchema) {
return dev.ai4j.openai4j.chat.JsonNumberSchema.builder()
.description(((JsonNumberSchema) jsonSchemaElement).description())
.build();
} else if (jsonSchemaElement instanceof JsonBooleanSchema) {
return dev.ai4j.openai4j.chat.JsonBooleanSchema.builder()
.description(((JsonBooleanSchema) jsonSchemaElement).description())
.build();
} else if (jsonSchemaElement instanceof JsonEnumSchema) {
return dev.ai4j.openai4j.chat.JsonEnumSchema.builder()
.description(((JsonEnumSchema) jsonSchemaElement).description())
.enumValues(((JsonEnumSchema) jsonSchemaElement).enumValues())
.build();
} else if (jsonSchemaElement instanceof JsonArraySchema) {
return dev.ai4j.openai4j.chat.JsonArraySchema.builder()
.description(((JsonArraySchema) jsonSchemaElement).description())
.items(toOpenAiJsonSchemaElement(((JsonArraySchema) jsonSchemaElement).items()))
.build();
} else if (jsonSchemaElement instanceof JsonObjectSchema) {
Map<String, JsonSchemaElement> properties = ((JsonObjectSchema) jsonSchemaElement).properties();
Map<String, dev.ai4j.openai4j.chat.JsonSchemaElement> openAiProperties = new LinkedHashMap<>();
properties.forEach((key, value) -> openAiProperties.put(key, toOpenAiJsonSchemaElement(value)));
return dev.ai4j.openai4j.chat.JsonObjectSchema.builder()
.description(((JsonObjectSchema) jsonSchemaElement).description())
.properties(openAiProperties)
.required(((JsonObjectSchema) jsonSchemaElement).required())
.additionalProperties(((JsonObjectSchema) jsonSchemaElement).additionalProperties())
.build();
} else {
throw new IllegalArgumentException("Unknown type: " + jsonSchemaElement);
}
}
}

View File

@ -4,13 +4,22 @@ import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
@ -19,13 +28,27 @@ import lombok.extern.slf4j.Slf4j;
import java.net.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
@ -48,9 +71,12 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final Double presencePenalty;
private final Double frequencyPenalty;
private final Map<String, Integer> logitBias;
private final String responseFormat;
private final ResponseFormat responseFormat;
private final Boolean strictJsonSchema;
private final Integer seed;
private final String user;
private final Boolean strictTools;
private final Boolean parallelToolCalls;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final List<ChatModelListener> listeners;
@ -68,8 +94,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Boolean strictJsonSchema,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Integer maxRetries,
Proxy proxy,
@ -108,9 +137,12 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat = responseFormat;
this.responseFormat = responseFormat == null ? null : new ResponseFormat(responseFormat, null);
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
this.strictTools = getOrDefault(strictTools, false);
this.parallelToolCalls = parallelToolCalls;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
@ -120,25 +152,56 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
return modelName;
}
@Override
public Set<Capability> supportedCapabilities() {
Set<Capability> capabilities = new HashSet<>();
if (responseFormat != null && "json_schema".equals(responseFormat.type())) {
capabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
}
return capabilities;
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null);
return generate(messages, null, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
return generate(messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response = generate(
request.messages(),
request.toolSpecifications(),
null,
getOrDefault(toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema), this.responseFormat)
);
return ChatResponse.builder()
.aiMessage(response.content())
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.build();
}
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null
&& "json_schema".equals(responseFormat.type())
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
@ -151,10 +214,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);
.user(user)
.parallelToolCalls(parallelToolCalls);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.tools(toTools(toolSpecifications));
requestBuilder.tools(toTools(toolSpecifications, strictTools));
}
if (toolThatMustBeExecuted != null) {
requestBuilder.toolChoice(toolThatMustBeExecuted.name());

View File

@ -54,6 +54,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
private final String responseFormat;
private final Integer seed;
private final String user;
private final Boolean strictTools;
private final Boolean parallelToolCalls;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;
private final List<ChatModelListener> listeners;
@ -73,6 +75,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
String responseFormat,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Proxy proxy,
Boolean logRequests,
@ -108,6 +112,8 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
this.responseFormat = responseFormat;
this.seed = seed;
this.user = user;
this.strictTools = getOrDefault(strictTools, false);
this.parallelToolCalls = parallelToolCalls;
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.isOpenAiModel = isOpenAiModel(this.modelName);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
@ -150,13 +156,14 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);
.user(user)
.parallelToolCalls(parallelToolCalls);
if (toolThatMustBeExecuted != null) {
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted)));
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted), strictTools));
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
requestBuilder.tools(toTools(toolSpecifications, strictTools));
}
ChatCompletionRequest request = requestBuilder.build();

View File

@ -44,11 +44,6 @@ class OpenAiTokenizerIT {
GPT_4_32K_0613
));
private static final Set<ChatCompletionModel> MODELS_WITHOUT_TOOL_SUPPORT = new HashSet<>(asList(
GPT_4_0314,
GPT_4_VISION_PREVIEW
));
private static final Set<ChatCompletionModel> MODELS_WITH_PARALLEL_TOOL_SUPPORT = new HashSet<>(asList(
// TODO add GPT_3_5_TURBO once it points to GPT_3_5_TURBO_1106
GPT_3_5_TURBO_1106,
@ -151,7 +146,6 @@ class OpenAiTokenizerIT {
static Stream<Arguments> should_count_tokens_in_messages_with_single_tool() {
return stream(ChatCompletionModel.values())
.filter(model -> !MODELS_WITHOUT_ACCESS.contains(model))
.filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model))
.flatMap(model -> Stream.of(
// various tool "name" lengths
@ -798,7 +792,6 @@ class OpenAiTokenizerIT {
static Stream<Arguments> should_count_tokens_in_tool_specifications() {
return stream(ChatCompletionModel.values())
.filter(model -> !MODELS_WITHOUT_ACCESS.contains(model))
.filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model))
.flatMap(model -> Stream.of(
// "name" of various lengths
@ -1146,7 +1139,6 @@ class OpenAiTokenizerIT {
static Stream<Arguments> should_count_tokens_in_tool_execution_request() {
return stream(ChatCompletionModel.values())
.filter(model -> !MODELS_WITHOUT_ACCESS.contains(model))
.filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model))
.flatMap(model -> Stream.of(
// no arguments, different lengths of "name"

View File

@ -18,7 +18,7 @@
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.build.outputTimestamp>1714382357</project.build.outputTimestamp>
<openai4j.version>0.17.0</openai4j.version>
<openai4j.version>0.18.0</openai4j.version>
<azure-ai-openai.version>1.0.0-beta.10</azure-ai-openai.version>
<azure-ai-search.version>11.7.0</azure-ai-search.version>
<azure.storage-blob.version>12.27.0</azure.storage-blob.version>

View File

@ -1,10 +1,16 @@
package dev.langchain4j.service;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
@ -19,8 +25,19 @@ import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutor;
import java.io.InputStream;
import java.lang.reflect.*;
import java.util.*;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@ -29,7 +46,10 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Exceptions.runtime;
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
import static dev.langchain4j.service.output.JsonSchemas.jsonSchemaFrom;
class DefaultAiServices<T> extends AiServices<T> {
@ -111,12 +131,16 @@ class DefaultAiServices<T> extends AiServices<T> {
// TODO give user ability to provide custom OutputParser
Type returnType = method.getGenericReturnType();
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
String text = userMessage.singleText() + outputFormatInstructions;
if (isNotNullOrBlank(userMessage.name())) {
userMessage = UserMessage.from(userMessage.name(), text);
} else {
userMessage = UserMessage.from(text);
boolean supportsJsonSchema = supportsJsonSchema();
Optional<JsonSchema> jsonSchema = Optional.empty();
if (supportsJsonSchema) {
jsonSchema = jsonSchemaFrom(returnType);
}
if (!supportsJsonSchema || !jsonSchema.isPresent()) {
// TODO append after storing in the memory?
userMessage = appendOutputFormatInstructions(returnType, userMessage);
}
if (context.hasChatMemory()) {
@ -140,9 +164,31 @@ class DefaultAiServices<T> extends AiServices<T> {
return new AiServiceTokenStream(messages, context, memoryId); // TODO moderation
}
Response<AiMessage> response = context.toolSpecifications == null
? context.chatModel.generate(messages)
: context.chatModel.generate(messages, context.toolSpecifications);
Response<AiMessage> response;
if (supportsJsonSchema && jsonSchema.isPresent()) {
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.toolSpecifications(context.toolSpecifications)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema.get())
.build())
.build();
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
response = new Response<>(
chatResponse.aiMessage(),
chatResponse.tokenUsage(),
chatResponse.finishReason()
);
} else {
// TODO migrate to new API
response = context.toolSpecifications == null
? context.chatModel.generate(messages)
: context.chatModel.generate(messages, context.toolSpecifications);
}
TokenUsage tokenUsageAccumulator = response.tokenUsage();
verifyModerationIfNeeded(moderationFuture);
@ -206,6 +252,22 @@ class DefaultAiServices<T> extends AiServices<T> {
}
}
private boolean supportsJsonSchema() {
return context.chatModel != null
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
}
private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) {
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
String text = userMessage.singleText() + outputFormatInstructions;
if (isNotNullOrBlank(userMessage.name())) {
userMessage = UserMessage.from(userMessage.name(), text);
} else {
userMessage = UserMessage.from(text);
}
return userMessage;
}
private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
if (method.isAnnotationPresent(Moderate.class)) {
return executor.submit(() -> {

View File

@ -0,0 +1,187 @@
package dev.langchain4j.service.output;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.TypeUtils.isJsonBoolean;
import static dev.langchain4j.internal.TypeUtils.isJsonInteger;
import static dev.langchain4j.internal.TypeUtils.isJsonNumber;
import static dev.langchain4j.service.TypeUtils.getRawClass;
import static dev.langchain4j.service.TypeUtils.resolveFirstGenericParameterClass;
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
import static java.lang.reflect.Modifier.isStatic;
@Experimental
public class JsonSchemas {
public static Optional<JsonSchema> jsonSchemaFrom(Type returnType) {
if (typeHasRawClass(returnType, Result.class)) {
returnType = resolveFirstGenericParameterClass(returnType);
}
// TODO validate this earlier
if (returnType == void.class) {
throw illegalConfiguration("Return type of method '%s' cannot be void");
}
if (!isPojo(returnType)) {
return Optional.empty();
}
Class<?> rawClass = getRawClass(returnType);
JsonSchema jsonSchema = JsonSchema.builder()
.name(rawClass.getSimpleName())
.schema(toJsonObjectSchema(rawClass, null))
.build();
return Optional.of(jsonSchema);
}
private static boolean isPojo(Type returnType) {
if (returnType == String.class
|| returnType == AiMessage.class
|| returnType == TokenStream.class
|| returnType == Response.class) {
return false;
}
// Explanation (which will make this a lot easier to understand):
// In the case of List<String> these two would be set like:
// rawClass: List.class
// typeArgumentClass: String.class
Class<?> rawClass = getRawClass(returnType);
Class<?> typeArgumentClass = TypeUtils.resolveFirstGenericParameterClass(returnType);
Optional<OutputParser<?>> outputParser = new DefaultOutputParserFactory().get(rawClass, typeArgumentClass);
if (outputParser.isPresent()) {
return false;
}
return true;
}
private static JsonObjectSchema toJsonObjectSchema(Class<?> type, String description) {
Map<String, JsonSchemaElement> properties = new LinkedHashMap<>();
for (Field field : type.getDeclaredFields()) {
String fieldName = field.getName();
if (isStatic(field.getModifiers()) || fieldName.equals("__$hits$__") || fieldName.startsWith("this$")) {
continue;
}
String fieldDescription = getDescription(field);
JsonSchemaElement jsonSchemaElement = jsonSchema(field.getType(), field.getGenericType(), fieldDescription);
properties.put(fieldName, jsonSchemaElement);
}
return JsonObjectSchema.builder()
.description(Optional.ofNullable(description).orElse(getDescription(type)))
.properties(properties)
.required(new ArrayList<>(properties.keySet()))
.additionalProperties(false)
.build();
}
private static String getDescription(Field field) {
return getDescription(field.getAnnotation(Description.class));
}
private static String getDescription(Class<?> type) {
return getDescription(type.getAnnotation(Description.class));
}
private static String getDescription(Description description) {
if (description == null) {
return null;
}
return String.join(" ", description.value());
}
private static JsonSchemaElement jsonSchema(Class<?> clazz, Type type, String fieldDescription) {
if (clazz == String.class) {
return JsonStringSchema.builder()
.description(fieldDescription)
.build();
}
if (isJsonInteger(clazz)) {
return JsonIntegerSchema.builder()
.description(fieldDescription)
.build();
}
if (isJsonNumber(clazz)) {
return JsonNumberSchema.builder()
.description(fieldDescription)
.build();
}
if (isJsonBoolean(clazz)) {
return JsonBooleanSchema.builder()
.description(fieldDescription)
.build();
}
if (clazz.isEnum()) {
return JsonEnumSchema.builder()
.enumValues(clazz)
.description(Optional.ofNullable(fieldDescription).orElse(getDescription(clazz)))
.build();
}
if (clazz.isArray()) {
return JsonArraySchema.builder()
.items(jsonSchema(clazz.getComponentType(), null, null))
.description(fieldDescription)
.build();
}
if (clazz.equals(List.class) || clazz.equals(Set.class)) {
return JsonArraySchema.builder()
.items(jsonSchema(getActualType(type), null, null))
.description(fieldDescription)
.build();
}
return toJsonObjectSchema(clazz, fieldDescription);
}
private static Class<?> getActualType(Type type) {
if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
if (actualTypeArguments.length == 1) {
return (Class<?>) actualTypeArguments[0];
}
}
return null;
}
}

View File

@ -83,6 +83,7 @@ public class ServiceOutputParser {
return "";
}
// TODO validate this earlier
if (returnType == void.class) {
throw illegalConfiguration("Return type of method '%s' cannot be void");
}

View File

@ -27,15 +27,23 @@ import java.util.List;
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
import static dev.langchain4j.service.AiServicesIT.Ingredient.*;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.*;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
import static dev.langchain4j.service.AiServicesIT.Ingredient.OIL;
import static dev.langchain4j.service.AiServicesIT.Ingredient.PEPPER;
import static dev.langchain4j.service.AiServicesIT.Ingredient.SALT;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.COMFORT_ISSUE;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.MAINTENANCE_ISSUE;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.OVERALL_EXPERIENCE_ISSUE;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.SERVICE_ISSUE;
import static dev.langchain4j.service.AiServicesIT.Sentiment.POSITIVE;
import static java.time.Month.JULY;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ExtendWith(MockitoExtension.class)
public class AiServicesIT {
@ -45,6 +53,7 @@ public class AiServicesIT {
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O_MINI)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
@ -85,6 +94,7 @@ public class AiServicesIT {
verify(chatLanguageModel).generate(singletonList(userMessage("Count number of 'egg' mentions in this sentence:\n" +
"|||I have ten eggs in my basket and three in my pocket.|||\n" +
"You must answer strictly in the following format: integer number")));
verify(chatLanguageModel).supportedCapabilities();
}
@ -105,6 +115,7 @@ public class AiServicesIT {
assertThat(joke).isNotBlank();
verify(chatLanguageModel).generate(singletonList(userMessage("Tell me a joke about AI")));
verify(chatLanguageModel).supportedCapabilities();
}
@ -135,6 +146,7 @@ public class AiServicesIT {
verify(chatLanguageModel).generate(singletonList(userMessage(
"Extract date from " + text + "\n" +
"You must answer strictly in the following format: yyyy-MM-dd")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -152,6 +164,7 @@ public class AiServicesIT {
verify(chatLanguageModel).generate(singletonList(userMessage(
"Extract time from " + text + "\n" +
"You must answer strictly in the following format: HH:mm:ss")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -169,6 +182,7 @@ public class AiServicesIT {
verify(chatLanguageModel).generate(singletonList(userMessage(
"Extract date and time from " + text + "\n" +
"You must answer strictly in the following format: yyyy-MM-ddTHH:mm:ss")));
verify(chatLanguageModel).supportedCapabilities();
}
@ -201,6 +215,7 @@ public class AiServicesIT {
"POSITIVE\n" +
"NEUTRAL\n" +
"NEGATIVE")));
verify(chatLanguageModel).supportedCapabilities();
}
public enum Weather {
@ -239,6 +254,7 @@ public class AiServicesIT {
"CLOUDY - The sky is covered with clouds with no rain, often creating a gray and overcast appearance\n" +
"RAINY - Precipitation in the form of rain, with cloudy skies and wet conditions\n" +
"SNOWY - Snowfall occurs, covering the ground in white and creating cold, wintry conditions")));
verify(chatLanguageModel).supportedCapabilities();
}
public enum Ingredient {
@ -270,6 +286,7 @@ public class AiServicesIT {
"PEPPER\n" +
"VINEGAR\n" +
"OIL")));
verify(chatLanguageModel).supportedCapabilities();
}
public enum IssueCategory {
@ -320,6 +337,7 @@ public class AiServicesIT {
"CONNECTIVITY_ISSUE - The feedback mentions issues with internet connectivity, such as unreliable Wi-Fi\n" +
"CHECK_IN_ISSUE - The feedback mentions issues with the check-in process, such as it being tedious and time-consuming\n" +
"OVERALL_EXPERIENCE_ISSUE - The feedback mentions a general dissatisfaction with the overall hotel experience due to multiple issues")));
verify(chatLanguageModel).supportedCapabilities();
}
@ToString
@ -377,6 +395,7 @@ public class AiServicesIT {
"\"city\": (type: string)\n" +
"})\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -424,6 +443,7 @@ public class AiServicesIT {
"\"city\": (type: string)\n" +
"})\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
@ -481,6 +501,7 @@ public class AiServicesIT {
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -504,6 +525,7 @@ public class AiServicesIT {
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -527,6 +549,7 @@ public class AiServicesIT {
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -550,6 +573,7 @@ public class AiServicesIT {
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
interface BadChef {
@ -626,6 +650,7 @@ public class AiServicesIT {
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -657,6 +682,7 @@ public class AiServicesIT {
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -688,6 +714,7 @@ public class AiServicesIT {
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")
));
verify(chatLanguageModel).supportedCapabilities();
}
interface ProfessionalChef {
@ -712,6 +739,7 @@ public class AiServicesIT {
systemMessage("You are a professional chef. You are friendly, polite and concise."),
userMessage(question)
));
verify(chatLanguageModel).supportedCapabilities();
}
@ -738,6 +766,7 @@ public class AiServicesIT {
systemMessage("You are a professional translator into german"),
userMessage("Translate the following text: Hello, how are you?")
));
verify(chatLanguageModel).supportedCapabilities();
}
interface Summarizer {
@ -764,6 +793,7 @@ public class AiServicesIT {
systemMessage("Summarize every message from user in 3 bullet points. Provide only bullet points."),
userMessage(text + "\nYou must put every item on a separate line.")
));
verify(chatLanguageModel).supportedCapabilities();
}
@ -788,6 +818,7 @@ public class AiServicesIT {
.hasMessage("Text \"" + message + "\" violates content policy");
verify(chatLanguageModel).generate(singletonList(userMessage(message)));
verify(chatLanguageModel).supportedCapabilities();
verify(moderationModel).moderate(singletonList(userMessage(message)));
}
@ -806,6 +837,7 @@ public class AiServicesIT {
assertThat(response).isNotBlank();
verify(chatLanguageModel).generate(singletonList(userMessage(message)));
verify(chatLanguageModel).supportedCapabilities();
verify(moderationModel).moderate(singletonList(userMessage(message)));
}
@ -839,6 +871,7 @@ public class AiServicesIT {
assertThat(result.sources()).isNull();
verify(chatLanguageModel).generate(singletonList(userMessage(userMessage)));
verify(chatLanguageModel).supportedCapabilities();
}
@ -877,6 +910,6 @@ public class AiServicesIT {
"\"bookingId\": (type: string)\n" +
"}")
));
verify(chatLanguageModel).supportedCapabilities();
}
}

View File

@ -0,0 +1,901 @@
package dev.langchain4j.service;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
import static dev.langchain4j.model.chat.request.json.JsonBooleanSchema.JSON_BOOLEAN_SCHEMA;
import static dev.langchain4j.model.chat.request.json.JsonIntegerSchema.JSON_INTEGER_SCHEMA;
import static dev.langchain4j.model.chat.request.json.JsonNumberSchema.JSON_NUMBER_SCHEMA;
import static dev.langchain4j.model.chat.request.json.JsonStringSchema.JSON_STRING_SCHEMA;
import static dev.langchain4j.service.AiServicesJsonSchemaIT.PersonExtractor3.MaritalStatus.SINGLE;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
public class AiServicesJsonSchemaIT {
static Stream<ChatLanguageModel> models() {
return Stream.of(
OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O_MINI)
.responseFormat("json_schema")
.strictJsonSchema(true)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build(),
OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O_MINI)
.responseFormat("json_schema")
.strictJsonSchema(false)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build()
);
}
interface PersonExtractor1 {
class Person {
String name;
int age;
Double height;
boolean married;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_primitives(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor1 personExtractor = AiServices.create(PersonExtractor1.class, model);
String text = "Klaus is 37 years old, 1.78m height and single";
// when
PersonExtractor1.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.age).isEqualTo(37);
assertThat(person.height).isEqualTo(1.78);
assertThat(person.married).isFalse();
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("age", JSON_INTEGER_SCHEMA);
put("height", JSON_NUMBER_SCHEMA);
put("married", JSON_BOOLEAN_SCHEMA);
}})
.required("name", "age", "height", "married")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor2 {
class Person {
String name;
Address address;
}
class Address {
String city;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_nested_pojo(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor2 personExtractor = AiServices.create(PersonExtractor2.class, model);
String text = "Klaus lives in Langley Falls";
// when
PersonExtractor2.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.address.city).isEqualTo("Langley Falls");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("address", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("city", JSON_STRING_SCHEMA);
}})
.required("city")
.additionalProperties(false)
.build());
}})
.required("name", "address")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor3 {
class Person {
String name;
MaritalStatus maritalStatus;
}
enum MaritalStatus {
SINGLE, MARRIED
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_enum(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor3 personExtractor = AiServices.create(PersonExtractor3.class, model);
String text = "Klaus is single";
// when
PersonExtractor3.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.maritalStatus).isEqualTo(SINGLE);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("maritalStatus", JsonEnumSchema.builder()
.enumValues("SINGLE", "MARRIED")
.build());
}})
.required("name", "maritalStatus")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor4 {
class Person {
String name;
String[] favouriteColors;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_array_of_primitives(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor4 personExtractor = AiServices.create(PersonExtractor4.class, model);
String text = "Klaus likes orange and green";
// when
PersonExtractor4.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.favouriteColors).containsExactly("orange", "green");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("favouriteColors", JsonArraySchema.builder()
.items(JSON_STRING_SCHEMA)
.build());
}})
.required("name", "favouriteColors")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor5 {
class Person {
String name;
List<String> favouriteColors;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_list_of_primitives(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor5 personExtractor = AiServices.create(PersonExtractor5.class, model);
String text = "Klaus likes orange and green";
// when
PersonExtractor5.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.favouriteColors).containsExactly("orange", "green");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("favouriteColors", JsonArraySchema.builder()
.items(JSON_STRING_SCHEMA)
.build());
}})
.required("name", "favouriteColors")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor6 {
class Person {
String name;
Set<String> favouriteColors;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_set_of_primitives(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor6 personExtractor = AiServices.create(PersonExtractor6.class, model);
String text = "Klaus likes orange and green";
// when
PersonExtractor6.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.favouriteColors).containsExactly("orange", "green");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("favouriteColors", JsonArraySchema.builder()
.items(JSON_STRING_SCHEMA)
.build());
}})
.required("name", "favouriteColors")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor7 {
class Person {
String name;
Pet[] pets;
}
class Pet {
String name;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_array_of_pojos(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor7 personExtractor = AiServices.create(PersonExtractor7.class, model);
String text = "Klaus has 2 pets: Peanut and Muffin";
// when
PersonExtractor7.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.pets).hasSize(2);
assertThat(person.pets[0].name).isEqualTo("Peanut");
assertThat(person.pets[1].name).isEqualTo("Muffin");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("pets", JsonArraySchema.builder()
.items(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
}})
.required("name")
.additionalProperties(false)
.build())
.build());
}})
.required("name", "pets")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor8 {
class Person {
String name;
List<Pet> pets;
}
class Pet {
String name;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_list_of_pojos(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor8 personExtractor = AiServices.create(PersonExtractor8.class, model);
String text = "Klaus has 2 pets: Peanut and Muffin";
// when
PersonExtractor8.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.pets).hasSize(2);
assertThat(person.pets.get(0).name).isEqualTo("Peanut");
assertThat(person.pets.get(1).name).isEqualTo("Muffin");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("pets", JsonArraySchema.builder()
.items(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
}})
.required("name")
.additionalProperties(false)
.build())
.build());
}})
.required("name", "pets")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor9 {
class Person {
String name;
Set<Pet> pets;
}
class Pet {
String name;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_set_of_pojos(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor9 personExtractor = AiServices.create(PersonExtractor9.class, model);
String text = "Klaus has 2 pets: Peanut and Muffin";
// when
PersonExtractor9.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.pets).hasSize(2);
Iterator<PersonExtractor9.Pet> iterator = person.pets.iterator();
assertThat(iterator.next().name).isEqualTo("Peanut");
assertThat(iterator.next().name).isEqualTo("Muffin");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("pets", JsonArraySchema.builder()
.items(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
}})
.required("name")
.additionalProperties(false)
.build())
.build());
}})
.required("name", "pets")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor10 {
class Person {
String name;
Group[] groups;
}
enum Group {
A, B, C
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_array_of_enums(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor10 personExtractor = AiServices.create(PersonExtractor10.class, model);
String text = "Klaus is assigned to groups A and C";
// when
PersonExtractor10.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.groups).containsExactly(PersonExtractor10.Group.A, PersonExtractor10.Group.C);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("groups", JsonArraySchema.builder()
.items(JsonEnumSchema.builder()
.enumValues("A", "B", "C")
.build())
.build());
}})
.required("name", "groups")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor11 {
class Person {
String name;
List<Group> groups;
}
enum Group {
A, B, C
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_list_of_enums(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor11 personExtractor = AiServices.create(PersonExtractor11.class, model);
String text = "Klaus is assigned to groups A and C";
// when
PersonExtractor11.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.groups).containsExactly(PersonExtractor11.Group.A, PersonExtractor11.Group.C);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("groups", JsonArraySchema.builder()
.items(JsonEnumSchema.builder()
.enumValues("A", "B", "C")
.build())
.build());
}})
.required("name", "groups")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor12 {
class Person {
String name;
Set<Group> groups;
}
enum Group {
A, B, C
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_set_of_enums(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor12 personExtractor = AiServices.create(PersonExtractor12.class, model);
String text = "Klaus is assigned to groups A and C";
// when
PersonExtractor12.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.groups).containsExactly(PersonExtractor12.Group.A, PersonExtractor12.Group.C);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("groups", JsonArraySchema.builder()
.items(JsonEnumSchema.builder()
.enumValues("A", "B", "C")
.build())
.build());
}})
.required("name", "groups")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor13 {
class Person {
String name;
LocalDate birthDate;
LocalTime birthTime;
LocalDateTime birthDateTime;
}
Person extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_extract_pojo_with_local_date_time_fields(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor13 personExtractor = AiServices.create(PersonExtractor13.class, model);
String text = "Klaus was born at 14:43:26 on 12th of August 1976";
// when
PersonExtractor13.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.birthDate).isEqualTo(LocalDate.of(1976, 8, 12));
assertThat(person.birthTime).isEqualTo(LocalTime.of(14, 43, 26));
assertThat(person.birthDateTime)
.isEqualTo(LocalDateTime.of(1976, 8, 12, 14, 43, 26));
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
put("birthDate", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("year", JSON_INTEGER_SCHEMA);
put("month", JSON_INTEGER_SCHEMA);
put("day", JSON_INTEGER_SCHEMA);
}})
.required("year", "month", "day")
.additionalProperties(false)
.build());
put("birthTime", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("hour", JSON_INTEGER_SCHEMA);
put("minute", JSON_INTEGER_SCHEMA);
put("second", JSON_INTEGER_SCHEMA);
put("nano", JSON_INTEGER_SCHEMA);
}})
.required("hour", "minute", "second", "nano")
.additionalProperties(false)
.build());
put("birthDateTime", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("date", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("year", JSON_INTEGER_SCHEMA);
put("month", JSON_INTEGER_SCHEMA);
put("day", JSON_INTEGER_SCHEMA);
}})
.required("year", "month", "day")
.additionalProperties(false)
.build());
put("time", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("hour", JSON_INTEGER_SCHEMA);
put("minute", JSON_INTEGER_SCHEMA);
put("second", JSON_INTEGER_SCHEMA);
put("nano", JSON_INTEGER_SCHEMA);
}})
.required("hour", "minute", "second", "nano")
.additionalProperties(false)
.build());
}})
.required("date", "time")
.additionalProperties(false)
.build());
}})
.required("name", "birthDate", "birthTime", "birthDateTime")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor14 {
class Person {
String name;
}
Result<Person> extractPersonFrom(String text);
}
@ParameterizedTest
@MethodSource("models")
void should_return_result_with_pojo(ChatLanguageModel model) {
// given
model = spy(model);
PersonExtractor14 personExtractor = AiServices.create(PersonExtractor14.class, model);
String text = "Klaus";
// when
Result<PersonExtractor14.Person> result = personExtractor.extractPersonFrom(text);
PersonExtractor14.Person person = result.content();
// then
assertThat(person.name).isEqualTo("Klaus");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JSON_STRING_SCHEMA);
}})
.required("name")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
}

View File

@ -0,0 +1,969 @@
package dev.langchain4j.service;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
import static dev.langchain4j.model.chat.request.json.JsonIntegerSchema.JSON_INTEGER_SCHEMA;
import static dev.langchain4j.model.chat.request.json.JsonStringSchema.JSON_STRING_SCHEMA;
import static dev.langchain4j.service.AiServicesJsonSchemaWithDescriptionsIT.PersonExtractor3.MaritalStatus.SINGLE;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ExtendWith(MockitoExtension.class)
public class AiServicesJsonSchemaWithDescriptionsIT {
@Spy
ChatLanguageModel model = OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O_MINI)
.responseFormat("json_schema")
.strictJsonSchema(true)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
@AfterEach
void afterEach() {
verifyNoMoreInteractions(model);
}
interface PersonExtractor1 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("an age")
int age;
@Description("a height")
Double height;
@Description("married or not")
boolean married;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_primitives() {
// given
PersonExtractor1 personExtractor = AiServices.create(PersonExtractor1.class, model);
String text = "Klaus is 37 years old, 1.78m height and single";
// when
PersonExtractor1.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.age).isEqualTo(37);
assertThat(person.height).isEqualTo(1.78);
assertThat(person.married).isFalse();
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("age", JsonIntegerSchema.builder()
.description("an age")
.build());
put("height", JsonNumberSchema.builder()
.description("a height")
.build());
put("married", JsonBooleanSchema.builder()
.description("married or not")
.build());
}})
.required("name", "age", "height", "married")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor2 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("an address override")
Address address;
}
@Description("an address")
class Address {
@Description("a city")
String city;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_nested_pojo() {
// given
PersonExtractor2 personExtractor = AiServices.create(PersonExtractor2.class, model);
String text = "Klaus lives in Langley Falls";
// when
PersonExtractor2.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.address.city).isEqualTo("Langley Falls");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("address", JsonObjectSchema.builder()
.description("an address override")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("city", JsonStringSchema.builder()
.description("a city")
.build());
}})
.required("city")
.additionalProperties(false)
.build());
}})
.required("name", "address")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor3 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("marital status override")
MaritalStatus maritalStatus;
}
@Description("marital status")
enum MaritalStatus {
SINGLE, MARRIED
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_enum() {
// given
PersonExtractor3 personExtractor = AiServices.create(PersonExtractor3.class, model);
String text = "Klaus is single";
// when
PersonExtractor3.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.maritalStatus).isEqualTo(SINGLE);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("maritalStatus", JsonEnumSchema.builder()
.enumValues("SINGLE", "MARRIED")
.description("marital status override")
.build());
}})
.required("name", "maritalStatus")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor4 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("favourite colors")
String[] favouriteColors;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_array_of_primitives() {
// given
PersonExtractor4 personExtractor = AiServices.create(PersonExtractor4.class, model);
String text = "Klaus likes orange and green";
// when
PersonExtractor4.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.favouriteColors).containsExactly("orange", "green");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("favouriteColors", JsonArraySchema.builder()
.items(JSON_STRING_SCHEMA)
.description("favourite colors")
.build());
}})
.required("name", "favouriteColors")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor5 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("favourite colors")
List<String> favouriteColors;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_list_of_primitives() {
// given
PersonExtractor5 personExtractor = AiServices.create(PersonExtractor5.class, model);
String text = "Klaus likes orange and green";
// when
PersonExtractor5.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.favouriteColors).containsExactly("orange", "green");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("favouriteColors", JsonArraySchema.builder()
.items(JSON_STRING_SCHEMA)
.description("favourite colors")
.build());
}})
.required("name", "favouriteColors")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor6 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("favourite colors")
Set<String> favouriteColors;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_set_of_primitives() {
// given
PersonExtractor6 personExtractor = AiServices.create(PersonExtractor6.class, model);
String text = "Klaus likes orange and green";
// when
PersonExtractor6.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.favouriteColors).containsExactly("orange", "green");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("favouriteColors", JsonArraySchema.builder()
.items(JSON_STRING_SCHEMA)
.description("favourite colors")
.build());
}})
.required("name", "favouriteColors")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor7 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("pets of a person")
Pet[] pets;
}
@Description("a pet")
class Pet {
@Description("a name of a pet")
String name;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_array_of_pojos() {
// given
PersonExtractor7 personExtractor = AiServices.create(PersonExtractor7.class, model);
String text = "Klaus has 2 pets: Peanut and Muffin";
// when
PersonExtractor7.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.pets).hasSize(2);
assertThat(person.pets[0].name).isEqualTo("Peanut");
assertThat(person.pets[1].name).isEqualTo("Muffin");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("pets", JsonArraySchema.builder()
.items(JsonObjectSchema.builder()
.description("a pet")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name of a pet")
.build());
}})
.required("name")
.additionalProperties(false)
.build())
.description("pets of a person")
.build());
}})
.required("name", "pets")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor8 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("pets of a person")
List<Pet> pets;
}
@Description("a pet")
class Pet {
@Description("a name of a pet")
String name;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_list_of_pojos() {
// given
PersonExtractor8 personExtractor = AiServices.create(PersonExtractor8.class, model);
String text = "Klaus has 2 pets: Peanut and Muffin";
// when
PersonExtractor8.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.pets).hasSize(2);
assertThat(person.pets.get(0).name).isEqualTo("Peanut");
assertThat(person.pets.get(1).name).isEqualTo("Muffin");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("pets", JsonArraySchema.builder()
.items(JsonObjectSchema.builder()
.description("a pet")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name of a pet")
.build());
}})
.required("name")
.additionalProperties(false)
.build())
.description("pets of a person")
.build());
}})
.required("name", "pets")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor9 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("pets of a person")
Set<Pet> pets;
}
@Description("a pet")
class Pet {
@Description("a name of a pet")
String name;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_set_of_pojos() {
// given
PersonExtractor9 personExtractor = AiServices.create(PersonExtractor9.class, model);
String text = "Klaus has 2 pets: Peanut and Muffin";
// when
PersonExtractor9.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.pets).hasSize(2);
Iterator<PersonExtractor9.Pet> iterator = person.pets.iterator();
assertThat(iterator.next().name).isEqualTo("Peanut");
assertThat(iterator.next().name).isEqualTo("Muffin");
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("pets", JsonArraySchema.builder()
.items(JsonObjectSchema.builder()
.description("a pet")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name of a pet")
.build());
}})
.required("name")
.additionalProperties(false)
.build())
.description("pets of a person")
.build());
}})
.required("name", "pets")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor10 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("groups")
Group[] groups;
}
@Description("a group")
enum Group {
A, B, C
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_array_of_enums() {
// given
PersonExtractor10 personExtractor = AiServices.create(PersonExtractor10.class, model);
String text = "Klaus is assigned to groups A and C";
// when
PersonExtractor10.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.groups).containsExactly(PersonExtractor10.Group.A, PersonExtractor10.Group.C);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("groups", JsonArraySchema.builder()
.description("groups")
.items(JsonEnumSchema.builder()
.description("a group")
.enumValues("A", "B", "C")
.build())
.build());
}})
.required("name", "groups")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor11 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("groups")
List<Group> groups;
}
@Description("a group")
enum Group {
A, B, C
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_list_of_enums() {
// given
PersonExtractor11 personExtractor = AiServices.create(PersonExtractor11.class, model);
String text = "Klaus is assigned to groups A and C";
// when
PersonExtractor11.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.groups).containsExactly(PersonExtractor11.Group.A, PersonExtractor11.Group.C);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("groups", JsonArraySchema.builder()
.description("groups")
.items(JsonEnumSchema.builder()
.description("a group")
.enumValues("A", "B", "C")
.build())
.build());
}})
.required("name", "groups")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor12 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("groups")
Set<Group> groups;
}
@Description("a group")
enum Group {
A, B, C
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_set_of_enums() {
// given
PersonExtractor12 personExtractor = AiServices.create(PersonExtractor12.class, model);
String text = "Klaus is assigned to groups A and C";
// when
PersonExtractor12.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.groups).containsExactly(PersonExtractor12.Group.A, PersonExtractor12.Group.C);
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("groups", JsonArraySchema.builder()
.description("groups")
.items(JsonEnumSchema.builder()
.description("a group")
.enumValues("A", "B", "C")
.build())
.build());
}})
.required("name", "groups")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
interface PersonExtractor13 {
@Description("a person")
class Person {
@Description("a name")
String name;
@Description("a birth date")
LocalDate birthDate;
@Description("a birth time")
LocalTime birthTime;
@Description("a birth date and time")
LocalDateTime birthDateTime;
}
Person extractPersonFrom(String text);
}
@Test
void should_extract_pojo_with_local_date_time_fields() {
// given
PersonExtractor13 personExtractor = AiServices.create(PersonExtractor13.class, model);
String text = "Klaus was born at 14:43:26 on 12th of August 1976";
// when
PersonExtractor13.Person person = personExtractor.extractPersonFrom(text);
// then
assertThat(person.name).isEqualTo("Klaus");
assertThat(person.birthDate).isEqualTo(LocalDate.of(1976, 8, 12));
assertThat(person.birthTime).isEqualTo(LocalTime.of(14, 43, 26));
assertThat(person.birthDateTime)
.isEqualTo(LocalDateTime.of(1976, 8, 12, 14, 43, 26));
verify(model).chat(ChatRequest.builder()
.messages(singletonList(userMessage(text)))
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.name("Person")
.schema(JsonObjectSchema.builder()
.description("a person")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("name", JsonStringSchema.builder()
.description("a name")
.build());
put("birthDate", JsonObjectSchema.builder()
.description("a birth date")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("year", JSON_INTEGER_SCHEMA);
put("month", JSON_INTEGER_SCHEMA);
put("day", JSON_INTEGER_SCHEMA);
}})
.required("year", "month", "day")
.additionalProperties(false)
.build());
put("birthTime", JsonObjectSchema.builder()
.description("a birth time")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("hour", JSON_INTEGER_SCHEMA);
put("minute", JSON_INTEGER_SCHEMA);
put("second", JSON_INTEGER_SCHEMA);
put("nano", JSON_INTEGER_SCHEMA);
}})
.required("hour", "minute", "second", "nano")
.additionalProperties(false)
.build());
put("birthDateTime", JsonObjectSchema.builder()
.description("a birth date and time")
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("date", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("year", JSON_INTEGER_SCHEMA);
put("month", JSON_INTEGER_SCHEMA);
put("day", JSON_INTEGER_SCHEMA);
}})
.required("year", "month", "day")
.additionalProperties(false)
.build());
put("time", JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("hour", JSON_INTEGER_SCHEMA);
put("minute", JSON_INTEGER_SCHEMA);
put("second", JSON_INTEGER_SCHEMA);
put("nano", JSON_INTEGER_SCHEMA);
}})
.required("hour", "minute", "second", "nano")
.additionalProperties(false)
.build());
}})
.required("date", "time")
.additionalProperties(false)
.build());
}})
.required("name", "birthDate", "birthTime", "birthDateTime")
.additionalProperties(false)
.build())
.build())
.build())
.build());
verify(model).supportedCapabilities();
}
}

View File

@ -129,6 +129,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -146,6 +147,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -163,6 +165,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -180,6 +183,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -197,6 +201,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -214,6 +219,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -231,6 +237,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -248,6 +255,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -265,6 +273,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -282,6 +291,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -300,6 +310,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -318,6 +329,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -336,6 +348,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -354,6 +367,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -372,6 +386,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -390,6 +405,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -408,6 +424,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -426,6 +443,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -444,6 +462,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -462,6 +481,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("Given a name of a country, answer with a name of it's capital"),
userMessage("Country: Germany")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -480,6 +500,7 @@ class AiServicesSystemAndUserMessageConfigsTest {
systemMessage("This message should take precedence over the one provided by systemMessageProvider"),
userMessage("What is the capital of Germany?")
));
verify(chatLanguageModel).supportedCapabilities();
}
@Test

View File

@ -79,6 +79,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat1("What is the capital of Germany?"))
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -93,6 +94,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat2("What is the capital of Germany?"))
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -107,6 +109,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat3("What is the capital of {{country}}?", "Germany"))
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -121,6 +124,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat4())
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -135,6 +139,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat5("Germany"))
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -149,6 +154,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat6("Germany"))
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test
@ -163,6 +169,7 @@ class AiServicesUserMessageConfigTest {
assertThat(aiService.chat7("capital", "Germany"))
.containsIgnoringCase("Berlin");
verify(chatLanguageModel).generate(singletonList(userMessage("What is the capital of Germany?")));
verify(chatLanguageModel).supportedCapabilities();
}
@Test

View File

@ -27,7 +27,9 @@ import static dev.langchain4j.service.AiServicesWithChatMemoryIT.ChatWithMemory.
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ExtendWith(MockitoExtension.class)
class AiServicesWithChatMemoryIT {
@ -142,6 +144,7 @@ class AiServicesWithChatMemoryIT {
));
verify(chatMemory).add(aiMessage(fourthAiMessage));
verify(chatLanguageModel, times(4)).supportedCapabilities();
verify(chatMemory, times(12)).messages();
}
@ -171,6 +174,7 @@ class AiServicesWithChatMemoryIT {
aiMessage(firstAiMessage),
userMessage(secondUserMessage)
));
verify(chatLanguageModel, times(2)).supportedCapabilities();
verify(chatMemory, times(2)).add(systemMessage(SYSTEM_MESSAGE));
verify(chatMemory).add(userMessage(firstUserMessage));
@ -207,6 +211,7 @@ class AiServicesWithChatMemoryIT {
systemMessage(ANOTHER_SYSTEM_MESSAGE),
userMessage(secondUserMessage)
));
verify(chatLanguageModel, times(2)).supportedCapabilities();
verify(chatMemory).add(systemMessage(SYSTEM_MESSAGE));
verify(chatMemory).add(userMessage(firstUserMessage));
@ -304,5 +309,7 @@ class AiServicesWithChatMemoryIT {
userMessage(secondMessageFromSecondUser),
aiMessage(secondAiResponseToSecondUser)
);
verify(chatLanguageModel, times(4)).supportedCapabilities();
}
}

View File

@ -3,7 +3,10 @@ package dev.langchain4j.service;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
@ -33,7 +36,7 @@ import java.util.stream.Stream;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.model.mistralai.MistralAiChatModelName.MISTRAL_LARGE_LATEST;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO_0613;
import static dev.langchain4j.model.openai.OpenAiChatModelName.*;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.service.AiServicesWithToolsIT.Operator.EQUALS;
import static dev.langchain4j.service.AiServicesWithToolsIT.TemperatureUnit.Kelvin;
@ -54,6 +57,17 @@ class AiServicesWithToolsIT {
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build(),
OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O_MINI)
.strictTools(true)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
@ -616,7 +630,7 @@ class AiServicesWithToolsIT {
Operator operator;
@Description("Value to compare with")
Object value;
String value;
}
enum Operator {
@ -678,6 +692,33 @@ class AiServicesWithToolsIT {
assertThat(response.content().text()).contains("2027");
}
static class Clock {
@Tool
String currentTime() {
return "16:37:43";
}
}
@ParameterizedTest
@MethodSource("models")
void should_execute_tool_without_parameters(ChatLanguageModel chatLanguageModel) {
// given
Clock clock = spy(new Clock());
Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.tools(clock)
.build();
// when
Response<AiMessage> response = assistant.chat("What is the time now?");
// then
assertThat(response.content().text()).contains("16:37:43");
}
private static Map<String, Object> toMap(String arguments) {
try {
return new ObjectMapper().readValue(arguments, new TypeReference<Map<String, Object>>() {

View File

@ -0,0 +1,220 @@
package dev.langchain4j.service.output;
import com.google.gson.reflect.TypeToken;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.Result;
import org.junit.jupiter.api.Test;
import java.time.LocalDate;
import java.util.Optional;
import static dev.langchain4j.service.output.JsonSchemas.jsonSchemaFrom;
import static org.assertj.core.api.Assertions.assertThat;
class JsonSchemasTest {
class Pojo {
String field;
}
@Test
void should_return_json_schema_for_pojos() {
assertThat(jsonSchemaFrom(Pojo.class)).isPresent();
assertThat(jsonSchemaFrom(new TypeToken<Result<Pojo>>() {
}.getType())).isPresent();
}
@Test
void should_return_empty_for_not_pojos() {
assertThat(jsonSchemaFrom(String.class)).isEmpty();
assertThat(jsonSchemaFrom(AiMessage.class)).isEmpty();
assertThat(jsonSchemaFrom(Response.class)).isEmpty();
assertThat(jsonSchemaFrom(Integer.class)).isEmpty();
assertThat(jsonSchemaFrom(LocalDate.class)).isEmpty();
assertThat(jsonSchemaFrom(new TypeToken<Result<String>>() {
}.getType())).isEmpty();
}
// POJO
@Test
void should_take_pojo_description_from_the_field() {
// given
class Address {
String street;
String city;
}
class Person {
@Description("an address")
Address address;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonObjectSchema addressSchema = (JsonObjectSchema) jsonSchema.get().schema().properties().get("address");
assertThat(addressSchema.description()).isEqualTo("an address");
}
@Test
void should_take_pojo_description_from_the_class() {
// given
@Description("an address")
class Address {
String street;
String city;
}
class Person {
Address address;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonObjectSchema addressSchema = (JsonObjectSchema) jsonSchema.get().schema().properties().get("address");
assertThat(addressSchema.description()).isEqualTo("an address");
}
@Test
void pojo_field_description_should_override_class_description() {
// given
@Description("an address")
class Address {
String street;
String city;
}
class Person {
@Description("an address 2")
Address address;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonObjectSchema addressSchema = (JsonObjectSchema) jsonSchema.get().schema().properties().get("address");
assertThat(addressSchema.description()).isEqualTo("an address 2");
}
// ENUM
enum MaritalStatus {
SINGLE, MARRIED
}
@Test
void should_take_enum_description_from_the_field() {
// given
class Person {
@Description("marital status")
MaritalStatus maritalStatus;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonEnumSchema maritalStatusSchema = (JsonEnumSchema) jsonSchema.get().schema().properties().get("maritalStatus");
assertThat(maritalStatusSchema.description()).isEqualTo("marital status");
}
@Description("marital status")
enum MaritalStatus2 {
SINGLE, MARRIED
}
@Test
void should_take_enum_description_from_the_enum() {
// given
class Person {
MaritalStatus2 maritalStatus;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonEnumSchema maritalStatusSchema = (JsonEnumSchema) jsonSchema.get().schema().properties().get("maritalStatus");
assertThat(maritalStatusSchema.description()).isEqualTo("marital status");
}
@Description("marital status")
enum MaritalStatus3 {
SINGLE, MARRIED
}
@Test
void enum_field_description_should_override_class_description() {
// given
class Person {
@Description("marital status 2")
MaritalStatus3 maritalStatus;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonEnumSchema maritalStatusSchema = (JsonEnumSchema) jsonSchema.get().schema().properties().get("maritalStatus");
assertThat(maritalStatusSchema.description()).isEqualTo("marital status 2");
}
// ARRAY
@Test
void should_take_array_description_from_the_field() {
// given
class Pet {
String name;
}
class Person {
@Description("pets of a person")
Pet[] pets;
}
// when
Optional<JsonSchema> jsonSchema = jsonSchemaFrom(Person.class);
// then
JsonArraySchema petsSchema = (JsonArraySchema) jsonSchema.get().schema().properties().get("pets");
assertThat(petsSchema.description()).isEqualTo("pets of a person");
}
}