From 421b4cd048fdc113881c4c63effe24ff3f5746ac Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 9 Sep 2024 10:15:23 +0200 Subject: [PATCH] =?UTF-8?q?bug(Google=20AI=20Gemini)=20=E2=80=94=20fix=20m?= =?UTF-8?q?apping=20for=20tools=20with=20parameters=20with=20nested=20obje?= =?UTF-8?q?ct=20structures=20(#1732)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bug(Google AI Gemini) — fix mapping for tools with parameters with nested object structures * fix(JsonSchemas) — minor improvement to cover all cases of strings --- .../dev/langchain4j/internal/TypeUtils.java | 9 + .../langchain4j/internal/TypeUtilsTest.java | 36 +++ .../model/googleai/FunctionMapper.java | 79 ++++-- .../model/googleai/GeminiType.java | 1 - .../model/googleai/FunctionMapperTest.java | 258 ++++++++++++++++++ .../service/output/JsonSchemas.java | 3 +- 6 files changed, 357 insertions(+), 29 deletions(-) create mode 100644 langchain4j-core/src/test/java/dev/langchain4j/internal/TypeUtilsTest.java create mode 100644 langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/FunctionMapperTest.java diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/TypeUtils.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/TypeUtils.java index 7b0bb1783..46e659523 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/TypeUtils.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/TypeUtils.java @@ -22,4 +22,13 @@ public class TypeUtils { public static boolean isJsonBoolean(Class type) { return type == boolean.class || type == Boolean.class; } + + public static boolean isJsonString(Class type) { + return type == String.class || type == char.class || type == Character.class + || CharSequence.class.isAssignableFrom(type); + } + + public static boolean isJsonArray(Class type) { + return type.isArray() || Iterable.class.isAssignableFrom(type); + } } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/TypeUtilsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/TypeUtilsTest.java new file mode 100644 index 000000000..9fa774d33 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/TypeUtilsTest.java @@ -0,0 +1,36 @@ +package dev.langchain4j.internal; + +import org.junit.jupiter.api.Test; + +import java.util.Collection; +import java.util.Deque; +import java.util.List; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +public class TypeUtilsTest { + + @Test + void testStringIsJsonCompatible() { + assertThat(TypeUtils.isJsonString(char.class)).isTrue(); + assertThat(TypeUtils.isJsonString(String.class)).isTrue(); + assertThat(TypeUtils.isJsonString(Character.class)).isTrue(); + assertThat(TypeUtils.isJsonString(StringBuffer.class)).isTrue(); + assertThat(TypeUtils.isJsonString(StringBuilder.class)).isTrue(); + assertThat(TypeUtils.isJsonString(CharSequence.class)).isTrue(); + } + + @Test + void testCollectionIsJsonCompatible() { + assertThat(TypeUtils.isJsonArray(String[].class)).isTrue(); + assertThat(TypeUtils.isJsonArray(Integer[].class)).isTrue(); + assertThat(TypeUtils.isJsonArray(int[].class)).isTrue(); + + assertThat(TypeUtils.isJsonArray(List.class)).isTrue(); + assertThat(TypeUtils.isJsonArray(Set.class)).isTrue(); + assertThat(TypeUtils.isJsonArray(Deque.class)).isTrue(); + assertThat(TypeUtils.isJsonArray(Collection.class)).isTrue(); + assertThat(TypeUtils.isJsonArray(Iterable.class)).isTrue(); + } +} diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/FunctionMapper.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/FunctionMapper.java index dcefce729..f90190bc3 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/FunctionMapper.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/FunctionMapper.java @@ -2,12 +2,13 @@ package dev.langchain4j.model.googleai; import com.google.gson.Gson; import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; import java.util.*; import java.util.stream.Collectors; +import static java.util.Collections.emptyMap; + class FunctionMapper { private static final Gson GSON = new Gson(); @@ -34,6 +35,7 @@ class FunctionMapper { .map(specification -> { GeminiFunctionDeclaration.GeminiFunctionDeclarationBuilder fnBuilder = GeminiFunctionDeclaration.builder(); + if (specification.name() != null) { fnBuilder.name(specification.name()); } @@ -41,30 +43,13 @@ class FunctionMapper { fnBuilder.description(specification.description()); } if (specification.parameters() != null) { - ToolParameters parameters = specification.parameters(); + Map> properties = specification.parameters().properties(); - final String[] propName = {""}; - fnBuilder.parameters(GeminiSchema.builder() - .type(GeminiType.OBJECT) - .properties(parameters.properties().entrySet().stream() - .map(prop -> { - propName[0] = prop.getKey(); - Map propAttributes = prop.getValue(); - - String typeString = (String) propAttributes.getOrDefault("type", "string"); - GeminiType type = GeminiType.valueOf(typeString.toUpperCase()); - String description = (String) propAttributes.getOrDefault("description", null); - - //TODO need to deal with nested objects - - return GeminiSchema.builder() - .description(description) - .type(type) - .build(); - }) - .collect(Collectors.toMap(schema -> propName[0], schema -> schema))) - .build()); + String type = "object"; + String description = specification.description(); + fnBuilder.parameters(fromMap(type, null, null, properties)); } + return fnBuilder.build(); }) .filter(Objects::nonNull) @@ -77,10 +62,50 @@ class FunctionMapper { return tool.build(); } + private static GeminiSchema fromMap(String type, String arrayType, String description, Map> obj) { + GeminiSchema.GeminiSchemaBuilder schemaBuilder = GeminiSchema.builder(); + + schemaBuilder.type(GeminiType.valueOf(type.toUpperCase())); + schemaBuilder.description(description); + + if (type.equals("array")) { + Map> arrayObj = (Map>) obj.values().iterator().next().get("properties"); + + schemaBuilder.items(fromMap(arrayType, null, description, arrayObj)); + } else { + Map props = new LinkedHashMap<>(); + if (obj != null) { + for (Map.Entry> oneProperty : obj.entrySet()) { + String propName = oneProperty.getKey(); + Map propAttributes = oneProperty.getValue(); + String propTypeString = (String) propAttributes.getOrDefault("type", "string"); + String propDescription = (String) propAttributes.getOrDefault("description", null); + Map> childProps = + (Map>) propAttributes.getOrDefault("properties", emptyMap()); + Map items = (Map) propAttributes.get("items"); + Map> singleProp = new HashMap<>(); + singleProp.put(propName, items); + + if (items != null) { + String itemsType = items.get("type").toString(); + props.put(propName, fromMap(propTypeString, itemsType, propDescription, singleProp)); + } else { + props.put(propName, fromMap(propTypeString, null, propDescription, childProps)); + } + } + } + schemaBuilder.properties(props); + } + + return schemaBuilder.build(); + } + static List fromToolExecReqToGFunCall(List functionCalls) { - return functionCalls.stream().map(functionCall -> ToolExecutionRequest.builder() - .name(functionCall.getName()) - .arguments(GSON.toJson(functionCall.getArgs())) - .build()).collect(Collectors.toList()); + return functionCalls.stream() + .map(functionCall -> ToolExecutionRequest.builder() + .name(functionCall.getName()) + .arguments(GSON.toJson(functionCall.getArgs())) + .build()) + .collect(Collectors.toList()); } } diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiType.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiType.java index 3762a62d2..dd00275fd 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiType.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiType.java @@ -1,7 +1,6 @@ package dev.langchain4j.model.googleai; enum GeminiType { - TYPE_UNSPECIFIED, STRING, NUMBER, INTEGER, diff --git a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/FunctionMapperTest.java b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/FunctionMapperTest.java new file mode 100644 index 000000000..3277e1e14 --- /dev/null +++ b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/FunctionMapperTest.java @@ -0,0 +1,258 @@ +package dev.langchain4j.model.googleai; + +import dev.langchain4j.agent.tool.*; +import dev.langchain4j.model.output.structured.Description; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; +import static org.assertj.core.api.Assertions.assertThat; + +public class FunctionMapperTest { + + enum Projection { + WGS84, + NAD83, + PZ90, + GCJ02, + BD09 + } + + static class Coordinates { + @Description("latitude") + double latitude; + @Description("latitude") + double longitude; + @Description("Geographic projection system used") + Projection projection; + + public Coordinates(double latitude, double longitude) { + this.latitude = latitude; + this.longitude = longitude; + this.projection = Projection.WGS84; + } + } + + static class IssTool { + @Tool("Get the distance between the user and the ISS.") + int distanceBetween( + @P("user coordinates") Coordinates userCoordinates, + @P("ISS coordinates") Coordinates issCoordinates + ) { + return 3456; + } + } + + @Test + void should_convert_nested_structures() { + // when + List toolSpecifications = ToolSpecifications.toolSpecificationsFrom(IssTool.class); + System.out.println("\ntoolSpecifications = " + toolSpecifications); + + // then + assertThat(toolSpecifications.size()).isEqualTo(1); + assertThat(toolSpecifications.get(0).name()).isEqualTo("distanceBetween"); + assertThat(toolSpecifications.get(0).description()).isEqualTo("Get the distance between the user and the ISS."); + assertThat(toolSpecifications.get(0).parameters().type()).isEqualTo("object"); + assertThat(toolSpecifications.get(0).parameters().properties().size()).isEqualTo(2); + assertThat(toolSpecifications.get(0).parameters().properties().keySet()).containsAll(Arrays.asList("userCoordinates", "issCoordinates")); + + // when + GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(toolSpecifications, false); + System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString())); + + // then + List allGFnDecl = geminiTool.getFunctionDeclarations(); + assertThat(allGFnDecl.size()).isEqualTo(1); + + GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0); + assertThat(gFnDecl.getName()).isEqualTo("distanceBetween"); + + assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT); + Map props = gFnDecl.getParameters().getProperties(); + + assertThat(props.size()).isEqualTo(2); + assertThat(props.keySet()).containsAll(Arrays.asList("userCoordinates", "issCoordinates")); + + GeminiSchema userCoord = props.get("userCoordinates"); + assertThat(userCoord.getType()).isEqualTo(GeminiType.OBJECT); + + GeminiSchema issCoord = props.get("issCoordinates"); + assertThat(issCoord.getType()).isEqualTo(GeminiType.OBJECT); + + assertThat(userCoord.getProperties().size()).isEqualTo(3); + assertThat(issCoord.getProperties().size()).isEqualTo(3); + + assertThat(userCoord.getProperties().keySet()).containsAll(Arrays.asList("latitude", "longitude", "projection")); + assertThat(issCoord.getProperties().keySet()).containsAll(Arrays.asList("latitude", "longitude", "projection")); + } + + static class Address { + private final String street; + private final String zipCode; + private final String city; + + public Address(String street, String zipCode, String city) { + this.street = street; + this.zipCode = zipCode; + this.city = city; + } + } + + static class Customer { + private final String firstname; + private final String lastname; + + private final Address shippingAddress; +// private final Address billingAddress; + + public Customer(String firstname, String lastname, + Address shippingAddress +// Address billingAddress + ) { + this.firstname = firstname; + this.lastname = lastname; + this.shippingAddress = shippingAddress; +// this.billingAddress = billingAddress; + } + } + + static class Product { + private final String name; + private final String description; + private final double price; + + public Product(String name, String description, double price) { + this.name = name; + this.description = description; + this.price = price; + } + } + + static class LineItem { + private final Product product; + private final int quantity; + + public LineItem(int quantity, Product product) { + this.product = product; + this.quantity = quantity; + } + } + + static class Order { + private final Double totalAmount; + private final List lineItems; + private final Customer customer; + + public Order(Double totalAmount, List lineItems, Customer customer) { + this.totalAmount = totalAmount; + this.lineItems = lineItems; + this.customer = customer; + } + } + + static class OrderSystem { + @Tool("Make an order") + boolean makeOrder(@P(value = "The order to make") Order order) { + return true; + } + } + + @Test + void testComplexNestedGraph() { + // given + List toolSpecifications = ToolSpecifications.toolSpecificationsFrom(OrderSystem.class); + System.out.println("\ntoolSpecifications = " + toolSpecifications); + + // when + GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(toolSpecifications, false); + System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString())); + + // then + List allGFnDecl = geminiTool.getFunctionDeclarations(); + assertThat(allGFnDecl.size()).isEqualTo(1); + + GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0); + assertThat(gFnDecl.getName()).isEqualTo("makeOrder"); + assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT); + + Map props = gFnDecl.getParameters().getProperties(); + assertThat(props.size()).isEqualTo(1); + assertThat(props.keySet()).containsExactly("order"); + + GeminiSchema orderSchema = props.get("order"); + assertThat(orderSchema.getType()).isEqualTo(GeminiType.OBJECT); + assertThat(orderSchema.getProperties().size()).isEqualTo(3); + assertThat(orderSchema.getProperties().keySet()).containsAll(Arrays.asList("totalAmount", "lineItems", "customer")); + + GeminiSchema totalAmount = orderSchema.getProperties().get("totalAmount"); + assertThat(totalAmount.getType()).isEqualTo(GeminiType.NUMBER); + + GeminiSchema lineItems = orderSchema.getProperties().get("lineItems"); + assertThat(lineItems.getType()).isEqualTo(GeminiType.ARRAY); + + GeminiSchema lineItemsItems = lineItems.getItems(); + assertThat(lineItemsItems.getType()).isEqualTo(GeminiType.OBJECT); + assertThat(lineItemsItems.getProperties().size()).isEqualTo(2); + assertThat(lineItemsItems.getProperties().keySet()).containsAll(Arrays.asList("product", "quantity")); + + GeminiSchema product = lineItemsItems.getProperties().get("product"); + assertThat(product.getType()).isEqualTo(GeminiType.OBJECT); + assertThat(product.getProperties().size()).isEqualTo(3); + assertThat(product.getProperties().keySet()).containsAll(Arrays.asList("name", "description", "price")); + + GeminiSchema customer = orderSchema.getProperties().get("customer"); + assertThat(customer.getType()).isEqualTo(GeminiType.OBJECT); + assertThat(customer.getProperties().size()).isEqualTo(3); + assertThat(customer.getProperties().keySet()).containsAll(Arrays.asList("firstname", "lastname", "shippingAddress")); + + GeminiSchema shippingAddress = customer.getProperties().get("shippingAddress"); + assertThat(shippingAddress.getType()).isEqualTo(GeminiType.OBJECT); + assertThat(shippingAddress.getProperties().size()).isEqualTo(3); + assertThat(shippingAddress.getProperties().keySet()).containsAll(Arrays.asList("street", "zipCode", "city")); + } + + @Test + void testArray() { + // given + ToolSpecification spec = ToolSpecification.builder() + .name("toolName") + .description("tool description") + .addParameter("arrayParameter", ARRAY, items(STRING), description("an array")) + .build(); + + System.out.println("\nspec = " + spec); + + // when + GeminiTool geminiTool = FunctionMapper.fromToolSepcsToGTool(Arrays.asList(spec), false); + System.out.println("\ngeminiTool = " + withoutNullValues(geminiTool.toString())); + + // then + List allGFnDecl = geminiTool.getFunctionDeclarations(); + assertThat(allGFnDecl.size()).isEqualTo(1); + GeminiFunctionDeclaration gFnDecl = allGFnDecl.get(0); + assertThat(gFnDecl.getName()).isEqualTo("toolName"); + assertThat(gFnDecl.getParameters().getType()).isEqualTo(GeminiType.OBJECT); + + Map props = gFnDecl.getParameters().getProperties(); + System.out.println("props = " + withoutNullValues(props.toString())); + assertThat(props.size()).isEqualTo(1); + assertThat(props.keySet()).containsExactly("arrayParameter"); + + GeminiSchema arrayParameter = props.get("arrayParameter"); + assertThat(arrayParameter.getType()).isEqualTo(GeminiType.ARRAY); + assertThat(arrayParameter.getItems().getType()).isEqualTo(GeminiType.STRING); + assertThat(arrayParameter.getItems().getDescription()).isEqualTo("an array"); + assertThat(arrayParameter.getItems().getItems()).isNull(); + assertThat(arrayParameter.getItems().getProperties()).isEmpty(); + } + + private static String withoutNullValues(String toString) { + return toString + .replaceAll("(, )?(?<=(, |\\())[^\\s(]+?=null(?:, )?", " ") + .replaceFirst(", \\)$", ")"); + } +} diff --git a/langchain4j/src/main/java/dev/langchain4j/service/output/JsonSchemas.java b/langchain4j/src/main/java/dev/langchain4j/service/output/JsonSchemas.java index 81ade79d5..c74de6ece 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/output/JsonSchemas.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/output/JsonSchemas.java @@ -31,6 +31,7 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon import static dev.langchain4j.internal.TypeUtils.isJsonBoolean; import static dev.langchain4j.internal.TypeUtils.isJsonInteger; import static dev.langchain4j.internal.TypeUtils.isJsonNumber; +import static dev.langchain4j.internal.TypeUtils.isJsonString; import static dev.langchain4j.service.TypeUtils.getRawClass; import static dev.langchain4j.service.TypeUtils.resolveFirstGenericParameterClass; import static dev.langchain4j.service.TypeUtils.typeHasRawClass; @@ -126,7 +127,7 @@ public class JsonSchemas { private static JsonSchemaElement jsonSchema(Class clazz, Type type, String fieldDescription) { - if (clazz == String.class) { + if (isJsonString(clazz)) { return JsonStringSchema.builder() .description(fieldDescription) .build();