## Issue Contributes to #1636 ## Change Get rid of Lombok in langchain4j-core: Run Delombok refactoring in IntelliJ IDEA to remove Lombok annotations and replace them with the equivalent Java code. This pull request focuses on removing the Lombok dependency and replacing it with manually implemented builder patterns across several classes. Additionally, it includes some minor code improvements. ## General checklist - [x] There are no breaking changes - [ ] I have added unit and integration tests for my change - [ ] I have manually run all the unit tests in all modules, and they are all green - [ ] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
09faf9e55f
commit
99dec54a59
|
@ -0,0 +1,17 @@
|
|||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
indent_size = 4
|
||||
indent_style = space
|
||||
insert_final_newline = true
|
||||
max_line_length = 100
|
||||
tab_width = 4
|
||||
|
||||
[*.java]
|
||||
ij_java_names_count_to_use_import_on_demand = 999
|
||||
ij_java_class_count_to_use_import_on_demand = 999
|
||||
|
||||
[{*.yaml,*.yml}]
|
||||
indent_size = 2
|
|
@ -5,8 +5,9 @@ Thank you for investing your time and effort in contributing to our project, we
|
|||
- If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly.
|
||||
- Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/)
|
||||
- Keep the code compatible with Java 17.
|
||||
- Avoid adding new dependencies as much as possible (new dependencies with test scope are OK). If absolutely necessary, try to use the same libraries which are already used in the project.
|
||||
- Avoid adding new dependencies as much as possible (new dependencies with test scope are OK). If absolutely necessary, try to use the same libraries which are already used in the project. Make sure you run `mvn dependency:analyze` to identify unnecessary dependencies.
|
||||
- Write unit and/or integration tests for your code. This is critical: no tests, no review!
|
||||
- Make sure you run all unit tests on all modules with `mvn clean test`
|
||||
- Avoid making breaking changes. Always keep backward compatibility in mind. For example, instead of removing fields/methods/etc, mark them `@Deprecated` and make sure they still work as before.
|
||||
- Follow existing naming conventions.
|
||||
- Avoid using Lombok in the new code, and remove it from the old code if you get a chance.
|
||||
|
|
|
@ -34,12 +34,6 @@
|
|||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
|
|
|
@ -2,7 +2,9 @@ package dev.langchain4j;
|
|||
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
import static java.lang.annotation.ElementType.*;
|
||||
import static java.lang.annotation.ElementType.CONSTRUCTOR;
|
||||
import static java.lang.annotation.ElementType.METHOD;
|
||||
import static java.lang.annotation.ElementType.TYPE;
|
||||
|
||||
/**
|
||||
* Indicates that a class/constructor/method is experimental and might change in the future.
|
||||
|
|
|
@ -3,7 +3,12 @@ package dev.langchain4j.data.document;
|
|||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Exceptions.runtime;
|
||||
|
@ -67,8 +72,8 @@ public class Metadata {
|
|||
validate(key, value);
|
||||
if (!SUPPORTED_VALUE_TYPES.contains(value.getClass())) {
|
||||
throw illegalArgument("The metadata key '%s' has the value '%s', which is of the unsupported type '%s'. " +
|
||||
"Currently, the supported types are: %s",
|
||||
key, value, value.getClass().getName(), SUPPORTED_VALUE_TYPES
|
||||
"Currently, the supported types are: %s",
|
||||
key, value, value.getClass().getName(), SUPPORTED_VALUE_TYPES
|
||||
);
|
||||
}
|
||||
});
|
||||
|
@ -116,7 +121,7 @@ public class Metadata {
|
|||
}
|
||||
|
||||
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
|
||||
"It cannot be returned as a String.", key, value, value.getClass().getName());
|
||||
"It cannot be returned as a String.", key, value, value.getClass().getName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -140,7 +145,7 @@ public class Metadata {
|
|||
}
|
||||
|
||||
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
|
||||
"It cannot be returned as a UUID.", key, value, value.getClass().getName());
|
||||
"It cannot be returned as a UUID.", key, value, value.getClass().getName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -172,7 +177,7 @@ public class Metadata {
|
|||
}
|
||||
|
||||
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
|
||||
"It cannot be returned as an Integer.", key, value, value.getClass().getName());
|
||||
"It cannot be returned as an Integer.", key, value, value.getClass().getName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -204,7 +209,7 @@ public class Metadata {
|
|||
}
|
||||
|
||||
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
|
||||
"It cannot be returned as a Long.", key, value, value.getClass().getName());
|
||||
"It cannot be returned as a Long.", key, value, value.getClass().getName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -236,7 +241,7 @@ public class Metadata {
|
|||
}
|
||||
|
||||
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
|
||||
"It cannot be returned as a Float.", key, value, value.getClass().getName());
|
||||
"It cannot be returned as a Float.", key, value, value.getClass().getName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -268,7 +273,7 @@ public class Metadata {
|
|||
}
|
||||
|
||||
throw runtime("Metadata entry with the key '%s' has a value of '%s' and type '%s'. " +
|
||||
"It cannot be returned as a Double.", key, value, value.getClass().getName());
|
||||
"It cannot be returned as a Double.", key, value, value.getClass().getName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -449,8 +454,8 @@ public class Metadata {
|
|||
@Override
|
||||
public String toString() {
|
||||
return "Metadata {" +
|
||||
" metadata = " + metadata +
|
||||
" }";
|
||||
" metadata = " + metadata +
|
||||
" }";
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -8,7 +8,9 @@ import java.util.Objects;
|
|||
import static dev.langchain4j.data.message.ChatMessageType.AI;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static java.util.Arrays.asList;
|
||||
|
||||
/**
|
||||
|
@ -90,7 +92,7 @@ public class AiMessage implements ChatMessage {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AiMessage that = (AiMessage) o;
|
||||
return Objects.equals(this.text, that.text)
|
||||
&& Objects.equals(this.toolExecutionRequests, that.toolExecutionRequests);
|
||||
&& Objects.equals(this.toolExecutionRequests, that.toolExecutionRequests);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -101,9 +103,9 @@ public class AiMessage implements ChatMessage {
|
|||
@Override
|
||||
public String toString() {
|
||||
return "AiMessage {" +
|
||||
" text = " + quoted(text) +
|
||||
" toolExecutionRequests = " + toolExecutionRequests +
|
||||
" }";
|
||||
" text = " + quoted(text) +
|
||||
" toolExecutionRequests = " + toolExecutionRequests +
|
||||
" }";
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -139,7 +141,7 @@ public class AiMessage implements ChatMessage {
|
|||
/**
|
||||
* Create a new {@link AiMessage} with the given text and tool execution requests.
|
||||
*
|
||||
* @param text the text of the message.
|
||||
* @param text the text of the message.
|
||||
* @param toolExecutionRequests the tool execution requests of the message.
|
||||
* @return the new {@link AiMessage}.
|
||||
*/
|
||||
|
@ -180,7 +182,7 @@ public class AiMessage implements ChatMessage {
|
|||
/**
|
||||
* Create a new {@link AiMessage} with the given text and tool execution requests.
|
||||
*
|
||||
* @param text the text of the message.
|
||||
* @param text the text of the message.
|
||||
* @param toolExecutionRequests the tool execution requests of the message.
|
||||
* @return the new {@link AiMessage}.
|
||||
*/
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package dev.langchain4j.data.message;
|
||||
|
||||
import static dev.langchain4j.data.message.ChatMessageSerializer.CODEC;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.data.message.ChatMessageSerializer.CODEC;
|
||||
|
||||
/**
|
||||
* A deserializer for {@link ChatMessage} objects.
|
||||
*/
|
||||
|
|
|
@ -1,19 +1,26 @@
|
|||
package dev.langchain4j.data.message;
|
||||
|
||||
import com.google.gson.*;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.google.gson.JsonDeserializationContext;
|
||||
import com.google.gson.JsonDeserializer;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonSerializationContext;
|
||||
import com.google.gson.JsonSerializer;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
class GsonChatMessageAdapter implements JsonDeserializer<ChatMessage>, JsonSerializer<ChatMessage> {
|
||||
|
||||
private static final Gson GSON = new GsonBuilder()
|
||||
.registerTypeAdapter(Content.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(TextContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(ImageContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(AudioContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(VideoContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(PdfFileContent.class, new GsonContentAdapter())
|
||||
.create();
|
||||
.registerTypeAdapter(Content.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(TextContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(ImageContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(AudioContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(VideoContent.class, new GsonContentAdapter())
|
||||
.registerTypeAdapter(PdfFileContent.class, new GsonContentAdapter())
|
||||
.create();
|
||||
|
||||
private static final String CHAT_MESSAGE_TYPE = "type"; // do not change, will break backward compatibility!
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package dev.langchain4j.data.message;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
/**
|
||||
* A codec for serializing and deserializing {@link ChatMessage} objects to and from JSON.
|
||||
*/
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
package dev.langchain4j.data.message;
|
||||
|
||||
import com.google.gson.*;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonDeserializationContext;
|
||||
import com.google.gson.JsonDeserializer;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonSerializationContext;
|
||||
import com.google.gson.JsonSerializer;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import dev.langchain4j.agent.tool.ToolSpecification;
|
|||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -25,7 +24,6 @@ public class ChatModelRequest {
|
|||
private final List<ChatMessage> messages;
|
||||
private final List<ToolSpecification> toolSpecifications;
|
||||
|
||||
@Builder
|
||||
public ChatModelRequest(String model,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
|
@ -40,6 +38,10 @@ public class ChatModelRequest {
|
|||
this.toolSpecifications = copyIfNotNull(toolSpecifications);
|
||||
}
|
||||
|
||||
public static ChatModelRequestBuilder builder() {
|
||||
return new ChatModelRequestBuilder();
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return model;
|
||||
}
|
||||
|
@ -63,4 +65,54 @@ public class ChatModelRequest {
|
|||
public List<ToolSpecification> toolSpecifications() {
|
||||
return toolSpecifications;
|
||||
}
|
||||
|
||||
public static class ChatModelRequestBuilder {
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Double topP;
|
||||
private Integer maxTokens;
|
||||
private List<ChatMessage> messages;
|
||||
private List<ToolSpecification> toolSpecifications;
|
||||
|
||||
ChatModelRequestBuilder() {
|
||||
}
|
||||
|
||||
public ChatModelRequestBuilder model(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelRequestBuilder temperature(Double temperature) {
|
||||
this.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelRequestBuilder topP(Double topP) {
|
||||
this.topP = topP;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelRequestBuilder maxTokens(Integer maxTokens) {
|
||||
this.maxTokens = maxTokens;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelRequestBuilder messages(List<ChatMessage> messages) {
|
||||
this.messages = messages;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelRequestBuilder toolSpecifications(List<ToolSpecification> toolSpecifications) {
|
||||
this.toolSpecifications = toolSpecifications;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelRequest build() {
|
||||
return new ChatModelRequest(this.model, this.temperature, this.topP, this.maxTokens, this.messages, this.toolSpecifications);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "ChatModelRequest.ChatModelRequestBuilder(model=" + this.model + ", temperature=" + this.temperature + ", topP=" + this.topP + ", maxTokens=" + this.maxTokens + ", messages=" + this.messages + ", toolSpecifications=" + this.toolSpecifications + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import lombok.Builder;
|
||||
|
||||
/**
|
||||
* A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel},
|
||||
|
@ -21,7 +20,6 @@ public class ChatModelResponse {
|
|||
private final FinishReason finishReason;
|
||||
private final AiMessage aiMessage;
|
||||
|
||||
@Builder
|
||||
public ChatModelResponse(String id,
|
||||
String model,
|
||||
TokenUsage tokenUsage,
|
||||
|
@ -34,6 +32,10 @@ public class ChatModelResponse {
|
|||
this.aiMessage = aiMessage;
|
||||
}
|
||||
|
||||
public static ChatModelResponseBuilder builder() {
|
||||
return new ChatModelResponseBuilder();
|
||||
}
|
||||
|
||||
public String id() {
|
||||
return id;
|
||||
}
|
||||
|
@ -53,4 +55,48 @@ public class ChatModelResponse {
|
|||
public AiMessage aiMessage() {
|
||||
return aiMessage;
|
||||
}
|
||||
|
||||
public static class ChatModelResponseBuilder {
|
||||
private String id;
|
||||
private String model;
|
||||
private TokenUsage tokenUsage;
|
||||
private FinishReason finishReason;
|
||||
private AiMessage aiMessage;
|
||||
|
||||
ChatModelResponseBuilder() {
|
||||
}
|
||||
|
||||
public ChatModelResponseBuilder id(String id) {
|
||||
this.id = id;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelResponseBuilder model(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelResponseBuilder tokenUsage(TokenUsage tokenUsage) {
|
||||
this.tokenUsage = tokenUsage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelResponseBuilder finishReason(FinishReason finishReason) {
|
||||
this.finishReason = finishReason;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelResponseBuilder aiMessage(AiMessage aiMessage) {
|
||||
this.aiMessage = aiMessage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatModelResponse build() {
|
||||
return new ChatModelResponse(this.id, this.model, this.tokenUsage, this.finishReason, this.aiMessage);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "ChatModelResponse.ChatModelResponseBuilder(id=" + this.id + ", model=" + this.model + ", tokenUsage=" + this.tokenUsage + ", finishReason=" + this.finishReason + ", aiMessage=" + this.aiMessage + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import dev.langchain4j.model.ModelDisabledException;
|
|||
import dev.langchain4j.model.output.Response;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An {@link EmbeddingModel} which throws a {@link ModelDisabledException} for all of its methods
|
||||
|
|
|
@ -2,6 +2,7 @@ package dev.langchain4j.model.image;
|
|||
|
||||
import dev.langchain4j.data.image.Image;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
|
|
@ -7,6 +7,7 @@ import com.google.gson.reflect.TypeToken;
|
|||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.spi.prompt.structured.StructuredPromptFactory;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
|
|
@ -2,7 +2,6 @@ package dev.langchain4j.rag;
|
|||
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -24,12 +23,15 @@ public class AugmentationResult {
|
|||
*/
|
||||
private final List<Content> contents;
|
||||
|
||||
@Builder
|
||||
public AugmentationResult(ChatMessage chatMessage, List<Content> contents) {
|
||||
this.chatMessage = ensureNotNull(chatMessage, "chatMessage");
|
||||
this.contents = copyIfNotNull(contents);
|
||||
}
|
||||
|
||||
public static AugmentationResultBuilder builder() {
|
||||
return new AugmentationResultBuilder();
|
||||
}
|
||||
|
||||
public ChatMessage chatMessage() {
|
||||
return chatMessage;
|
||||
}
|
||||
|
@ -37,4 +39,30 @@ public class AugmentationResult {
|
|||
public List<Content> contents() {
|
||||
return contents;
|
||||
}
|
||||
|
||||
public static class AugmentationResultBuilder {
|
||||
private ChatMessage chatMessage;
|
||||
private List<Content> contents;
|
||||
|
||||
AugmentationResultBuilder() {
|
||||
}
|
||||
|
||||
public AugmentationResultBuilder chatMessage(ChatMessage chatMessage) {
|
||||
this.chatMessage = chatMessage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public AugmentationResultBuilder contents(List<Content> contents) {
|
||||
this.contents = contents;
|
||||
return this;
|
||||
}
|
||||
|
||||
public AugmentationResult build() {
|
||||
return new AugmentationResult(this.chatMessage, this.contents);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "AugmentationResult.AugmentationResultBuilder(chatMessage=" + this.chatMessage + ", contents=" + this.contents + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,22 +14,30 @@ import dev.langchain4j.rag.query.router.DefaultQueryRouter;
|
|||
import dev.langchain4j.rag.query.router.QueryRouter;
|
||||
import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
|
||||
import dev.langchain4j.rag.query.transformer.QueryTransformer;
|
||||
import lombok.Builder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.SynchronousQueue;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static java.util.Collections.*;
|
||||
import static java.util.Collections.emptyMap;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.Collections.singletonMap;
|
||||
import static java.util.concurrent.CompletableFuture.allOf;
|
||||
import static java.util.concurrent.CompletableFuture.supplyAsync;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static java.util.stream.Collectors.*;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
import static java.util.stream.Collectors.toMap;
|
||||
|
||||
/**
|
||||
* The default implementation of {@link RetrievalAugmentor} intended to be suitable for the majority of use cases.
|
||||
|
@ -109,7 +117,6 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
private final ContentInjector contentInjector;
|
||||
private final Executor executor;
|
||||
|
||||
@Builder
|
||||
public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
|
||||
QueryRouter queryRouter,
|
||||
ContentAggregator contentAggregator,
|
||||
|
@ -124,9 +131,9 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
|
||||
private static ExecutorService createDefaultExecutor() {
|
||||
return new ThreadPoolExecutor(
|
||||
0, Integer.MAX_VALUE,
|
||||
1, SECONDS,
|
||||
new SynchronousQueue<>()
|
||||
0, Integer.MAX_VALUE,
|
||||
1, SECONDS,
|
||||
new SynchronousQueue<>()
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -160,9 +167,9 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
log(augmentedChatMessage);
|
||||
|
||||
return AugmentationResult.builder()
|
||||
.chatMessage(augmentedChatMessage)
|
||||
.contents(contents)
|
||||
.build();
|
||||
.chatMessage(augmentedChatMessage)
|
||||
.contents(contents)
|
||||
.build();
|
||||
}
|
||||
|
||||
private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
|
||||
|
@ -183,13 +190,13 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
|
||||
queries.forEach(query -> {
|
||||
CompletableFuture<Collection<List<Content>>> futureContents =
|
||||
supplyAsync(() -> {
|
||||
Collection<ContentRetriever> retrievers = queryRouter.route(query);
|
||||
log(query, retrievers);
|
||||
return retrievers;
|
||||
},
|
||||
executor
|
||||
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
|
||||
supplyAsync(() -> {
|
||||
Collection<ContentRetriever> retrievers = queryRouter.route(query);
|
||||
log(query, retrievers);
|
||||
return retrievers;
|
||||
},
|
||||
executor
|
||||
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
|
||||
queryToFutureContents.put(query, futureContents);
|
||||
});
|
||||
return join(queryToFutureContents);
|
||||
|
@ -201,15 +208,14 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
|
||||
Query query) {
|
||||
List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()
|
||||
.map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
|
||||
.collect(toList());
|
||||
.map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
|
||||
.toList();
|
||||
|
||||
return allOf(futureContents.toArray(new CompletableFuture[0]))
|
||||
.thenApply(ignored ->
|
||||
futureContents.stream()
|
||||
.map(CompletableFuture::join)
|
||||
.collect(toList())
|
||||
);
|
||||
.thenApply(ignored ->
|
||||
futureContents.stream()
|
||||
.map(CompletableFuture::join)
|
||||
.toList());
|
||||
}
|
||||
|
||||
private static List<Content> retrieve(ContentRetriever retriever, Query query) {
|
||||
|
@ -219,15 +225,15 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
}
|
||||
|
||||
private static Map<Query, Collection<List<Content>>> join(
|
||||
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents) {
|
||||
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents) {
|
||||
return allOf(queryToFutureContents.values().toArray(new CompletableFuture[0]))
|
||||
.thenApply(ignored ->
|
||||
queryToFutureContents.entrySet().stream()
|
||||
.collect(toMap(
|
||||
Map.Entry::getKey,
|
||||
entry -> entry.getValue().join()
|
||||
))
|
||||
).join();
|
||||
.thenApply(ignored ->
|
||||
queryToFutureContents.entrySet().stream()
|
||||
.collect(toMap(
|
||||
Map.Entry::getKey,
|
||||
entry -> entry.getValue().join()
|
||||
))
|
||||
).join();
|
||||
}
|
||||
|
||||
private static void logQueries(Query originalQuery, Collection<Query> queries) {
|
||||
|
@ -235,14 +241,14 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
Query transformedQuery = queries.iterator().next();
|
||||
if (!transformedQuery.equals(originalQuery)) {
|
||||
log.debug("Transformed original query '{}' into '{}'",
|
||||
originalQuery.text(), transformedQuery.text());
|
||||
originalQuery.text(), transformedQuery.text());
|
||||
}
|
||||
} else {
|
||||
} else if (log.isDebugEnabled()){
|
||||
log.debug("Transformed original query '{}' into the following queries:\n{}",
|
||||
originalQuery.text(), queries.stream()
|
||||
.map(Query::text)
|
||||
.map(query -> "- '" + query + "'")
|
||||
.collect(joining("\n")));
|
||||
originalQuery.text(), queries.stream()
|
||||
.map(Query::text)
|
||||
.map(query -> "- '" + query + "'")
|
||||
.collect(joining("\n")));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,27 +256,40 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
// TODO use retriever id
|
||||
if (retrievers.size() == 1) {
|
||||
log.debug("Routing query '{}' to the following retriever: {}",
|
||||
query.text(), retrievers.iterator().next());
|
||||
} else {
|
||||
query.text(), retrievers.iterator().next());
|
||||
} else if (log.isDebugEnabled()) {
|
||||
log.debug("Routing query '{}' to the following retrievers:\n{}",
|
||||
query.text(), retrievers.stream()
|
||||
.map(retriever -> "- " + retriever.toString())
|
||||
.collect(joining("\n")));
|
||||
query.text(), retrievers.stream()
|
||||
.map(retriever -> "- " + retriever.toString())
|
||||
.collect(joining("\n")));
|
||||
}
|
||||
}
|
||||
|
||||
private static void log(Query query, ContentRetriever retriever, List<Content> contents) {
|
||||
// TODO use retriever id
|
||||
log.debug("Retrieved {} contents using query '{}' and retriever '{}'",
|
||||
contents.size(), query.text(), retriever);
|
||||
contents.size(), query.text(), retriever);
|
||||
|
||||
if (contents.size() > 0) {
|
||||
log.trace("Retrieved {} contents using query '{}' and retriever '{}':\n{}",
|
||||
contents.size(), query.text(), retriever, contents.stream()
|
||||
.map(Content::textSegment)
|
||||
.map(segment -> "- " + escapeNewlines(segment.text()))
|
||||
.collect(joining("\n")));
|
||||
if (!log.isTraceEnabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!contents.isEmpty()) {
|
||||
final var contentsSting = contents.stream()
|
||||
.map(Content::textSegment)
|
||||
.map(segment -> "- " + escapeNewlines(segment.text()))
|
||||
.collect(joining("\n"));
|
||||
log.trace("Retrieved {} contents using query '{}' and retriever '{}':\n{}",
|
||||
contents.size(),
|
||||
query.text(),
|
||||
retriever.getClass().getName(),
|
||||
contentsSting);
|
||||
} else {
|
||||
log.trace("Retrieved 0 contents using query '{}' and retriever '{}'",
|
||||
query.text(),
|
||||
retriever.getClass().getName());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static void log(Map<Query, Collection<List<Content>>> queryToContents, List<Content> contents) {
|
||||
|
@ -287,15 +306,21 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
|
||||
log.debug("Aggregated {} content(s) into {}", contentCount, contents.size());
|
||||
|
||||
log.trace("Aggregated {} content(s) into:\n{}",
|
||||
if (log.isTraceEnabled()) {
|
||||
log.trace("Aggregated {} content(s) into:\n{}",
|
||||
contentCount, contents.stream()
|
||||
.map(Content::textSegment)
|
||||
.map(segment -> "- " + escapeNewlines(segment.text()))
|
||||
.collect(joining("\n")));
|
||||
.map(Content::textSegment)
|
||||
.map(segment -> "- " + escapeNewlines(segment.text()))
|
||||
.collect(joining("\n")));
|
||||
}
|
||||
}
|
||||
|
||||
private static void log(ChatMessage augmentedChatMessage) {
|
||||
log.trace("Augmented chat message: {}", escapeNewlines(augmentedChatMessage.text()));
|
||||
if (log.isTraceEnabled()) {
|
||||
log.trace("Augmented chat message: {}",
|
||||
escapeNewlines(augmentedChatMessage.text())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static String escapeNewlines(String text) {
|
||||
|
@ -308,9 +333,51 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
|
||||
public static class DefaultRetrievalAugmentorBuilder {
|
||||
|
||||
private QueryTransformer queryTransformer;
|
||||
private QueryRouter queryRouter;
|
||||
private ContentAggregator contentAggregator;
|
||||
private ContentInjector contentInjector;
|
||||
private Executor executor;
|
||||
|
||||
DefaultRetrievalAugmentorBuilder() {
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentorBuilder contentRetriever(ContentRetriever contentRetriever) {
|
||||
this.queryRouter = new DefaultQueryRouter(ensureNotNull(contentRetriever, "contentRetriever"));
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentorBuilder queryTransformer(QueryTransformer queryTransformer) {
|
||||
this.queryTransformer = queryTransformer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentorBuilder queryRouter(QueryRouter queryRouter) {
|
||||
this.queryRouter = queryRouter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentorBuilder contentAggregator(ContentAggregator contentAggregator) {
|
||||
this.contentAggregator = contentAggregator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentorBuilder contentInjector(ContentInjector contentInjector) {
|
||||
this.contentInjector = contentInjector;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentorBuilder executor(Executor executor) {
|
||||
this.executor = executor;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultRetrievalAugmentor build() {
|
||||
return new DefaultRetrievalAugmentor(this.queryTransformer, this.queryRouter, this.contentAggregator, this.contentInjector, this.executor);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "DefaultRetrievalAugmentor.DefaultRetrievalAugmentorBuilder(queryTransformer=" + this.queryTransformer + ", queryRouter=" + this.queryRouter + ", contentAggregator=" + this.contentAggregator + ", contentInjector=" + this.contentInjector + ", executor=" + this.executor + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,9 +5,12 @@ import dev.langchain4j.model.scoring.ScoringModel;
|
|||
import dev.langchain4j.rag.content.Content;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.rag.query.transformer.ExpandingQueryTransformer;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
|
@ -64,7 +67,6 @@ public class ReRankingContentAggregator implements ContentAggregator {
|
|||
this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
|
||||
}
|
||||
|
||||
@Builder
|
||||
public ReRankingContentAggregator(ScoringModel scoringModel,
|
||||
Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
|
||||
Double minScore) {
|
||||
|
@ -73,6 +75,10 @@ public class ReRankingContentAggregator implements ContentAggregator {
|
|||
this.minScore = minScore;
|
||||
}
|
||||
|
||||
public static ReRankingContentAggregatorBuilder builder() {
|
||||
return new ReRankingContentAggregatorBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
|
||||
|
||||
|
@ -126,4 +132,36 @@ public class ReRankingContentAggregator implements ContentAggregator {
|
|||
.map(Content::from)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
public static class ReRankingContentAggregatorBuilder {
|
||||
private ScoringModel scoringModel;
|
||||
private Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
|
||||
private Double minScore;
|
||||
|
||||
ReRankingContentAggregatorBuilder() {
|
||||
}
|
||||
|
||||
public ReRankingContentAggregatorBuilder scoringModel(ScoringModel scoringModel) {
|
||||
this.scoringModel = scoringModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ReRankingContentAggregatorBuilder querySelector(Function<Map<Query, Collection<List<Content>>>, Query> querySelector) {
|
||||
this.querySelector = querySelector;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ReRankingContentAggregatorBuilder minScore(Double minScore) {
|
||||
this.minScore = minScore;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ReRankingContentAggregator build() {
|
||||
return new ReRankingContentAggregator(this.scoringModel, this.querySelector, this.minScore);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "ReRankingContentAggregator.ReRankingContentAggregatorBuilder(scoringModel=" + this.scoringModel + ", querySelector=" + this.querySelector + ", minScore=" + this.minScore + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,12 @@ package dev.langchain4j.rag.content.aggregator;
|
|||
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
|
||||
|
|
|
@ -7,14 +7,15 @@ import dev.langchain4j.data.segment.TextSegment;
|
|||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
|
@ -47,10 +48,10 @@ public class DefaultContentInjector implements ContentInjector {
|
|||
|
||||
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
|
||||
"""
|
||||
{{userMessage}}
|
||||
{{userMessage}}
|
||||
|
||||
Answer using the following information:
|
||||
{{contents}}"""
|
||||
Answer using the following information:
|
||||
{{contents}}"""
|
||||
);
|
||||
|
||||
private final PromptTemplate promptTemplate;
|
||||
|
@ -68,12 +69,15 @@ public class DefaultContentInjector implements ContentInjector {
|
|||
this(ensureNotNull(promptTemplate, "promptTemplate"), null);
|
||||
}
|
||||
|
||||
@Builder
|
||||
public DefaultContentInjector(PromptTemplate promptTemplate, List<String> metadataKeysToInclude) {
|
||||
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
|
||||
this.metadataKeysToInclude = copyIfNotNull(metadataKeysToInclude);
|
||||
}
|
||||
|
||||
public static DefaultContentInjectorBuilder builder() {
|
||||
return new DefaultContentInjectorBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
|
||||
|
||||
|
@ -161,4 +165,30 @@ public class DefaultContentInjector implements ContentInjector {
|
|||
? segmentContent
|
||||
: "content: %s\n%s".formatted(segmentContent, segmentMetadata);
|
||||
}
|
||||
|
||||
public static class DefaultContentInjectorBuilder {
|
||||
private PromptTemplate promptTemplate;
|
||||
private List<String> metadataKeysToInclude;
|
||||
|
||||
DefaultContentInjectorBuilder() {
|
||||
}
|
||||
|
||||
public DefaultContentInjectorBuilder promptTemplate(PromptTemplate promptTemplate) {
|
||||
this.promptTemplate = promptTemplate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultContentInjectorBuilder metadataKeysToInclude(List<String> metadataKeysToInclude) {
|
||||
this.metadataKeysToInclude = metadataKeysToInclude;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultContentInjector build() {
|
||||
return new DefaultContentInjector(this.promptTemplate, this.metadataKeysToInclude);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "DefaultContentInjector.DefaultContentInjectorBuilder(promptTemplate=" + this.promptTemplate + ", metadataKeysToInclude=" + this.metadataKeysToInclude + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,14 +11,15 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
@ -109,7 +110,6 @@ public class EmbeddingStoreContentRetriever implements ContentRetriever {
|
|||
);
|
||||
}
|
||||
|
||||
@Builder
|
||||
private EmbeddingStoreContentRetriever(String displayName,
|
||||
EmbeddingStore<TextSegment> embeddingStore,
|
||||
EmbeddingModel embeddingModel,
|
||||
|
@ -141,8 +141,22 @@ public class EmbeddingStoreContentRetriever implements ContentRetriever {
|
|||
return null;
|
||||
}
|
||||
|
||||
public static EmbeddingStoreContentRetrieverBuilder builder() {
|
||||
return new EmbeddingStoreContentRetrieverBuilder();
|
||||
}
|
||||
|
||||
public static class EmbeddingStoreContentRetrieverBuilder {
|
||||
|
||||
private String displayName;
|
||||
private EmbeddingStore<TextSegment> embeddingStore;
|
||||
private EmbeddingModel embeddingModel;
|
||||
private Function<Query, Integer> dynamicMaxResults;
|
||||
private Function<Query, Double> dynamicMinScore;
|
||||
private Function<Query, Filter> dynamicFilter;
|
||||
|
||||
EmbeddingStoreContentRetrieverBuilder() {
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder maxResults(Integer maxResults) {
|
||||
if (maxResults != null) {
|
||||
dynamicMaxResults = (query) -> ensureGreaterThanZero(maxResults, "maxResults");
|
||||
|
@ -163,6 +177,44 @@ public class EmbeddingStoreContentRetriever implements ContentRetriever {
|
|||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder displayName(String displayName) {
|
||||
this.displayName = displayName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder embeddingStore(EmbeddingStore<TextSegment> embeddingStore) {
|
||||
this.embeddingStore = embeddingStore;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder embeddingModel(EmbeddingModel embeddingModel) {
|
||||
this.embeddingModel = embeddingModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder dynamicMaxResults(Function<Query, Integer> dynamicMaxResults) {
|
||||
this.dynamicMaxResults = dynamicMaxResults;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder dynamicMinScore(Function<Query, Double> dynamicMinScore) {
|
||||
this.dynamicMinScore = dynamicMinScore;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetrieverBuilder dynamicFilter(Function<Query, Filter> dynamicFilter) {
|
||||
this.dynamicFilter = dynamicFilter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingStoreContentRetriever build() {
|
||||
return new EmbeddingStoreContentRetriever(this.displayName, this.embeddingStore, this.embeddingModel, this.dynamicMaxResults, this.dynamicMinScore, this.dynamicFilter);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "EmbeddingStoreContentRetriever.EmbeddingStoreContentRetrieverBuilder(displayName=" + this.displayName + ", embeddingStore=" + this.embeddingStore + ", embeddingModel=" + this.embeddingModel + ", dynamicMaxResults=" + this.dynamicMaxResults + ", dynamicMinScore=" + this.dynamicMinScore + ", dynamicFilter=" + this.dynamicFilter + ")";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,7 +5,6 @@ import dev.langchain4j.rag.query.Query;
|
|||
import dev.langchain4j.web.search.WebSearchEngine;
|
||||
import dev.langchain4j.web.search.WebSearchRequest;
|
||||
import dev.langchain4j.web.search.WebSearchResults;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -26,12 +25,15 @@ public class WebSearchContentRetriever implements ContentRetriever {
|
|||
private final WebSearchEngine webSearchEngine;
|
||||
private final int maxResults;
|
||||
|
||||
@Builder
|
||||
public WebSearchContentRetriever(WebSearchEngine webSearchEngine, Integer maxResults) {
|
||||
this.webSearchEngine = ensureNotNull(webSearchEngine, "webSearchEngine");
|
||||
this.maxResults = getOrDefault(maxResults, 5);
|
||||
}
|
||||
|
||||
public static WebSearchContentRetrieverBuilder builder() {
|
||||
return new WebSearchContentRetrieverBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Content> retrieve(Query query) {
|
||||
|
||||
|
@ -46,4 +48,30 @@ public class WebSearchContentRetriever implements ContentRetriever {
|
|||
.map(Content::from)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
public static class WebSearchContentRetrieverBuilder {
|
||||
private WebSearchEngine webSearchEngine;
|
||||
private Integer maxResults;
|
||||
|
||||
WebSearchContentRetrieverBuilder() {
|
||||
}
|
||||
|
||||
public WebSearchContentRetrieverBuilder webSearchEngine(WebSearchEngine webSearchEngine) {
|
||||
this.webSearchEngine = webSearchEngine;
|
||||
return this;
|
||||
}
|
||||
|
||||
public WebSearchContentRetrieverBuilder maxResults(Integer maxResults) {
|
||||
this.maxResults = maxResults;
|
||||
return this;
|
||||
}
|
||||
|
||||
public WebSearchContentRetriever build() {
|
||||
return new WebSearchContentRetriever(this.webSearchEngine, this.maxResults);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "WebSearchContentRetriever.WebSearchContentRetrieverBuilder(webSearchEngine=" + this.webSearchEngine + ", maxResults=" + this.maxResults + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import dev.langchain4j.model.input.Prompt;
|
|||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.rag.content.retriever.ContentRetriever;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import lombok.Builder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -16,7 +15,9 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.rag.query.router.LanguageModelQueryRouter.FallbackStrategy.DO_NOT_ROUTE;
|
||||
import static java.util.Arrays.stream;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
@ -46,12 +47,12 @@ public class LanguageModelQueryRouter implements QueryRouter {
|
|||
|
||||
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
|
||||
"""
|
||||
Based on the user query, determine the most suitable data source(s) \
|
||||
to retrieve relevant information from the following options:
|
||||
{{options}}
|
||||
It is very important that your answer consists of either a single number \
|
||||
or multiple numbers separated by commas and nothing else!
|
||||
User query: {{query}}"""
|
||||
Based on the user query, determine the most suitable data source(s) \
|
||||
to retrieve relevant information from the following options:
|
||||
{{options}}
|
||||
It is very important that your answer consists of either a single number \
|
||||
or multiple numbers separated by commas and nothing else!
|
||||
User query: {{query}}"""
|
||||
);
|
||||
|
||||
protected final ChatLanguageModel chatLanguageModel;
|
||||
|
@ -65,7 +66,6 @@ public class LanguageModelQueryRouter implements QueryRouter {
|
|||
this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE);
|
||||
}
|
||||
|
||||
@Builder
|
||||
public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
|
||||
Map<ContentRetriever, String> retrieverToDescription,
|
||||
PromptTemplate promptTemplate,
|
||||
|
@ -94,6 +94,10 @@ public class LanguageModelQueryRouter implements QueryRouter {
|
|||
this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE);
|
||||
}
|
||||
|
||||
public static LanguageModelQueryRouterBuilder builder() {
|
||||
return new LanguageModelQueryRouterBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<ContentRetriever> route(Query query) {
|
||||
Prompt prompt = createPrompt(query);
|
||||
|
@ -107,17 +111,17 @@ public class LanguageModelQueryRouter implements QueryRouter {
|
|||
}
|
||||
|
||||
protected Collection<ContentRetriever> fallback(Query query, Exception e) {
|
||||
switch (fallbackStrategy) {
|
||||
case DO_NOT_ROUTE:
|
||||
return switch (fallbackStrategy) {
|
||||
case DO_NOT_ROUTE -> {
|
||||
log.debug("Fallback: query '{}' will not be routed", query.text());
|
||||
return emptyList();
|
||||
case ROUTE_TO_ALL:
|
||||
yield emptyList();
|
||||
}
|
||||
case ROUTE_TO_ALL -> {
|
||||
log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text());
|
||||
return new ArrayList<>(idToRetriever.values());
|
||||
case FAIL:
|
||||
default:
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
yield new ArrayList<>(idToRetriever.values());
|
||||
}
|
||||
default -> throw new RuntimeException(e);
|
||||
};
|
||||
}
|
||||
|
||||
protected Prompt createPrompt(Query query) {
|
||||
|
@ -157,4 +161,42 @@ public class LanguageModelQueryRouter implements QueryRouter {
|
|||
*/
|
||||
FAIL
|
||||
}
|
||||
|
||||
public static class LanguageModelQueryRouterBuilder {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
private Map<ContentRetriever, String> retrieverToDescription;
|
||||
private PromptTemplate promptTemplate;
|
||||
private FallbackStrategy fallbackStrategy;
|
||||
|
||||
LanguageModelQueryRouterBuilder() {
|
||||
}
|
||||
|
||||
public LanguageModelQueryRouterBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
|
||||
this.chatLanguageModel = chatLanguageModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
public LanguageModelQueryRouterBuilder retrieverToDescription(Map<ContentRetriever, String> retrieverToDescription) {
|
||||
this.retrieverToDescription = retrieverToDescription;
|
||||
return this;
|
||||
}
|
||||
|
||||
public LanguageModelQueryRouterBuilder promptTemplate(PromptTemplate promptTemplate) {
|
||||
this.promptTemplate = promptTemplate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public LanguageModelQueryRouterBuilder fallbackStrategy(FallbackStrategy fallbackStrategy) {
|
||||
this.fallbackStrategy = fallbackStrategy;
|
||||
return this;
|
||||
}
|
||||
|
||||
public LanguageModelQueryRouter build() {
|
||||
return new LanguageModelQueryRouter(this.chatLanguageModel, this.retrieverToDescription, this.promptTemplate, this.fallbackStrategy);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "LanguageModelQueryRouter.LanguageModelQueryRouterBuilder(chatLanguageModel=" + this.chatLanguageModel + ", retrieverToDescription=" + this.retrieverToDescription + ", promptTemplate=" + this.promptTemplate + ", fallbackStrategy=" + this.fallbackStrategy + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,9 +8,12 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
@ -35,18 +38,18 @@ public class CompressingQueryTransformer implements QueryTransformer {
|
|||
|
||||
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
|
||||
"""
|
||||
Read and understand the conversation between the User and the AI. \
|
||||
Then, analyze the new query from the User. \
|
||||
Identify all relevant details, terms, and context from both the conversation and the new query. \
|
||||
Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval.
|
||||
Read and understand the conversation between the User and the AI. \
|
||||
Then, analyze the new query from the User. \
|
||||
Identify all relevant details, terms, and context from both the conversation and the new query. \
|
||||
Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval.
|
||||
|
||||
Conversation:
|
||||
{{chatMemory}}
|
||||
Conversation:
|
||||
{{chatMemory}}
|
||||
|
||||
User query: {{query}}
|
||||
User query: {{query}}
|
||||
|
||||
It is very important that you provide only reformulated query and nothing else! \
|
||||
Do not prepend a query with anything!"""
|
||||
It is very important that you provide only reformulated query and nothing else! \
|
||||
Do not prepend a query with anything!"""
|
||||
);
|
||||
|
||||
protected final PromptTemplate promptTemplate;
|
||||
|
@ -56,12 +59,15 @@ public class CompressingQueryTransformer implements QueryTransformer {
|
|||
this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE);
|
||||
}
|
||||
|
||||
@Builder
|
||||
public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
|
||||
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
|
||||
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
|
||||
}
|
||||
|
||||
public static CompressingQueryTransformerBuilder builder() {
|
||||
return new CompressingQueryTransformerBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Query> transform(Query query) {
|
||||
|
||||
|
@ -105,4 +111,30 @@ public class CompressingQueryTransformer implements QueryTransformer {
|
|||
variables.put("chatMemory", chatMemory);
|
||||
return promptTemplate.apply(variables);
|
||||
}
|
||||
|
||||
public static class CompressingQueryTransformerBuilder {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
private PromptTemplate promptTemplate;
|
||||
|
||||
CompressingQueryTransformerBuilder() {
|
||||
}
|
||||
|
||||
public CompressingQueryTransformerBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
|
||||
this.chatLanguageModel = chatLanguageModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CompressingQueryTransformerBuilder promptTemplate(PromptTemplate promptTemplate) {
|
||||
this.promptTemplate = promptTemplate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CompressingQueryTransformer build() {
|
||||
return new CompressingQueryTransformer(this.chatLanguageModel, this.promptTemplate);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "CompressingQueryTransformer.CompressingQueryTransformerBuilder(chatLanguageModel=" + this.chatLanguageModel + ", promptTemplate=" + this.promptTemplate + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
|
@ -37,13 +36,13 @@ public class ExpandingQueryTransformer implements QueryTransformer {
|
|||
|
||||
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
|
||||
"""
|
||||
Generate {{n}} different versions of a provided user query. \
|
||||
Each version should be worded differently, using synonyms or alternative sentence structures, \
|
||||
but they should all retain the original meaning. \
|
||||
These versions will be used to retrieve relevant documents. \
|
||||
It is very important to provide each query version on a separate line, \
|
||||
without enumerations, hyphens, or any additional formatting!
|
||||
User query: {{query}}"""
|
||||
Generate {{n}} different versions of a provided user query. \
|
||||
Each version should be worded differently, using synonyms or alternative sentence structures, \
|
||||
but they should all retain the original meaning. \
|
||||
These versions will be used to retrieve relevant documents. \
|
||||
It is very important to provide each query version on a separate line, \
|
||||
without enumerations, hyphens, or any additional formatting!
|
||||
User query: {{query}}"""
|
||||
);
|
||||
public static final int DEFAULT_N = 3;
|
||||
|
||||
|
@ -63,13 +62,16 @@ public class ExpandingQueryTransformer implements QueryTransformer {
|
|||
this(chatLanguageModel, ensureNotNull(promptTemplate, "promptTemplate"), DEFAULT_N);
|
||||
}
|
||||
|
||||
@Builder
|
||||
public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate, Integer n) {
|
||||
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
|
||||
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
|
||||
this.n = ensureGreaterThanZero(getOrDefault(n, DEFAULT_N), "n");
|
||||
}
|
||||
|
||||
public static ExpandingQueryTransformerBuilder builder() {
|
||||
return new ExpandingQueryTransformerBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Query> transform(Query query) {
|
||||
Prompt prompt = createPrompt(query);
|
||||
|
@ -94,4 +96,36 @@ public class ExpandingQueryTransformer implements QueryTransformer {
|
|||
.filter(Utils::isNotNullOrBlank)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
public static class ExpandingQueryTransformerBuilder {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
private PromptTemplate promptTemplate;
|
||||
private Integer n;
|
||||
|
||||
ExpandingQueryTransformerBuilder() {
|
||||
}
|
||||
|
||||
public ExpandingQueryTransformerBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
|
||||
this.chatLanguageModel = chatLanguageModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ExpandingQueryTransformerBuilder promptTemplate(PromptTemplate promptTemplate) {
|
||||
this.promptTemplate = promptTemplate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ExpandingQueryTransformerBuilder n(Integer n) {
|
||||
this.n = n;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ExpandingQueryTransformer build() {
|
||||
return new ExpandingQueryTransformer(this.chatLanguageModel, this.promptTemplate, this.n);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "ExpandingQueryTransformer.ExpandingQueryTransformerBuilder(chatLanguageModel=" + this.chatLanguageModel + ", promptTemplate=" + this.promptTemplate + ", n=" + this.n + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package dev.langchain4j.retriever;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
|
|
|
@ -4,18 +4,17 @@ import dev.langchain4j.data.document.Metadata;
|
|||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.Builder;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
/**
|
||||
* Represents a request to search in an {@link EmbeddingStore}.
|
||||
*/
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class EmbeddingSearchRequest {
|
||||
|
||||
private final Embedding queryEmbedding;
|
||||
|
@ -38,7 +37,6 @@ public class EmbeddingSearchRequest {
|
|||
* Please note that not all {@link EmbeddingStore}s support this feature yet.
|
||||
* This is an optional parameter. Default: no filtering
|
||||
*/
|
||||
@Builder
|
||||
public EmbeddingSearchRequest(Embedding queryEmbedding, Integer maxResults, Double minScore, Filter filter) {
|
||||
this.queryEmbedding = ensureNotNull(queryEmbedding, "queryEmbedding");
|
||||
this.maxResults = ensureGreaterThanZero(getOrDefault(maxResults, 3), "maxResults");
|
||||
|
@ -46,6 +44,10 @@ public class EmbeddingSearchRequest {
|
|||
this.filter = filter;
|
||||
}
|
||||
|
||||
public static EmbeddingSearchRequestBuilder builder() {
|
||||
return new EmbeddingSearchRequestBuilder();
|
||||
}
|
||||
|
||||
public Embedding queryEmbedding() {
|
||||
return queryEmbedding;
|
||||
}
|
||||
|
@ -61,4 +63,59 @@ public class EmbeddingSearchRequest {
|
|||
public Filter filter() {
|
||||
return filter;
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof EmbeddingSearchRequest other)) return false;
|
||||
return this.maxResults == other.maxResults
|
||||
&& this.minScore == other.minScore
|
||||
&& Objects.equals(this.queryEmbedding, other.queryEmbedding)
|
||||
&& Objects.equals(this.filter, other.filter);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(queryEmbedding, maxResults, minScore, filter);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "EmbeddingSearchRequest(queryEmbedding=" + this.queryEmbedding + ", maxResults=" + this.maxResults + ", minScore=" + this.minScore + ", filter=" + this.filter + ")";
|
||||
}
|
||||
|
||||
public static class EmbeddingSearchRequestBuilder {
|
||||
private Embedding queryEmbedding;
|
||||
private Integer maxResults;
|
||||
private Double minScore;
|
||||
private Filter filter;
|
||||
|
||||
EmbeddingSearchRequestBuilder() {
|
||||
}
|
||||
|
||||
public EmbeddingSearchRequestBuilder queryEmbedding(Embedding queryEmbedding) {
|
||||
this.queryEmbedding = queryEmbedding;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingSearchRequestBuilder maxResults(Integer maxResults) {
|
||||
this.maxResults = maxResults;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingSearchRequestBuilder minScore(Double minScore) {
|
||||
this.minScore = minScore;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingSearchRequestBuilder filter(Filter filter) {
|
||||
this.filter = filter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public EmbeddingSearchRequest build() {
|
||||
return new EmbeddingSearchRequest(this.queryEmbedding, this.maxResults, this.minScore, this.filter);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "EmbeddingSearchRequest.EmbeddingSearchRequestBuilder(queryEmbedding=" + this.queryEmbedding + ", maxResults=" + this.maxResults + ", minScore=" + this.minScore + ", filter=" + this.filter + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,8 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
|||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.spi.data.document.splitter.DocumentSplitterFactory;
|
||||
import dev.langchain4j.spi.model.embedding.EmbeddingModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
@ -48,9 +49,10 @@ import static java.util.stream.Collectors.toList;
|
|||
* Including a document title or a short summary in each {@code TextSegment} is a common technique
|
||||
* to improve the quality of similarity searches.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingStoreIngestor {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(EmbeddingStoreIngestor.class);
|
||||
|
||||
private final DocumentTransformer documentTransformer;
|
||||
private final DocumentSplitter documentSplitter;
|
||||
private final TextSegmentTransformer textSegmentTransformer;
|
||||
|
@ -123,6 +125,7 @@ public class EmbeddingStoreIngestor {
|
|||
* <br>
|
||||
* For the "Easy RAG", import {@code langchain4j-easy-rag} module,
|
||||
* which contains a {@code DocumentSplitterFactory} and {@code EmbeddingModelFactory} implementations.
|
||||
*
|
||||
* @return result including information related to ingestion process.
|
||||
*/
|
||||
public static IngestionResult ingest(Document document, EmbeddingStore<TextSegment> embeddingStore) {
|
||||
|
@ -137,6 +140,7 @@ public class EmbeddingStoreIngestor {
|
|||
* <br>
|
||||
* For the "Easy RAG", import {@code langchain4j-easy-rag} module,
|
||||
* which contains a {@code DocumentSplitterFactory} and {@code EmbeddingModelFactory} implementations.
|
||||
*
|
||||
* @return result including information related to ingestion process.
|
||||
*/
|
||||
public static IngestionResult ingest(List<Document> documents, EmbeddingStore<TextSegment> embeddingStore) {
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
package dev.langchain4j.store.embedding.filter;
|
||||
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.*;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
|
||||
import dev.langchain4j.store.embedding.filter.logical.And;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Not;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Or;
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
package dev.langchain4j.store.embedding.filter;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.*;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
|
|
|
@ -2,9 +2,8 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
|
@ -12,8 +11,6 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
|||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsEqualTo implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -34,11 +31,10 @@ public class IsEqualTo implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -56,4 +52,20 @@ public class IsEqualTo implements Filter {
|
|||
|
||||
return actualValue.equals(comparisonValue);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsEqualTo other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValue, other.comparisonValue);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,16 +2,14 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsGreaterThan implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -32,11 +30,10 @@ public class IsGreaterThan implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -50,4 +47,20 @@ public class IsGreaterThan implements Filter {
|
|||
|
||||
return ((Comparable) actualValue).compareTo(comparisonValue) > 0;
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsGreaterThan other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValue, other.comparisonValue);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsGreaterThan(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,16 +2,14 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsGreaterThanOrEqualTo implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -32,11 +30,10 @@ public class IsGreaterThanOrEqualTo implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -50,4 +47,20 @@ public class IsGreaterThanOrEqualTo implements Filter {
|
|||
|
||||
return ((Comparable) actualValue).compareTo(comparisonValue) >= 0;
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsGreaterThanOrEqualTo other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValue, other.comparisonValue);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsGreaterThanOrEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,22 +2,21 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.containsAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.UUIDComparator.containsAsUUID;
|
||||
import static java.util.Collections.unmodifiableSet;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsIn implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -40,11 +39,10 @@ public class IsIn implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -61,4 +59,21 @@ public class IsIn implements Filter {
|
|||
|
||||
return comparisonValues.contains(actualValue);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsIn other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValues, other.comparisonValues);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValues);
|
||||
}
|
||||
|
||||
|
||||
public String toString() {
|
||||
return "IsIn(key=" + this.key + ", comparisonValues=" + this.comparisonValues + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,16 +2,14 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsLessThan implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -32,11 +30,10 @@ public class IsLessThan implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -50,4 +47,20 @@ public class IsLessThan implements Filter {
|
|||
|
||||
return ((Comparable) actualValue).compareTo(comparisonValue) < 0;
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsLessThan other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValue, other.comparisonValue);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsLessThan(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,16 +2,14 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsLessThanOrEqualTo implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -32,11 +30,10 @@ public class IsLessThanOrEqualTo implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -50,4 +47,21 @@ public class IsLessThanOrEqualTo implements Filter {
|
|||
|
||||
return ((Comparable) actualValue).compareTo(comparisonValue) <= 0;
|
||||
}
|
||||
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsLessThanOrEqualTo other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValue, other.comparisonValue);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsLessThanOrEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")";
|
||||
}
|
||||
}
|
|
@ -2,9 +2,8 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
|
@ -12,8 +11,6 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
|||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.compareAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsNotEqualTo implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -34,11 +31,10 @@ public class IsNotEqualTo implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return true;
|
||||
}
|
||||
|
@ -56,4 +52,20 @@ public class IsNotEqualTo implements Filter {
|
|||
|
||||
return !actualValue.equals(comparisonValue);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsNotEqualTo other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValue, other.comparisonValue);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValue);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsNotEqualTo(key=" + this.key + ", comparisonValue=" + this.comparisonValue + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,22 +2,21 @@ package dev.langchain4j.store.embedding.filter.comparison;
|
|||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.NumberComparator.containsAsBigDecimals;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.TypeChecker.ensureTypesAreCompatible;
|
||||
import static dev.langchain4j.store.embedding.filter.comparison.UUIDComparator.containsAsUUID;
|
||||
import static java.util.Collections.unmodifiableSet;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class IsNotIn implements Filter {
|
||||
|
||||
private final String key;
|
||||
|
@ -40,11 +39,10 @@ public class IsNotIn implements Filter {
|
|||
|
||||
@Override
|
||||
public boolean test(Object object) {
|
||||
if (!(object instanceof Metadata)) {
|
||||
if (!(object instanceof Metadata metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Metadata metadata = (Metadata) object;
|
||||
if (!metadata.containsKey(key)) {
|
||||
return true;
|
||||
}
|
||||
|
@ -61,4 +59,20 @@ public class IsNotIn implements Filter {
|
|||
|
||||
return !comparisonValues.contains(actualValue);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof IsNotIn other)) return false;
|
||||
|
||||
return Objects.equals(this.key, other.key)
|
||||
&& Objects.equals(this.comparisonValues, other.comparisonValues);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(key, comparisonValues);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "IsNotIn(key=" + this.key + ", comparisonValues=" + this.comparisonValues + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
package dev.langchain4j.store.embedding.filter.logical;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class And implements Filter {
|
||||
|
||||
private final Filter left;
|
||||
|
@ -30,4 +28,18 @@ public class And implements Filter {
|
|||
public boolean test(Object object) {
|
||||
return left().test(object) && right().test(object);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof And other)) return false;
|
||||
return Objects.equals(this.left, other.left) && Objects.equals(this.right, other.right);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(left, right);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "And(left=" + this.left + ", right=" + this.right + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
package dev.langchain4j.store.embedding.filter.logical;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class Not implements Filter {
|
||||
|
||||
private final Filter expression;
|
||||
|
@ -24,4 +22,18 @@ public class Not implements Filter {
|
|||
public boolean test(Object object) {
|
||||
return !expression.test(object);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof Not other)) return false;
|
||||
return Objects.equals(this.expression, other.expression);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(expression);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "Not(expression=" + this.expression + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
package dev.langchain4j.store.embedding.filter.logical;
|
||||
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
@ToString
|
||||
@EqualsAndHashCode
|
||||
public class Or implements Filter {
|
||||
|
||||
private final Filter left;
|
||||
|
@ -30,4 +28,18 @@ public class Or implements Filter {
|
|||
public boolean test(Object object) {
|
||||
return left().test(object) || right().test(object);
|
||||
}
|
||||
|
||||
public boolean equals(final Object o) {
|
||||
if (o == this) return true;
|
||||
if (!(o instanceof Or other)) return false;
|
||||
return Objects.equals(this.left, other.left) && Objects.equals(this.right, other.right);
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
return Objects.hash(left, right);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "Or(left=" + this.left + ", right=" + this.right + ")";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import java.util.Map;
|
|||
import java.util.Objects;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
|||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import lombok.Data;
|
||||
import org.assertj.core.api.WithAssertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -24,16 +23,15 @@ import java.util.Set;
|
|||
|
||||
class ToolSpecificationsTest implements WithAssertions {
|
||||
|
||||
@Data
|
||||
public static class Person {
|
||||
|
||||
public record Person(
|
||||
@Description("Name of the person")
|
||||
private String name;
|
||||
private List<String> aliases;
|
||||
private boolean active;
|
||||
private Person parent;
|
||||
private Address currentAddress;
|
||||
private List<Address> previousAddresses;
|
||||
String name,
|
||||
List<String> aliases,
|
||||
boolean active,
|
||||
Person parent,
|
||||
Address currentAddress,
|
||||
List<Address> previousAddresses
|
||||
) {
|
||||
}
|
||||
|
||||
public static class Address {
|
||||
|
@ -49,36 +47,36 @@ class ToolSpecificationsTest implements WithAssertions {
|
|||
public static class Wrapper {
|
||||
@Tool({"line1", "line2"})
|
||||
public int f(
|
||||
@P("foo") String p0,
|
||||
boolean p1,
|
||||
@P("b2") Boolean p2,
|
||||
byte p3,
|
||||
Byte p4,
|
||||
short p5,
|
||||
Short p6,
|
||||
int p7,
|
||||
Integer p8,
|
||||
long p9,
|
||||
Long p10,
|
||||
@P("biggy")
|
||||
BigInteger p11,
|
||||
float p12,
|
||||
Float p13,
|
||||
double p14,
|
||||
Double p15,
|
||||
@P("bigger") BigDecimal p16,
|
||||
String[] p17,
|
||||
Integer[] p18,
|
||||
Boolean[] p19,
|
||||
int[] p20,
|
||||
boolean[] p21,
|
||||
List<Integer> p22,
|
||||
Set<BigDecimal> p23,
|
||||
Collection<String> p24,
|
||||
E p25,
|
||||
Person p26,
|
||||
@P(value = "optional", required = false) int p27,
|
||||
@P(value = "required") int p28) {
|
||||
@P("foo") String p0,
|
||||
boolean p1,
|
||||
@P("b2") Boolean p2,
|
||||
byte p3,
|
||||
Byte p4,
|
||||
short p5,
|
||||
Short p6,
|
||||
int p7,
|
||||
Integer p8,
|
||||
long p9,
|
||||
Long p10,
|
||||
@P("biggy")
|
||||
BigInteger p11,
|
||||
float p12,
|
||||
Float p13,
|
||||
double p14,
|
||||
Double p15,
|
||||
@P("bigger") BigDecimal p16,
|
||||
String[] p17,
|
||||
Integer[] p18,
|
||||
Boolean[] p19,
|
||||
int[] p20,
|
||||
boolean[] p21,
|
||||
List<Integer> p22,
|
||||
Set<BigDecimal> p23,
|
||||
Collection<String> p24,
|
||||
E p25,
|
||||
Person p26,
|
||||
@P(value = "optional", required = false) int p27,
|
||||
@P(value = "required") int p28) {
|
||||
return 42;
|
||||
}
|
||||
|
||||
|
@ -122,35 +120,35 @@ class ToolSpecificationsTest implements WithAssertions {
|
|||
|
||||
private static Method getF() throws NoSuchMethodException {
|
||||
return Wrapper.class.getMethod("f",
|
||||
String.class,//0
|
||||
boolean.class,
|
||||
Boolean.class,
|
||||
byte.class,
|
||||
Byte.class,
|
||||
short.class,//5
|
||||
Short.class,
|
||||
int.class,
|
||||
Integer.class,
|
||||
long.class,
|
||||
Long.class, //10
|
||||
BigInteger.class,
|
||||
float.class,
|
||||
Float.class,
|
||||
double.class,
|
||||
Double.class, //15
|
||||
BigDecimal.class,
|
||||
String[].class,
|
||||
Integer[].class,
|
||||
Boolean[].class,
|
||||
int[].class,//20
|
||||
boolean[].class,
|
||||
List.class,
|
||||
Set.class,
|
||||
Collection.class,
|
||||
E.class,// 25
|
||||
Person.class,
|
||||
int.class,
|
||||
int.class);
|
||||
String.class,//0
|
||||
boolean.class,
|
||||
Boolean.class,
|
||||
byte.class,
|
||||
Byte.class,
|
||||
short.class,//5
|
||||
Short.class,
|
||||
int.class,
|
||||
Integer.class,
|
||||
long.class,
|
||||
Long.class, //10
|
||||
BigInteger.class,
|
||||
float.class,
|
||||
Float.class,
|
||||
double.class,
|
||||
Double.class, //15
|
||||
BigDecimal.class,
|
||||
String[].class,
|
||||
Integer[].class,
|
||||
Boolean[].class,
|
||||
int[].class,//20
|
||||
boolean[].class,
|
||||
List.class,
|
||||
Set.class,
|
||||
Collection.class,
|
||||
E.class,// 25
|
||||
Person.class,
|
||||
int.class,
|
||||
int.class);
|
||||
}
|
||||
|
||||
public static <K, V> Map<K, V> mapOf(K k1, V v1) {
|
||||
|
@ -192,24 +190,24 @@ class ToolSpecificationsTest implements WithAssertions {
|
|||
assertThat(specs).hasSize(2);
|
||||
|
||||
assertThat(specs).extracting(ToolSpecification::name)
|
||||
.containsExactlyInAnyOrder("f", "func_name");
|
||||
.containsExactlyInAnyOrder("f", "func_name");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_toolSpecificationsFrom_with_duplicate_method_names() {
|
||||
assertThatExceptionOfType(IllegalArgumentException.class)
|
||||
.isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateMethodNames()))
|
||||
.withMessage("Tool names must be unique. The tool 'duplicateMethod' appears several times")
|
||||
.withNoCause();
|
||||
.isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateMethodNames()))
|
||||
.withMessage("Tool names must be unique. The tool 'duplicateMethod' appears several times")
|
||||
.withNoCause();
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_toolSpecificationsFrom_with_duplicate_names() {
|
||||
assertThatExceptionOfType(IllegalArgumentException.class)
|
||||
.isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateNames()))
|
||||
.withMessage("Tool names must be unique. The tool 'duplicate_name' appears several times")
|
||||
.withNoCause();
|
||||
.isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateNames()))
|
||||
.withMessage("Tool names must be unique. The tool 'duplicate_name' appears several times")
|
||||
.withNoCause();
|
||||
|
||||
}
|
||||
|
||||
|
@ -238,73 +236,72 @@ class ToolSpecificationsTest implements WithAssertions {
|
|||
|
||||
assertThat(properties).hasSize(29);
|
||||
assertThat(properties)
|
||||
.containsEntry("arg0", JsonStringSchema.builder().description("foo").build())
|
||||
.containsEntry("arg1", new JsonBooleanSchema())
|
||||
.containsEntry("arg2", JsonBooleanSchema.builder().description("b2").build())
|
||||
.containsEntry("arg3", new JsonIntegerSchema())
|
||||
.containsEntry("arg4", new JsonIntegerSchema())
|
||||
.containsEntry("arg5", new JsonIntegerSchema())
|
||||
.containsEntry("arg6", new JsonIntegerSchema())
|
||||
.containsEntry("arg7", new JsonIntegerSchema())
|
||||
.containsEntry("arg8", new JsonIntegerSchema())
|
||||
.containsEntry("arg9", new JsonIntegerSchema())
|
||||
.containsEntry("arg10", new JsonIntegerSchema())
|
||||
.containsEntry("arg11", JsonIntegerSchema.builder().description("biggy").build())
|
||||
.containsEntry("arg12", new JsonNumberSchema())
|
||||
.containsEntry("arg13", new JsonNumberSchema())
|
||||
.containsEntry("arg14", new JsonNumberSchema())
|
||||
.containsEntry("arg15", new JsonNumberSchema())
|
||||
.containsEntry("arg16", JsonNumberSchema.builder().description("bigger").build())
|
||||
.containsEntry("arg17", JsonArraySchema.builder().items(new JsonStringSchema()).build())
|
||||
.containsEntry("arg18", JsonArraySchema.builder().items(new JsonIntegerSchema()).build())
|
||||
.containsEntry("arg19", JsonArraySchema.builder().items(new JsonBooleanSchema()).build())
|
||||
.containsEntry("arg20", JsonArraySchema.builder().items(new JsonIntegerSchema()).build())
|
||||
.containsEntry("arg21", JsonArraySchema.builder().items(new JsonBooleanSchema()).build())
|
||||
.containsEntry("arg22", JsonArraySchema.builder().items(new JsonIntegerSchema()).build())
|
||||
.containsEntry("arg23", JsonArraySchema.builder().items(new JsonNumberSchema()).build())
|
||||
.containsEntry("arg24", JsonArraySchema.builder().items(new JsonStringSchema()).build())
|
||||
.containsEntry("arg25", JsonEnumSchema.builder().enumValues("A", "B", "C").build())
|
||||
.containsEntry("arg27", JsonIntegerSchema.builder().description("optional").build())
|
||||
.containsEntry("arg28", JsonIntegerSchema.builder().description("required").build());
|
||||
.containsEntry("arg0", JsonStringSchema.builder().description("foo").build())
|
||||
.containsEntry("arg1", new JsonBooleanSchema())
|
||||
.containsEntry("arg2", JsonBooleanSchema.builder().description("b2").build())
|
||||
.containsEntry("arg3", new JsonIntegerSchema())
|
||||
.containsEntry("arg4", new JsonIntegerSchema())
|
||||
.containsEntry("arg5", new JsonIntegerSchema())
|
||||
.containsEntry("arg6", new JsonIntegerSchema())
|
||||
.containsEntry("arg7", new JsonIntegerSchema())
|
||||
.containsEntry("arg8", new JsonIntegerSchema())
|
||||
.containsEntry("arg9", new JsonIntegerSchema())
|
||||
.containsEntry("arg10", new JsonIntegerSchema())
|
||||
.containsEntry("arg11", JsonIntegerSchema.builder().description("biggy").build())
|
||||
.containsEntry("arg12", new JsonNumberSchema())
|
||||
.containsEntry("arg13", new JsonNumberSchema())
|
||||
.containsEntry("arg14", new JsonNumberSchema())
|
||||
.containsEntry("arg15", new JsonNumberSchema())
|
||||
.containsEntry("arg16", JsonNumberSchema.builder().description("bigger").build())
|
||||
.containsEntry("arg17", JsonArraySchema.builder().items(new JsonStringSchema()).build())
|
||||
.containsEntry("arg18", JsonArraySchema.builder().items(new JsonIntegerSchema()).build())
|
||||
.containsEntry("arg19", JsonArraySchema.builder().items(new JsonBooleanSchema()).build())
|
||||
.containsEntry("arg20", JsonArraySchema.builder().items(new JsonIntegerSchema()).build())
|
||||
.containsEntry("arg21", JsonArraySchema.builder().items(new JsonBooleanSchema()).build())
|
||||
.containsEntry("arg22", JsonArraySchema.builder().items(new JsonIntegerSchema()).build())
|
||||
.containsEntry("arg23", JsonArraySchema.builder().items(new JsonNumberSchema()).build())
|
||||
.containsEntry("arg24", JsonArraySchema.builder().items(new JsonStringSchema()).build())
|
||||
.containsEntry("arg25", JsonEnumSchema.builder().enumValues("A", "B", "C").build())
|
||||
.containsEntry("arg27", JsonIntegerSchema.builder().description("optional").build())
|
||||
.containsEntry("arg28", JsonIntegerSchema.builder().description("required").build());
|
||||
|
||||
assertThat(ts.parameters().required())
|
||||
.containsExactly("arg0",
|
||||
"arg1",
|
||||
"arg2",
|
||||
"arg3",
|
||||
"arg4",
|
||||
"arg5",
|
||||
"arg6",
|
||||
"arg7",
|
||||
"arg8",
|
||||
"arg9",
|
||||
"arg10",
|
||||
"arg11",
|
||||
"arg12",
|
||||
"arg13",
|
||||
"arg14",
|
||||
"arg15",
|
||||
"arg16",
|
||||
"arg17",
|
||||
"arg18",
|
||||
"arg19",
|
||||
"arg20",
|
||||
"arg21",
|
||||
"arg22",
|
||||
"arg23",
|
||||
"arg24",
|
||||
"arg25",
|
||||
"arg26",
|
||||
// "arg27", params with @P(required = false) are optional
|
||||
"arg28"
|
||||
);
|
||||
.containsExactly("arg0",
|
||||
"arg1",
|
||||
"arg2",
|
||||
"arg3",
|
||||
"arg4",
|
||||
"arg5",
|
||||
"arg6",
|
||||
"arg7",
|
||||
"arg8",
|
||||
"arg9",
|
||||
"arg10",
|
||||
"arg11",
|
||||
"arg12",
|
||||
"arg13",
|
||||
"arg14",
|
||||
"arg15",
|
||||
"arg16",
|
||||
"arg17",
|
||||
"arg18",
|
||||
"arg19",
|
||||
"arg20",
|
||||
"arg21",
|
||||
"arg22",
|
||||
"arg23",
|
||||
"arg24",
|
||||
"arg25",
|
||||
"arg26",
|
||||
// "arg27", params with @P(required = false) are optional
|
||||
"arg28"
|
||||
);
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Customer {
|
||||
public String name;
|
||||
public Address billingAddress;
|
||||
public Address shippingAddress;
|
||||
record Customer(
|
||||
String name,
|
||||
Address billingAddress,
|
||||
Address shippingAddress) {
|
||||
}
|
||||
|
||||
public static class CustomerRegistration {
|
||||
|
@ -326,22 +323,22 @@ class ToolSpecificationsTest implements WithAssertions {
|
|||
assertThat(toolSpecification.name()).isEqualTo("registerCustomer");
|
||||
assertThat(toolSpecification.description()).isEqualTo("register a new customer");
|
||||
assertThat(toolSpecification.parameters()).isEqualTo(JsonObjectSchema.builder()
|
||||
.addProperty("arg0", JsonObjectSchema.builder()
|
||||
.addStringProperty("name")
|
||||
.addProperty("billingAddress", JsonObjectSchema.builder()
|
||||
.addStringProperty("street")
|
||||
.addStringProperty("city")
|
||||
.required("street", "city")
|
||||
.build())
|
||||
.addProperty("shippingAddress", JsonObjectSchema.builder()
|
||||
.addStringProperty("street")
|
||||
.addStringProperty("city")
|
||||
.required("street", "city")
|
||||
.build())
|
||||
.required("name", "billingAddress", "shippingAddress")
|
||||
.build())
|
||||
.required("arg0")
|
||||
.build());
|
||||
.addProperty("arg0", JsonObjectSchema.builder()
|
||||
.addStringProperty("name")
|
||||
.addProperty("billingAddress", JsonObjectSchema.builder()
|
||||
.addStringProperty("street")
|
||||
.addStringProperty("city")
|
||||
.required("street", "city")
|
||||
.build())
|
||||
.addProperty("shippingAddress", JsonObjectSchema.builder()
|
||||
.addStringProperty("street")
|
||||
.addStringProperty("city")
|
||||
.required("street", "city")
|
||||
.build())
|
||||
.required("name", "billingAddress", "shippingAddress")
|
||||
.build())
|
||||
.required("arg0")
|
||||
.build());
|
||||
assertThat(toolSpecification.toolParameters()).isNull();
|
||||
}
|
||||
}
|
|
@ -3,7 +3,10 @@ package dev.langchain4j.data.document;
|
|||
import org.assertj.core.api.WithAssertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.*;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
|
||||
|
@ -75,27 +78,27 @@ class DocumentLoaderTest implements WithAssertions {
|
|||
assertThat(document).isEqualTo(Document.from("Hello, world!", new Metadata().put("foo", "bar")));
|
||||
|
||||
assertThatExceptionOfType(RuntimeException.class)
|
||||
.isThrownBy(() -> DocumentLoader.load(new DocumentSource() {
|
||||
@Override
|
||||
public InputStream inputStream() throws IOException {
|
||||
throw new IOException("Failed to open input stream");
|
||||
}
|
||||
.isThrownBy(() -> DocumentLoader.load(new DocumentSource() {
|
||||
@Override
|
||||
public InputStream inputStream() throws IOException {
|
||||
throw new IOException("Failed to open input stream");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Metadata metadata() {
|
||||
return new Metadata();
|
||||
}
|
||||
}, new TrivialParser()))
|
||||
.withMessageContaining("Failed to load document");
|
||||
@Override
|
||||
public Metadata metadata() {
|
||||
return new Metadata();
|
||||
}
|
||||
}, new TrivialParser()))
|
||||
.withMessageContaining("Failed to load document");
|
||||
|
||||
assertThatExceptionOfType(RuntimeException.class)
|
||||
.isThrownBy(() -> DocumentLoader.load(
|
||||
source,
|
||||
inputStream -> {
|
||||
throw new RuntimeException("Failed to parse document");
|
||||
}
|
||||
.isThrownBy(() -> DocumentLoader.load(
|
||||
source,
|
||||
inputStream -> {
|
||||
throw new RuntimeException("Failed to parse document");
|
||||
}
|
||||
|
||||
))
|
||||
.withMessageContaining("Failed to load document");
|
||||
))
|
||||
.withMessageContaining("Failed to load document");
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package dev.langchain4j.internal;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
|
@ -12,8 +13,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class JsonTest {
|
||||
|
||||
|
|
|
@ -6,7 +6,11 @@ import java.util.concurrent.Callable;
|
|||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class RetryUtilsTest {
|
||||
@Test
|
||||
|
|
|
@ -9,13 +9,22 @@ import org.junit.jupiter.params.provider.MethodSource;
|
|||
import java.io.IOException;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.quoted;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.*;
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.emptyMap;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.Collections.singletonMap;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.assertj.core.api.Assertions.entry;
|
||||
|
||||
@SuppressWarnings({"ObviousNullCheck", "ConstantValue"})
|
||||
class UtilsTest {
|
||||
|
|
|
@ -6,9 +6,15 @@ import org.junit.jupiter.params.ParameterizedTest;
|
|||
import org.junit.jupiter.params.provider.NullSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureEq;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||
|
||||
@SuppressWarnings("ConstantConditions")
|
||||
class ValidationUtilsTest implements WithAssertions {
|
||||
|
|
|
@ -3,9 +3,10 @@ package dev.langchain4j.model.chat;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.SneakyThrows;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
|
@ -26,8 +27,7 @@ public class TestStreamingResponseHandler<T> implements StreamingResponseHandler
|
|||
public void onComplete(Response<T> response) {
|
||||
|
||||
String expectedTextContent = textContentBuilder.toString();
|
||||
if (response.content() instanceof AiMessage) {
|
||||
AiMessage aiMessage = (AiMessage) response.content();
|
||||
if (response.content() instanceof AiMessage aiMessage) {
|
||||
if (aiMessage.hasToolExecutionRequests()){
|
||||
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
|
||||
} else {
|
||||
|
@ -47,8 +47,14 @@ public class TestStreamingResponseHandler<T> implements StreamingResponseHandler
|
|||
futureResponse.completeExceptionally(error);
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
public Response<T> get() {
|
||||
return futureResponse.get(30, SECONDS);
|
||||
public Response<T> get() {
|
||||
try {
|
||||
return futureResponse.get(30, SECONDS);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException(e);
|
||||
} catch (ExecutionException | TimeoutException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,9 +8,7 @@ import dev.langchain4j.model.output.TokenUsage;
|
|||
import org.assertj.core.api.WithAssertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class DimensionAwareEmbeddingModelTest implements WithAssertions {
|
||||
|
|
|
@ -4,7 +4,12 @@ import org.junit.jupiter.api.Test;
|
|||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
import java.time.*;
|
||||
import java.time.Clock;
|
||||
import java.time.Instant;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
import java.time.ZoneOffset;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
package dev.langchain4j.model.input.structured;
|
||||
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class StructuredPromptProcessorTest {
|
||||
|
||||
@StructuredPrompt("Hello, my name is {{name}}")
|
||||
|
|
|
@ -30,7 +30,13 @@ import static java.util.Collections.singletonList;
|
|||
import static java.util.stream.Collectors.joining;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
class DefaultRetrievalAugmentorTest {
|
||||
|
||||
|
|
|
@ -14,7 +14,10 @@ import java.util.Map;
|
|||
import java.util.stream.Stream;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.*;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.emptyMap;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.Collections.singletonMap;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class DefaultContentAggregatorTest {
|
||||
|
|
|
@ -9,16 +9,25 @@ import org.junit.jupiter.params.ParameterizedTest;
|
|||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.*;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.emptyMap;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.Collections.singletonMap;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class ReRankingContentAggregatorTest {
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -19,7 +18,10 @@ import static java.util.Arrays.asList;
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class EmbeddingStoreContentRetrieverTest {
|
||||
|
||||
|
|
|
@ -4,7 +4,11 @@ import dev.langchain4j.data.document.Metadata;
|
|||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.web.search.*;
|
||||
import dev.langchain4j.web.search.WebSearchEngine;
|
||||
import dev.langchain4j.web.search.WebSearchInformationResult;
|
||||
import dev.langchain4j.web.search.WebSearchOrganicResult;
|
||||
import dev.langchain4j.web.search.WebSearchRequest;
|
||||
import dev.langchain4j.web.search.WebSearchResults;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -15,7 +19,12 @@ import java.util.List;
|
|||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class WebSearchContentRetrieverTest {
|
||||
|
||||
|
|
|
@ -16,7 +16,10 @@ import static dev.langchain4j.data.segment.TextSegment.textSegment;
|
|||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class EmbeddingStoreIngestorTest {
|
||||
|
||||
|
|
|
@ -4,14 +4,11 @@ import dev.langchain4j.data.embedding.Embedding;
|
|||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import org.awaitility.Awaitility;
|
||||
import org.awaitility.core.ThrowingRunnable;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.NullAndEmptySource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
|
|
|
@ -4,17 +4,12 @@ import dev.langchain4j.data.document.Metadata;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toMap;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
|
||||
import static org.mockito.ArgumentMatchers.anyList;
|
||||
|
||||
class WebSearchResultsTest {
|
||||
|
||||
|
|
|
@ -10,8 +10,11 @@ import java.util.HashMap;
|
|||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class WebSearchToolTest {
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package dev.langchain4j.store.embedding.pgvector;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -4,7 +4,6 @@ import dev.langchain4j.data.message.AiMessage;
|
|||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
@ -89,7 +88,7 @@ public class TestUtils {
|
|||
}
|
||||
|
||||
public static List<String> repeat(String s, int n) {
|
||||
val result = new ArrayList<String>();
|
||||
final var result = new ArrayList<String>();
|
||||
for (int i = 0; i < n; i++) {
|
||||
result.add(s);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue