bug(Google AI Gemini) — fix mapping for tools with parameters with nested object structures (#1732)
* bug(Google AI Gemini) — fix mapping for tools with parameters with nested object structures * fix(JsonSchemas) — minor improvement to cover all cases of strings
This commit is contained in:
parent
21d35e4434
commit
421b4cd048
|
@ -22,4 +22,13 @@ public class TypeUtils {
|
|||
public static boolean isJsonBoolean(Class<?> type) {
|
||||
return type == boolean.class || type == Boolean.class;
|
||||
}
|
||||
|
||||
public static boolean isJsonString(Class<?> type) {
|
||||
return type == String.class || type == char.class || type == Character.class
|
||||
|| CharSequence.class.isAssignableFrom(type);
|
||||
}
|
||||
|
||||
public static boolean isJsonArray(Class<?> type) {
|
||||
return type.isArray() || Iterable.class.isAssignableFrom(type);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
package dev.langchain4j.internal;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Deque;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class TypeUtilsTest {
|
||||
|
||||
@Test
|
||||
void testStringIsJsonCompatible() {
|
||||
assertThat(TypeUtils.isJsonString(char.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonString(String.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonString(Character.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonString(StringBuffer.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonString(StringBuilder.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonString(CharSequence.class)).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCollectionIsJsonCompatible() {
|
||||
assertThat(TypeUtils.isJsonArray(String[].class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonArray(Integer[].class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonArray(int[].class)).isTrue();
|
||||
|
||||
assertThat(TypeUtils.isJsonArray(List.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonArray(Set.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonArray(Deque.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonArray(Collection.class)).isTrue();
|
||||
assertThat(TypeUtils.isJsonArray(Iterable.class)).isTrue();
|
||||
}
|
||||
}
|
|
@ -2,12 +2,13 @@ package dev.langchain4j.model.googleai;
|
|||
|
||||
import com.google.gson.Gson;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolParameters;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static java.util.Collections.emptyMap;
|
||||
|
||||
class FunctionMapper {
|
||||
|
||||
private static final Gson GSON = new Gson();
|
||||
|
@ -34,6 +35,7 @@ class FunctionMapper {
|
|||
.map(specification -> {
|
||||
GeminiFunctionDeclaration.GeminiFunctionDeclarationBuilder fnBuilder =
|
||||
GeminiFunctionDeclaration.builder();
|
||||
|
||||
if (specification.name() != null) {
|
||||
fnBuilder.name(specification.name());
|
||||
}
|
||||
|
@ -41,30 +43,13 @@ class FunctionMapper {
|
|||
fnBuilder.description(specification.description());
|
||||
}
|
||||
if (specification.parameters() != null) {
|
||||
ToolParameters parameters = specification.parameters();
|
||||
Map<String, Map<String, Object>> properties = specification.parameters().properties();
|
||||
|
||||
final String[] propName = {""};
|
||||
fnBuilder.parameters(GeminiSchema.builder()
|
||||
.type(GeminiType.OBJECT)
|
||||
.properties(parameters.properties().entrySet().stream()
|
||||
.map(prop -> {
|
||||
propName[0] = prop.getKey();
|
||||
Map<String, Object> propAttributes = prop.getValue();
|
||||
|
||||
String typeString = (String) propAttributes.getOrDefault("type", "string");
|
||||
GeminiType type = GeminiType.valueOf(typeString.toUpperCase());
|
||||
String description = (String) propAttributes.getOrDefault("description", null);
|
||||
|
||||
//TODO need to deal with nested objects
|
||||
|
||||
return GeminiSchema.builder()
|
||||
.description(description)
|
||||
.type(type)
|
||||
.build();
|
||||
})
|
||||
.collect(Collectors.toMap(schema -> propName[0], schema -> schema)))
|
||||
.build());
|
||||
String type = "object";
|
||||
String description = specification.description();
|
||||
fnBuilder.parameters(fromMap(type, null, null, properties));
|
||||
}
|
||||
|
||||
return fnBuilder.build();
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
|
@ -77,10 +62,50 @@ class FunctionMapper {
|
|||
return tool.build();
|
||||
}
|
||||
|
||||
private static GeminiSchema fromMap(String type, String arrayType, String description, Map<String, Map<String, Object>> obj) {
|
||||
GeminiSchema.GeminiSchemaBuilder schemaBuilder = GeminiSchema.builder();
|
||||
|
||||
schemaBuilder.type(GeminiType.valueOf(type.toUpperCase()));
|
||||
schemaBuilder.description(description);
|
||||
|
||||
if (type.equals("array")) {
|
||||
Map<String, Map<String, Object>> arrayObj = (Map<String, Map<String, Object>>) obj.values().iterator().next().get("properties");
|
||||
|
||||
schemaBuilder.items(fromMap(arrayType, null, description, arrayObj));
|
||||
} else {
|
||||
Map<String, GeminiSchema> props = new LinkedHashMap<>();
|
||||
if (obj != null) {
|
||||
for (Map.Entry<String, Map<String, Object>> oneProperty : obj.entrySet()) {
|
||||
String propName = oneProperty.getKey();
|
||||
Map<String, Object> propAttributes = oneProperty.getValue();
|
||||
String propTypeString = (String) propAttributes.getOrDefault("type", "string");
|
||||
String propDescription = (String) propAttributes.getOrDefault("description", null);
|
||||
Map<String, Map<String, Object>> childProps =
|
||||
(Map<String, Map<String, Object>>) propAttributes.getOrDefault("properties", emptyMap());
|
||||
Map<String, Object> items = (Map<String, Object>) propAttributes.get("items");
|
||||
Map<String, Map<String, Object>> singleProp = new HashMap<>();
|
||||
singleProp.put(propName, items);
|
||||
|
||||
if (items != null) {
|
||||
String itemsType = items.get("type").toString();
|
||||
props.put(propName, fromMap(propTypeString, itemsType, propDescription, singleProp));
|
||||
} else {
|
||||
props.put(propName, fromMap(propTypeString, null, propDescription, childProps));
|
||||
}
|
||||
}
|
||||
}
|
||||
schemaBuilder.properties(props);
|
||||
}
|
||||
|
||||
return schemaBuilder.build();
|
||||
}
|
||||
|
||||
static List<ToolExecutionRequest> fromToolExecReqToGFunCall(List<GeminiFunctionCall> functionCalls) {
|
||||
return functionCalls.stream().map(functionCall -> ToolExecutionRequest.builder()
|
||||
.name(functionCall.getName())
|
||||
.arguments(GSON.toJson(functionCall.getArgs()))
|
||||
.build()).collect(Collectors.toList());
|
||||
return functionCalls.stream()
|
||||
.map(functionCall -> ToolExecutionRequest.builder()
|
||||
.name(functionCall.getName())
|
||||
.arguments(GSON.toJson(functionCall.getArgs()))
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
enum GeminiType {
|
||||
TYPE_UNSPECIFIED,
|
||||
STRING,
|
||||
NUMBER,
|
||||
INTEGER,
|
||||
|
|
|
@ -0,0 +1,258 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.agent.tool.*;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class FunctionMapperTest {
|
||||
|
||||
enum Projection {
|
||||
WGS84,
|
||||
NAD83,
|
||||
PZ90,
|
||||
GCJ02,
|
||||
BD09
|
||||
}
|
||||
|
||||
static class Coordinates {
|
||||
@Description("latitude")
|
||||
double latitude;
|
||||
@Description("latitude")
|
||||
double longitude;
|
||||
@Description("Geographic projection system used")
|
||||
Projection projection;
|
||||
|
||||
public Coordinates(double latitude, double longitude) {
|
||||
this.latitude = latitude;
|
||||
this.longitude = longitude;
|
||||
this.projection = Projection.WGS84;
|
||||
}
|
||||
}
|
||||
|
||||
static class IssTool {
|
||||
@Tool("Get the distance between the user and the ISS.")
|
||||
int distanceBetween(
|
||||
@P("user coordinates") Coordinates userCoordinates,
|
||||
@P("ISS coordinates") Coordinates issCoordinates
|
||||
) {
|
||||
return 3456;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_convert_nested_structures() {
|
||||
// when
|
||||
List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(IssTool.class);
|
||||
System.out.println("\ntoolSpecifications = " + toolSpecifications);
|
||||
|
||||
// then
|
||||
assertThat(toolSpecifications.size()).isEqualTo(1);
|
||||
assertThat(toolSpecifications.get(0).name()).isEqualTo("distanceBetween");
|
||||
assertThat(toolSpecifications.get(0).description()).isEqualTo("Get the distance between the user and the ISS.");
|
||||
assertThat(toolSpecifications.get(0).parameters().type()).isEqualTo("object");
|
||||
assertThat(toolSpecifications.get(0).parameters().properties().size()).isEqualTo(2);
|
||||
assertThat(toolSpecifications.get(0).parameters().properties().keySet()).containsAll(Arrays.asList("userCoordinates", "issCoordinates"));
|
||||
|
||||
// when
|
||||
GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(toolSpecifications, false);
|
||||
System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString()));
|
||||
|
||||
// then
|
||||
List<GeminiFunctionDeclaration> allGFnDecl = geminiTool.getFunctionDeclarations();
|
||||
assertThat(allGFnDecl.size()).isEqualTo(1);
|
||||
|
||||
GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0);
|
||||
assertThat(gFnDecl.getName()).isEqualTo("distanceBetween");
|
||||
|
||||
assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT);
|
||||
Map<String, GeminiSchema> props = gFnDecl.getParameters().getProperties();
|
||||
|
||||
assertThat(props.size()).isEqualTo(2);
|
||||
assertThat(props.keySet()).containsAll(Arrays.asList("userCoordinates", "issCoordinates"));
|
||||
|
||||
GeminiSchema userCoord = props.get("userCoordinates");
|
||||
assertThat(userCoord.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
|
||||
GeminiSchema issCoord = props.get("issCoordinates");
|
||||
assertThat(issCoord.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
|
||||
assertThat(userCoord.getProperties().size()).isEqualTo(3);
|
||||
assertThat(issCoord.getProperties().size()).isEqualTo(3);
|
||||
|
||||
assertThat(userCoord.getProperties().keySet()).containsAll(Arrays.asList("latitude", "longitude", "projection"));
|
||||
assertThat(issCoord.getProperties().keySet()).containsAll(Arrays.asList("latitude", "longitude", "projection"));
|
||||
}
|
||||
|
||||
static class Address {
|
||||
private final String street;
|
||||
private final String zipCode;
|
||||
private final String city;
|
||||
|
||||
public Address(String street, String zipCode, String city) {
|
||||
this.street = street;
|
||||
this.zipCode = zipCode;
|
||||
this.city = city;
|
||||
}
|
||||
}
|
||||
|
||||
static class Customer {
|
||||
private final String firstname;
|
||||
private final String lastname;
|
||||
|
||||
private final Address shippingAddress;
|
||||
// private final Address billingAddress;
|
||||
|
||||
public Customer(String firstname, String lastname,
|
||||
Address shippingAddress
|
||||
// Address billingAddress
|
||||
) {
|
||||
this.firstname = firstname;
|
||||
this.lastname = lastname;
|
||||
this.shippingAddress = shippingAddress;
|
||||
// this.billingAddress = billingAddress;
|
||||
}
|
||||
}
|
||||
|
||||
static class Product {
|
||||
private final String name;
|
||||
private final String description;
|
||||
private final double price;
|
||||
|
||||
public Product(String name, String description, double price) {
|
||||
this.name = name;
|
||||
this.description = description;
|
||||
this.price = price;
|
||||
}
|
||||
}
|
||||
|
||||
static class LineItem {
|
||||
private final Product product;
|
||||
private final int quantity;
|
||||
|
||||
public LineItem(int quantity, Product product) {
|
||||
this.product = product;
|
||||
this.quantity = quantity;
|
||||
}
|
||||
}
|
||||
|
||||
static class Order {
|
||||
private final Double totalAmount;
|
||||
private final List<LineItem> lineItems;
|
||||
private final Customer customer;
|
||||
|
||||
public Order(Double totalAmount, List<LineItem> lineItems, Customer customer) {
|
||||
this.totalAmount = totalAmount;
|
||||
this.lineItems = lineItems;
|
||||
this.customer = customer;
|
||||
}
|
||||
}
|
||||
|
||||
static class OrderSystem {
|
||||
@Tool("Make an order")
|
||||
boolean makeOrder(@P(value = "The order to make") Order order) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testComplexNestedGraph() {
|
||||
// given
|
||||
List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(OrderSystem.class);
|
||||
System.out.println("\ntoolSpecifications = " + toolSpecifications);
|
||||
|
||||
// when
|
||||
GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(toolSpecifications, false);
|
||||
System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString()));
|
||||
|
||||
// then
|
||||
List<GeminiFunctionDeclaration> allGFnDecl = geminiTool.getFunctionDeclarations();
|
||||
assertThat(allGFnDecl.size()).isEqualTo(1);
|
||||
|
||||
GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0);
|
||||
assertThat(gFnDecl.getName()).isEqualTo("makeOrder");
|
||||
assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT);
|
||||
|
||||
Map<String, GeminiSchema> props = gFnDecl.getParameters().getProperties();
|
||||
assertThat(props.size()).isEqualTo(1);
|
||||
assertThat(props.keySet()).containsExactly("order");
|
||||
|
||||
GeminiSchema orderSchema = props.get("order");
|
||||
assertThat(orderSchema.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
assertThat(orderSchema.getProperties().size()).isEqualTo(3);
|
||||
assertThat(orderSchema.getProperties().keySet()).containsAll(Arrays.asList("totalAmount", "lineItems", "customer"));
|
||||
|
||||
GeminiSchema totalAmount = orderSchema.getProperties().get("totalAmount");
|
||||
assertThat(totalAmount.getType()).isEqualTo(GeminiType.NUMBER);
|
||||
|
||||
GeminiSchema lineItems = orderSchema.getProperties().get("lineItems");
|
||||
assertThat(lineItems.getType()).isEqualTo(GeminiType.ARRAY);
|
||||
|
||||
GeminiSchema lineItemsItems = lineItems.getItems();
|
||||
assertThat(lineItemsItems.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
assertThat(lineItemsItems.getProperties().size()).isEqualTo(2);
|
||||
assertThat(lineItemsItems.getProperties().keySet()).containsAll(Arrays.asList("product", "quantity"));
|
||||
|
||||
GeminiSchema product = lineItemsItems.getProperties().get("product");
|
||||
assertThat(product.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
assertThat(product.getProperties().size()).isEqualTo(3);
|
||||
assertThat(product.getProperties().keySet()).containsAll(Arrays.asList("name", "description", "price"));
|
||||
|
||||
GeminiSchema customer = orderSchema.getProperties().get("customer");
|
||||
assertThat(customer.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
assertThat(customer.getProperties().size()).isEqualTo(3);
|
||||
assertThat(customer.getProperties().keySet()).containsAll(Arrays.asList("firstname", "lastname", "shippingAddress"));
|
||||
|
||||
GeminiSchema shippingAddress = customer.getProperties().get("shippingAddress");
|
||||
assertThat(shippingAddress.getType()).isEqualTo(GeminiType.OBJECT);
|
||||
assertThat(shippingAddress.getProperties().size()).isEqualTo(3);
|
||||
assertThat(shippingAddress.getProperties().keySet()).containsAll(Arrays.asList("street", "zipCode", "city"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testArray() {
|
||||
// given
|
||||
ToolSpecification spec = ToolSpecification.builder()
|
||||
.name("toolName")
|
||||
.description("tool description")
|
||||
.addParameter("arrayParameter", ARRAY, items(STRING), description("an array"))
|
||||
.build();
|
||||
|
||||
System.out.println("\nspec = " + spec);
|
||||
|
||||
// when
|
||||
GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(Arrays.asList(spec), false);
|
||||
System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString()));
|
||||
|
||||
// then
|
||||
List<GeminiFunctionDeclaration> allGFnDecl = geminiTool.getFunctionDeclarations();
|
||||
assertThat(allGFnDecl.size()).isEqualTo(1);
|
||||
GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0);
|
||||
assertThat(gFnDecl.getName()).isEqualTo("toolName");
|
||||
assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT);
|
||||
|
||||
Map<String, GeminiSchema> props = gFnDecl.getParameters().getProperties();
|
||||
System.out.println("props = " + withoutNullValues(props.toString()));
|
||||
assertThat(props.size()).isEqualTo(1);
|
||||
assertThat(props.keySet()).containsExactly("arrayParameter");
|
||||
|
||||
GeminiSchema arrayParameter = props.get("arrayParameter");
|
||||
assertThat(arrayParameter.getType()).isEqualTo(GeminiType.ARRAY);
|
||||
assertThat(arrayParameter.getItems().getType()).isEqualTo(GeminiType.STRING);
|
||||
assertThat(arrayParameter.getItems().getDescription()).isEqualTo("an array");
|
||||
assertThat(arrayParameter.getItems().getItems()).isNull();
|
||||
assertThat(arrayParameter.getItems().getProperties()).isEmpty();
|
||||
}
|
||||
|
||||
private static String withoutNullValues(String toString) {
|
||||
return toString
|
||||
.replaceAll("(, )?(?<=(, |\\())[^\\s(]+?=null(?:, )?", " ")
|
||||
.replaceFirst(", \\)$", ")");
|
||||
}
|
||||
}
|
|
@ -31,6 +31,7 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
|
|||
import static dev.langchain4j.internal.TypeUtils.isJsonBoolean;
|
||||
import static dev.langchain4j.internal.TypeUtils.isJsonInteger;
|
||||
import static dev.langchain4j.internal.TypeUtils.isJsonNumber;
|
||||
import static dev.langchain4j.internal.TypeUtils.isJsonString;
|
||||
import static dev.langchain4j.service.TypeUtils.getRawClass;
|
||||
import static dev.langchain4j.service.TypeUtils.resolveFirstGenericParameterClass;
|
||||
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
|
||||
|
@ -126,7 +127,7 @@ public class JsonSchemas {
|
|||
|
||||
private static JsonSchemaElement jsonSchema(Class<?> clazz, Type type, String fieldDescription) {
|
||||
|
||||
if (clazz == String.class) {
|
||||
if (isJsonString(clazz)) {
|
||||
return JsonStringSchema.builder()
|
||||
.description(fieldDescription)
|
||||
.build();
|
||||
|
|
Loading…
Reference in New Issue