Fix a enum serialization issue leading to stackoverflow when creating schemas from classes (#1450)

- **Fixes #1447 Make vector embedding calculation in batch mode to speed
up the calculation of all embeddings for each label.**
- **Remove unneeded import**
- **[Gemini] Fix enum schema handling**
This commit is contained in:
Guillaume Laforge 2024-07-12 14:28:05 +02:00 committed by GitHub
parent 7b9366c975
commit 03aaa7676c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 3 deletions

View File

@ -76,6 +76,14 @@ public class SchemaHelper {
} else if (Collection.class.isAssignableFrom(theClass)) {
// Because of type erasure, we can't easily know the type of the items in the collection
return Schema.newBuilder().setType(Type.ARRAY).build();
} else if (theClass.isEnum()) {
List<String> enumConstantNames = Arrays.stream(theClass.getEnumConstants())
.map(Object::toString)
.collect(Collectors.toList());
return Schema.newBuilder()
.setType(Type.STRING)
.addAllEnum(enumConstantNames)
.build();
} else {
// This is some kind of object, let's go through its fields
Schema.Builder schemaBuilder = Schema.newBuilder().setType(Type.OBJECT);

View File

@ -2,8 +2,14 @@ package dev.langchain4j.model.vertexai;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Type;
import lombok.Data;
import lombok.Getter;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import static dev.langchain4j.model.vertexai.SchemaHelper.fromClass;
import static org.assertj.core.api.Assertions.assertThat;
public class SchemaHelperTest {
@ -17,10 +23,10 @@ public class SchemaHelperTest {
public int age;
public boolean isStudent;
public String[] friends;
};
}
// when
Schema schema = SchemaHelper.fromClass(Person.class);
Schema schema = fromClass(Person.class);
System.out.println("schema = " + schema);
// then
@ -32,7 +38,7 @@ public class SchemaHelperTest {
assertThat(schema.getPropertiesMap().get("friends").getType()).isEqualTo(Type.ARRAY);
assertThat(schema.getPropertiesMap().get("friends").getItems().getType()).isEqualTo(Type.STRING);
}
@Test
void should_convert_json_schema_string_into_schema() {
@ -78,6 +84,38 @@ public class SchemaHelperTest {
assertThat(schema.getPropertiesMap().get("artist-adult").getType()).isEqualTo(Type.BOOLEAN);
assertThat(schema.getPropertiesMap().get("artist-pets").getType()).isEqualTo(Type.ARRAY);
assertThat(schema.getPropertiesMap().get("artist-pets").getItems().getType()).isEqualTo(Type.STRING);
}
@Getter
enum Sentiment {
POSITIVE(1.0), NEUTRAL(0.0), NEGATIVE(-1.0);
private final double value;
Sentiment(double val) {
this.value = val;
}
}
@Data
static class SentimentClassification {
private Sentiment sentiment;
}
@Test
void should_support_enum_schema_without_stackoverflow() {
// given
Schema schemaFromEnum = fromClass(SentimentClassification.class);
Schema expectedSchema = Schema.newBuilder()
.setType(Type.OBJECT)
.putProperties("sentiment", Schema.newBuilder()
.setType(Type.STRING)
.addAllEnum(Arrays.asList("POSITIVE", "NEUTRAL", "NEGATIVE"))
.build())
.addAllRequired(Collections.singletonList("sentiment"))
.build();
// then
assertThat(schemaFromEnum).isEqualTo(expectedSchema);
}
}