Allow user to modify text key (#1723)

## Issue
Closes #1722

## Change
Parameterized the field name used to look up the text value in Weaviate.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x ] There are no breaking changes
- [? ] I have added unit and integration tests for my change (Based on
the test that are there I am not really sure what to add that would
really test it well)
- [x ] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
Jaland 2024-09-10 04:00:32 -04:00 committed by GitHub
parent e1f2b1729e
commit 86353f9b13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 8 deletions

View File

@ -39,7 +39,6 @@ import static java.util.stream.Collectors.toList;
*/ */
public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> { public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final String METADATA_TEXT_SEGMENT = "text";
private static final String ADDITIONALS = "_additional"; private static final String ADDITIONALS = "_additional";
private static final String METADATA = "_metadata"; private static final String METADATA = "_metadata";
private static final String NULL_VALUE = "<null>"; private static final String NULL_VALUE = "<null>";
@ -49,6 +48,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
private final boolean avoidDups; private final boolean avoidDups;
private final String consistencyLevel; private final String consistencyLevel;
private final Collection<String> metadataKeys; private final Collection<String> metadataKeys;
private final String textFieldName;
/** /**
* Creates a new WeaviateEmbeddingStore instance. * Creates a new WeaviateEmbeddingStore instance.
@ -67,6 +67,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
* @param useGrpcForInserts Use GRPC instead of HTTP for batch inserts only. <b>You still need HTTP configured for search</b> * @param useGrpcForInserts Use GRPC instead of HTTP for batch inserts only. <b>You still need HTTP configured for search</b>
* @param securedGrpc The GRPC connection is secured * @param securedGrpc The GRPC connection is secured
* @param grpcPort The port, e.g. 50051. This parameter is optional. * @param grpcPort The port, e.g. 50051. This parameter is optional.
* @param textFieldName The name of the field that contains the text of a {@link TextSegment}. Default is "text".
*/ */
@Builder @Builder
public WeaviateEmbeddingStore( public WeaviateEmbeddingStore(
@ -80,7 +81,8 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
String objectClass, String objectClass,
Boolean avoidDups, Boolean avoidDups,
String consistencyLevel, String consistencyLevel,
Collection<String> metadataKeys Collection<String> metadataKeys,
String textFieldName
) { ) {
try { try {
@ -104,6 +106,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
this.avoidDups = getOrDefault(avoidDups, true); this.avoidDups = getOrDefault(avoidDups, true);
this.consistencyLevel = getOrDefault(consistencyLevel, QUORUM); this.consistencyLevel = getOrDefault(consistencyLevel, QUORUM);
this.metadataKeys = getOrDefault(metadataKeys, Collections.emptyList()); this.metadataKeys = getOrDefault(metadataKeys, Collections.emptyList());
this.textFieldName = getOrDefault(textFieldName, "text");
} }
private static String concatenate(String host, Integer port) { private static String concatenate(String host, Integer port) {
@ -180,7 +183,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
double minCertainty double minCertainty
) { ) {
List<Field> fields = new ArrayList<>(); List<Field> fields = new ArrayList<>();
fields.add(Field.builder().name(METADATA_TEXT_SEGMENT).build()); fields.add(Field.builder().name(textFieldName).build());
fields.add(Field fields.add(Field
.builder() .builder()
.name(ADDITIONALS) .name(ADDITIONALS)
@ -236,7 +239,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
List<Map<String, ?>> resItems = ((Map.Entry<String, List<Map<String, ?>>>) resItemsPart.get()).getValue(); List<Map<String, ?>> resItems = ((Map.Entry<String, List<Map<String, ?>>>) resItemsPart.get()).getValue();
return resItems.stream().map(WeaviateEmbeddingStore::toEmbeddingMatch).collect(toList()); return resItems.stream().map(item -> toEmbeddingMatch(item)).collect(toList());
} }
private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) { private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
@ -264,7 +267,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
Map<String, Object> props = new HashMap<>(); Map<String, Object> props = new HashMap<>();
Map<String, Object> metadata = prefillMetadata(); Map<String, Object> metadata = prefillMetadata();
if (segment != null) { if (segment != null) {
props.put(METADATA_TEXT_SEGMENT, segment.text()); props.put(textFieldName, segment.text());
if (!segment.metadata().toMap().isEmpty()) { if (!segment.metadata().toMap().isEmpty()) {
for (String property : metadataKeys) { for (String property : metadataKeys) {
if (segment.metadata().containsKey(property)) { if (segment.metadata().containsKey(property)) {
@ -274,7 +277,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
setMetadata(props, metadata); setMetadata(props, metadata);
} else { } else {
props.put(METADATA_TEXT_SEGMENT, ""); props.put(textFieldName, "");
setMetadata(props, metadata); setMetadata(props, metadata);
} }
props.put("indexFilterable", true); props.put("indexFilterable", true);
@ -302,7 +305,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
return metadata; return metadata;
} }
private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) { private EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS); Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS);
final Metadata metadata = new Metadata(); final Metadata metadata = new Metadata();
if (item.get(METADATA) != null && item.get(METADATA) instanceof Map) { if (item.get(METADATA) != null && item.get(METADATA) instanceof Map) {
@ -313,7 +316,7 @@ public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
} }
} }
String text = (String) item.get(METADATA_TEXT_SEGMENT); String text = (String) item.get(textFieldName);
return new EmbeddingMatch<>( return new EmbeddingMatch<>(
(Double) additional.get("certainty"), (Double) additional.get("certainty"),