bug(Google AI Gemini) — fix mapping for tools with parameters with nested object structures (#1732)

* bug(Google AI Gemini) — fix mapping for tools with parameters with
nested object structures
* fix(JsonSchemas) — minor improvement to cover all cases of strings
This commit is contained in:
Guillaume Laforge 2024-09-09 10:15:23 +02:00 committed by GitHub
parent 21d35e4434
commit 421b4cd048
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 357 additions and 29 deletions

View File

@ -22,4 +22,13 @@ public class TypeUtils {
public static boolean isJsonBoolean(Class<?> type) { public static boolean isJsonBoolean(Class<?> type) {
return type == boolean.class || type == Boolean.class; return type == boolean.class || type == Boolean.class;
} }
public static boolean isJsonString(Class<?> type) {
return type == String.class || type == char.class || type == Character.class
|| CharSequence.class.isAssignableFrom(type);
}
public static boolean isJsonArray(Class<?> type) {
return type.isArray() || Iterable.class.isAssignableFrom(type);
}
} }

View File

@ -0,0 +1,36 @@
package dev.langchain4j.internal;
import org.junit.jupiter.api.Test;
import java.util.Collection;
import java.util.Deque;
import java.util.List;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
public class TypeUtilsTest {
@Test
void testStringIsJsonCompatible() {
assertThat(TypeUtils.isJsonString(char.class)).isTrue();
assertThat(TypeUtils.isJsonString(String.class)).isTrue();
assertThat(TypeUtils.isJsonString(Character.class)).isTrue();
assertThat(TypeUtils.isJsonString(StringBuffer.class)).isTrue();
assertThat(TypeUtils.isJsonString(StringBuilder.class)).isTrue();
assertThat(TypeUtils.isJsonString(CharSequence.class)).isTrue();
}
@Test
void testCollectionIsJsonCompatible() {
assertThat(TypeUtils.isJsonArray(String[].class)).isTrue();
assertThat(TypeUtils.isJsonArray(Integer[].class)).isTrue();
assertThat(TypeUtils.isJsonArray(int[].class)).isTrue();
assertThat(TypeUtils.isJsonArray(List.class)).isTrue();
assertThat(TypeUtils.isJsonArray(Set.class)).isTrue();
assertThat(TypeUtils.isJsonArray(Deque.class)).isTrue();
assertThat(TypeUtils.isJsonArray(Collection.class)).isTrue();
assertThat(TypeUtils.isJsonArray(Iterable.class)).isTrue();
}
}

View File

@ -2,12 +2,13 @@ package dev.langchain4j.model.googleai;
import com.google.gson.Gson; import com.google.gson.Gson;
import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static java.util.Collections.emptyMap;
class FunctionMapper { class FunctionMapper {
private static final Gson GSON = new Gson(); private static final Gson GSON = new Gson();
@ -34,6 +35,7 @@ class FunctionMapper {
.map(specification -> { .map(specification -> {
GeminiFunctionDeclaration.GeminiFunctionDeclarationBuilder fnBuilder = GeminiFunctionDeclaration.GeminiFunctionDeclarationBuilder fnBuilder =
GeminiFunctionDeclaration.builder(); GeminiFunctionDeclaration.builder();
if (specification.name() != null) { if (specification.name() != null) {
fnBuilder.name(specification.name()); fnBuilder.name(specification.name());
} }
@ -41,30 +43,13 @@ class FunctionMapper {
fnBuilder.description(specification.description()); fnBuilder.description(specification.description());
} }
if (specification.parameters() != null) { if (specification.parameters() != null) {
ToolParameters parameters = specification.parameters(); Map<String, Map<String, Object>> properties = specification.parameters().properties();
final String[] propName = {""}; String type = "object";
fnBuilder.parameters(GeminiSchema.builder() String description = specification.description();
.type(GeminiType.OBJECT) fnBuilder.parameters(fromMap(type, null, null, properties));
.properties(parameters.properties().entrySet().stream()
.map(prop -> {
propName[0] = prop.getKey();
Map<String, Object> propAttributes = prop.getValue();
String typeString = (String) propAttributes.getOrDefault("type", "string");
GeminiType type = GeminiType.valueOf(typeString.toUpperCase());
String description = (String) propAttributes.getOrDefault("description", null);
//TODO need to deal with nested objects
return GeminiSchema.builder()
.description(description)
.type(type)
.build();
})
.collect(Collectors.toMap(schema -> propName[0], schema -> schema)))
.build());
} }
return fnBuilder.build(); return fnBuilder.build();
}) })
.filter(Objects::nonNull) .filter(Objects::nonNull)
@ -77,10 +62,50 @@ class FunctionMapper {
return tool.build(); return tool.build();
} }
private static GeminiSchema fromMap(String type, String arrayType, String description, Map<String, Map<String, Object>> obj) {
GeminiSchema.GeminiSchemaBuilder schemaBuilder = GeminiSchema.builder();
schemaBuilder.type(GeminiType.valueOf(type.toUpperCase()));
schemaBuilder.description(description);
if (type.equals("array")) {
Map<String, Map<String, Object>> arrayObj = (Map<String, Map<String, Object>>) obj.values().iterator().next().get("properties");
schemaBuilder.items(fromMap(arrayType, null, description, arrayObj));
} else {
Map<String, GeminiSchema> props = new LinkedHashMap<>();
if (obj != null) {
for (Map.Entry<String, Map<String, Object>> oneProperty : obj.entrySet()) {
String propName = oneProperty.getKey();
Map<String, Object> propAttributes = oneProperty.getValue();
String propTypeString = (String) propAttributes.getOrDefault("type", "string");
String propDescription = (String) propAttributes.getOrDefault("description", null);
Map<String, Map<String, Object>> childProps =
(Map<String, Map<String, Object>>) propAttributes.getOrDefault("properties", emptyMap());
Map<String, Object> items = (Map<String, Object>) propAttributes.get("items");
Map<String, Map<String, Object>> singleProp = new HashMap<>();
singleProp.put(propName, items);
if (items != null) {
String itemsType = items.get("type").toString();
props.put(propName, fromMap(propTypeString, itemsType, propDescription, singleProp));
} else {
props.put(propName, fromMap(propTypeString, null, propDescription, childProps));
}
}
}
schemaBuilder.properties(props);
}
return schemaBuilder.build();
}
static List<ToolExecutionRequest> fromToolExecReqToGFunCall(List<GeminiFunctionCall> functionCalls) { static List<ToolExecutionRequest> fromToolExecReqToGFunCall(List<GeminiFunctionCall> functionCalls) {
return functionCalls.stream().map(functionCall -> ToolExecutionRequest.builder() return functionCalls.stream()
.map(functionCall -> ToolExecutionRequest.builder()
.name(functionCall.getName()) .name(functionCall.getName())
.arguments(GSON.toJson(functionCall.getArgs())) .arguments(GSON.toJson(functionCall.getArgs()))
.build()).collect(Collectors.toList()); .build())
.collect(Collectors.toList());
} }
} }

View File

@ -1,7 +1,6 @@
package dev.langchain4j.model.googleai; package dev.langchain4j.model.googleai;
enum GeminiType { enum GeminiType {
TYPE_UNSPECIFIED,
STRING, STRING,
NUMBER, NUMBER,
INTEGER, INTEGER,

View File

@ -0,0 +1,258 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.model.output.structured.Description;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static org.assertj.core.api.Assertions.assertThat;
public class FunctionMapperTest {
enum Projection {
WGS84,
NAD83,
PZ90,
GCJ02,
BD09
}
static class Coordinates {
@Description("latitude")
double latitude;
@Description("latitude")
double longitude;
@Description("Geographic projection system used")
Projection projection;
public Coordinates(double latitude, double longitude) {
this.latitude = latitude;
this.longitude = longitude;
this.projection = Projection.WGS84;
}
}
static class IssTool {
@Tool("Get the distance between the user and the ISS.")
int distanceBetween(
@P("user coordinates") Coordinates userCoordinates,
@P("ISS coordinates") Coordinates issCoordinates
) {
return 3456;
}
}
@Test
void should_convert_nested_structures() {
// when
List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(IssTool.class);
System.out.println("\ntoolSpecifications = " + toolSpecifications);
// then
assertThat(toolSpecifications.size()).isEqualTo(1);
assertThat(toolSpecifications.get(0).name()).isEqualTo("distanceBetween");
assertThat(toolSpecifications.get(0).description()).isEqualTo("Get the distance between the user and the ISS.");
assertThat(toolSpecifications.get(0).parameters().type()).isEqualTo("object");
assertThat(toolSpecifications.get(0).parameters().properties().size()).isEqualTo(2);
assertThat(toolSpecifications.get(0).parameters().properties().keySet()).containsAll(Arrays.asList("userCoordinates", "issCoordinates"));
// when
GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(toolSpecifications, false);
System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString()));
// then
List<GeminiFunctionDeclaration> allGFnDecl = geminiTool.getFunctionDeclarations();
assertThat(allGFnDecl.size()).isEqualTo(1);
GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0);
assertThat(gFnDecl.getName()).isEqualTo("distanceBetween");
assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT);
Map<String, GeminiSchema> props = gFnDecl.getParameters().getProperties();
assertThat(props.size()).isEqualTo(2);
assertThat(props.keySet()).containsAll(Arrays.asList("userCoordinates", "issCoordinates"));
GeminiSchema userCoord = props.get("userCoordinates");
assertThat(userCoord.getType()).isEqualTo(GeminiType.OBJECT);
GeminiSchema issCoord = props.get("issCoordinates");
assertThat(issCoord.getType()).isEqualTo(GeminiType.OBJECT);
assertThat(userCoord.getProperties().size()).isEqualTo(3);
assertThat(issCoord.getProperties().size()).isEqualTo(3);
assertThat(userCoord.getProperties().keySet()).containsAll(Arrays.asList("latitude", "longitude", "projection"));
assertThat(issCoord.getProperties().keySet()).containsAll(Arrays.asList("latitude", "longitude", "projection"));
}
static class Address {
private final String street;
private final String zipCode;
private final String city;
public Address(String street, String zipCode, String city) {
this.street = street;
this.zipCode = zipCode;
this.city = city;
}
}
static class Customer {
private final String firstname;
private final String lastname;
private final Address shippingAddress;
// private final Address billingAddress;
public Customer(String firstname, String lastname,
Address shippingAddress
// Address billingAddress
) {
this.firstname = firstname;
this.lastname = lastname;
this.shippingAddress = shippingAddress;
// this.billingAddress = billingAddress;
}
}
static class Product {
private final String name;
private final String description;
private final double price;
public Product(String name, String description, double price) {
this.name = name;
this.description = description;
this.price = price;
}
}
static class LineItem {
private final Product product;
private final int quantity;
public LineItem(int quantity, Product product) {
this.product = product;
this.quantity = quantity;
}
}
static class Order {
private final Double totalAmount;
private final List<LineItem> lineItems;
private final Customer customer;
public Order(Double totalAmount, List<LineItem> lineItems, Customer customer) {
this.totalAmount = totalAmount;
this.lineItems = lineItems;
this.customer = customer;
}
}
static class OrderSystem {
@Tool("Make an order")
boolean makeOrder(@P(value = "The order to make") Order order) {
return true;
}
}
@Test
void testComplexNestedGraph() {
// given
List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(OrderSystem.class);
System.out.println("\ntoolSpecifications = " + toolSpecifications);
// when
GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(toolSpecifications, false);
System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString()));
// then
List<GeminiFunctionDeclaration> allGFnDecl = geminiTool.getFunctionDeclarations();
assertThat(allGFnDecl.size()).isEqualTo(1);
GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0);
assertThat(gFnDecl.getName()).isEqualTo("makeOrder");
assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT);
Map<String, GeminiSchema> props = gFnDecl.getParameters().getProperties();
assertThat(props.size()).isEqualTo(1);
assertThat(props.keySet()).containsExactly("order");
GeminiSchema orderSchema = props.get("order");
assertThat(orderSchema.getType()).isEqualTo(GeminiType.OBJECT);
assertThat(orderSchema.getProperties().size()).isEqualTo(3);
assertThat(orderSchema.getProperties().keySet()).containsAll(Arrays.asList("totalAmount", "lineItems", "customer"));
GeminiSchema totalAmount = orderSchema.getProperties().get("totalAmount");
assertThat(totalAmount.getType()).isEqualTo(GeminiType.NUMBER);
GeminiSchema lineItems = orderSchema.getProperties().get("lineItems");
assertThat(lineItems.getType()).isEqualTo(GeminiType.ARRAY);
GeminiSchema lineItemsItems = lineItems.getItems();
assertThat(lineItemsItems.getType()).isEqualTo(GeminiType.OBJECT);
assertThat(lineItemsItems.getProperties().size()).isEqualTo(2);
assertThat(lineItemsItems.getProperties().keySet()).containsAll(Arrays.asList("product", "quantity"));
GeminiSchema product = lineItemsItems.getProperties().get("product");
assertThat(product.getType()).isEqualTo(GeminiType.OBJECT);
assertThat(product.getProperties().size()).isEqualTo(3);
assertThat(product.getProperties().keySet()).containsAll(Arrays.asList("name", "description", "price"));
GeminiSchema customer = orderSchema.getProperties().get("customer");
assertThat(customer.getType()).isEqualTo(GeminiType.OBJECT);
assertThat(customer.getProperties().size()).isEqualTo(3);
assertThat(customer.getProperties().keySet()).containsAll(Arrays.asList("firstname", "lastname", "shippingAddress"));
GeminiSchema shippingAddress = customer.getProperties().get("shippingAddress");
assertThat(shippingAddress.getType()).isEqualTo(GeminiType.OBJECT);
assertThat(shippingAddress.getProperties().size()).isEqualTo(3);
assertThat(shippingAddress.getProperties().keySet()).containsAll(Arrays.asList("street", "zipCode", "city"));
}
@Test
void testArray() {
// given
ToolSpecification spec = ToolSpecification.builder()
.name("toolName")
.description("tool description")
.addParameter("arrayParameter", ARRAY, items(STRING), description("an array"))
.build();
System.out.println("\nspec = " + spec);
// when
GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(Arrays.asList(spec), false);
System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString()));
// then
List<GeminiFunctionDeclaration> allGFnDecl = geminiTool.getFunctionDeclarations();
assertThat(allGFnDecl.size()).isEqualTo(1);
GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0);
assertThat(gFnDecl.getName()).isEqualTo("toolName");
assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT);
Map<String, GeminiSchema> props = gFnDecl.getParameters().getProperties();
System.out.println("props = " + withoutNullValues(props.toString()));
assertThat(props.size()).isEqualTo(1);
assertThat(props.keySet()).containsExactly("arrayParameter");
GeminiSchema arrayParameter = props.get("arrayParameter");
assertThat(arrayParameter.getType()).isEqualTo(GeminiType.ARRAY);
assertThat(arrayParameter.getItems().getType()).isEqualTo(GeminiType.STRING);
assertThat(arrayParameter.getItems().getDescription()).isEqualTo("an array");
assertThat(arrayParameter.getItems().getItems()).isNull();
assertThat(arrayParameter.getItems().getProperties()).isEmpty();
}
private static String withoutNullValues(String toString) {
return toString
.replaceAll("(, )?(?<=(, |\\())[^\\s(]+?=null(?:, )?", " ")
.replaceFirst(", \\)$", ")");
}
}

View File

@ -31,6 +31,7 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
import static dev.langchain4j.internal.TypeUtils.isJsonBoolean; import static dev.langchain4j.internal.TypeUtils.isJsonBoolean;
import static dev.langchain4j.internal.TypeUtils.isJsonInteger; import static dev.langchain4j.internal.TypeUtils.isJsonInteger;
import static dev.langchain4j.internal.TypeUtils.isJsonNumber; import static dev.langchain4j.internal.TypeUtils.isJsonNumber;
import static dev.langchain4j.internal.TypeUtils.isJsonString;
import static dev.langchain4j.service.TypeUtils.getRawClass; import static dev.langchain4j.service.TypeUtils.getRawClass;
import static dev.langchain4j.service.TypeUtils.resolveFirstGenericParameterClass; import static dev.langchain4j.service.TypeUtils.resolveFirstGenericParameterClass;
import static dev.langchain4j.service.TypeUtils.typeHasRawClass; import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
@ -126,7 +127,7 @@ public class JsonSchemas {
private static JsonSchemaElement jsonSchema(Class<?> clazz, Type type, String fieldDescription) { private static JsonSchemaElement jsonSchema(Class<?> clazz, Type type, String fieldDescription) {
if (clazz == String.class) { if (isJsonString(clazz)) {
return JsonStringSchema.builder() return JsonStringSchema.builder()
.description(fieldDescription) .description(fieldDescription)
.build(); .build();