From 053a35d5a394db095ea23420da4c4760f6d1614a Mon Sep 17 00:00:00 2001 From: Jansen Ang <41965778+jmgang@users.noreply.github.com> Date: Fri, 27 Oct 2023 20:00:16 +0800 Subject: [PATCH] Add support for loading documents from s3 (#197) Adding two separate loaders that load a single document or multiple documents from S3 respectively. They also contain different parameters to support different configurations. However, the document type is dependent on the current parsers that langchain4j supports, but I am planning to help in adding more parsers in the future. --- .../java/dev/langchain4j/internal/Utils.java | 19 ++ langchain4j-parent/pom.xml | 30 +++ langchain4j/pom.xml | 27 +++ .../data/document/AbstractS3Loader.java | 176 ++++++++++++++++++ .../data/document/AwsCredentials.java | 45 +++++ .../data/document/S3DirectoryLoader.java | 100 ++++++++++ .../data/document/S3FileLoader.java | 70 +++++++ .../data/document/source/S3Source.java | 35 ++++ .../data/document/S3DirectoryLoaderIT.java | 136 ++++++++++++++ .../data/document/S3DirectoryLoaderTest.java | 118 ++++++++++++ .../data/document/S3FileLoaderIT.java | 97 ++++++++++ .../data/document/S3FileLoaderTest.java | 73 ++++++++ .../org.mockito.plugins.MockMaker | 1 + 13 files changed, 927 insertions(+) create mode 100644 langchain4j/src/main/java/dev/langchain4j/data/document/AbstractS3Loader.java create mode 100644 langchain4j/src/main/java/dev/langchain4j/data/document/AwsCredentials.java create mode 100644 langchain4j/src/main/java/dev/langchain4j/data/document/S3DirectoryLoader.java create mode 100644 langchain4j/src/main/java/dev/langchain4j/data/document/S3FileLoader.java create mode 100644 langchain4j/src/main/java/dev/langchain4j/data/document/source/S3Source.java create mode 100644 langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderIT.java create mode 100644 langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderTest.java create mode 100644 langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderIT.java create mode 100644 langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderTest.java create mode 100644 langchain4j/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java index 100630459..0db880ba9 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java @@ -17,6 +17,25 @@ public class Utils { return string == null || string.trim().isEmpty(); } + public static boolean isNotNullOrBlank(String string) { + return !isNullOrBlank(string); + } + + public static boolean areNotNullOrBlank(String... strings) { + if (strings == null || strings.length == 0) { + return false; + } + + for (String string : strings) { + if (isNullOrBlank(string)) { + return false; + } + } + + return true; + } + + public static boolean isCollectionEmpty(Collection collection) { return collection == null || collection.isEmpty(); } diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml index ad2f2f558..4562e7089 100644 --- a/langchain4j-parent/pom.xml +++ b/langchain4j-parent/pom.xml @@ -42,6 +42,8 @@ 8.9.0 2.12.7.1 5.0.0 + 2.20.149 + 1.19.0 4.1.100.Final @@ -264,6 +266,34 @@ ${jedis.version} + + software.amazon.awssdk + bom + ${aws.java.sdk.version} + pom + import + + + + io.netty + netty-handler + + + + + + org.testcontainers + testcontainers + ${testcontainers.version} + test + + + + org.testcontainers + localstack + ${testcontainers.version} + test + diff --git a/langchain4j/pom.xml b/langchain4j/pom.xml index d0d41a402..2ef0582c8 100644 --- a/langchain4j/pom.xml +++ b/langchain4j/pom.xml @@ -135,6 +135,33 @@ test + + software.amazon.awssdk + s3 + + + + io.netty + netty-handler + + + + io.netty + netty-codec-http2 + + + + + + org.testcontainers + testcontainers + + + + org.testcontainers + localstack + + diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/AbstractS3Loader.java b/langchain4j/src/main/java/dev/langchain4j/data/document/AbstractS3Loader.java new file mode 100644 index 000000000..7f86db6ef --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/AbstractS3Loader.java @@ -0,0 +1,176 @@ +package dev.langchain4j.data.document; + + +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; + +import java.net.URI; +import java.net.URISyntaxException; + +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + +public abstract class AbstractS3Loader { + protected final String bucket; + protected final String region; + protected final String endpointUrl; + protected final String profile; + + protected final boolean forcePathStyle; + protected final AwsCredentials awsCredentials; + + protected AbstractS3Loader(Builder builder) { + this.bucket = ensureNotBlank(builder.bucket, "bucket"); + this.region = builder.region; + this.endpointUrl = builder.endpointUrl; + this.profile = builder.profile; + this.forcePathStyle = builder.forcePathStyle; + this.awsCredentials = builder.awsCredentials; + } + + /** + * Initiates the loading process by configuring the AWS credentials and S3 client, + * then loads either a document or a list of documents. + * + * @return A generic object of type T, which could be a Document or a list of Documents + * @throws RuntimeException if there are issues with AWS credentials or S3 client configuration + */ + public T load() { + AwsCredentialsProvider awsCredentialsProvider = configureCredentialsProvider(); + S3Client s3Client = configureS3Client(awsCredentialsProvider); + return load(s3Client); + } + + private static AwsSessionCredentials toAwsSessionCredentials(AwsCredentials awsCredentials) { + return AwsSessionCredentials.create(awsCredentials.accessKeyId(), awsCredentials.secretAccessKey(), awsCredentials.sessionToken()); + } + + private static software.amazon.awssdk.auth.credentials.AwsCredentials toAwsCredentials(AwsCredentials awsCredentials) { + return AwsBasicCredentials.create(awsCredentials.accessKeyId(), awsCredentials.secretAccessKey()); + } + + protected abstract T load(S3Client s3Client); + + private AwsCredentialsProvider configureCredentialsProvider() { + AwsCredentialsProvider provider = DefaultCredentialsProvider.builder().build(); + + if (awsCredentials != null) { + if (awsCredentials.hasAllCredentials()) { + provider = StaticCredentialsProvider.create(toAwsSessionCredentials(awsCredentials)); + }else if (awsCredentials.hasAccessKeyIdAndSecretKey()) { + provider = StaticCredentialsProvider.create(toAwsCredentials(awsCredentials)); + } + } + + if( isNotNullOrBlank(profile) ) { + provider = ProfileCredentialsProvider.create(profile); + } + + return provider; + } + + private S3Client configureS3Client(AwsCredentialsProvider provider) { + S3ClientBuilder s3ClientBuilder = S3Client.builder() + .region(Region.of(isNotNullOrBlank(region) ? region : Region.US_EAST_1.id())) + .forcePathStyle(forcePathStyle) + .credentialsProvider(provider); + + if (!isNullOrBlank(endpointUrl)) { + try { + s3ClientBuilder.endpointOverride(new URI(endpointUrl)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Invalid URL: " + endpointUrl, e); + } + } + + return s3ClientBuilder.build(); + } + + public static abstract class Builder> { + private String bucket; + private String region; + private String endpointUrl; + private String profile; + + private boolean forcePathStyle; + private AwsCredentials awsCredentials; + + /** + * Set the AWS bucket. + * + * @return The builder instance. + */ + public T bucket(String bucket) { + this.bucket = bucket; + return self(); + } + + /** + * Set the AWS region. Defaults to US_EAST_1 + * + * @param region The AWS region. + * @return The builder instance. + */ + public T region(String region) { + this.region = region; + return self(); + } + + /** + * Specifies a custom endpoint URL to override the default service URL. + * + * @param endpointUrl The endpoint URL. + * @return The builder instance. + */ + public T endpointUrl(String endpointUrl) { + this.endpointUrl = endpointUrl; + return self(); + } + + /** + * Set the profile defined in AWS credentials. If not set, it will use the default profile. + * + * @param profile The profile defined in AWS credentials. + * @return The builder instance. + */ + public T profile(String profile) { + this.profile = profile; + return self(); + } + + /** + * Set the forcePathStyle. When enabled, it will use the path-style URL + * + * @param forcePathStyle The forcePathStyle. + * @return The builder instance. + */ + public T forcePathStyle(boolean forcePathStyle) { + this.forcePathStyle = forcePathStyle; + return self(); + } + + /** + * Set the AWS credentials. If not set, it will use the default credentials. + * + * @param awsCredentials The AWS credentials. + * @return The builder instance. + */ + public T awsCredentials(AwsCredentials awsCredentials) { + this.awsCredentials = awsCredentials; + return self(); + } + + public abstract AbstractS3Loader build(); + + protected abstract T self(); + } +} + diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/AwsCredentials.java b/langchain4j/src/main/java/dev/langchain4j/data/document/AwsCredentials.java new file mode 100644 index 000000000..988eae11c --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/AwsCredentials.java @@ -0,0 +1,45 @@ +package dev.langchain4j.data.document; + +import static dev.langchain4j.internal.Utils.areNotNullOrBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + +/** + * Represents an AWS credentials object, including access key ID, secret access key, and optional session token. + */ +public class AwsCredentials { + + private final String accessKeyId; + private final String secretAccessKey; + private String sessionToken; + + public AwsCredentials(String accessKeyId, String secretAccessKey, String sessionToken) { + this.accessKeyId = ensureNotBlank(accessKeyId, "accessKeyId"); + this.secretAccessKey = ensureNotBlank(secretAccessKey, "secretAccessKey"); + this.sessionToken = sessionToken; + } + + public AwsCredentials(String accessKeyId, String secretAccessKey) { + this(accessKeyId, secretAccessKey, null); + } + + + public String accessKeyId() { + return accessKeyId; + } + + public String secretAccessKey() { + return secretAccessKey; + } + + public String sessionToken() { + return sessionToken; + } + + public boolean hasAccessKeyIdAndSecretKey() { + return areNotNullOrBlank(accessKeyId, secretAccessKey); + } + + public boolean hasAllCredentials() { + return areNotNullOrBlank(accessKeyId, secretAccessKey, sessionToken); + } +} diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/S3DirectoryLoader.java b/langchain4j/src/main/java/dev/langchain4j/data/document/S3DirectoryLoader.java new file mode 100644 index 000000000..250802a7e --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/S3DirectoryLoader.java @@ -0,0 +1,100 @@ +package dev.langchain4j.data.document; + +import dev.langchain4j.data.document.source.S3Source; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.data.document.DocumentLoaderUtils.parserFor; +/** + * S3 Directory Loader Implementation + */ +public class S3DirectoryLoader extends AbstractS3Loader> { + + private static final Logger log = LoggerFactory.getLogger(S3DirectoryLoader.class); + + private final String prefix; + + private S3DirectoryLoader(Builder builder) { + super(builder); + this.prefix = builder.prefix; + } + + + /** + * Loads a list of documents from an S3 bucket, ignoring unsupported document types. + * If a prefix is specified, only objects with that prefix will be loaded. + * + * @param s3Client The S3 client used for the operation + * @return A list of Document objects containing the content and metadata of the S3 objects + * @throws RuntimeException if an S3 exception occurs during the operation + */ + @Override + protected List load(S3Client s3Client) { + List documents = new ArrayList<>(); + + ListObjectsV2Request listObjectsV2Request = ListObjectsV2Request.builder() + .bucket(bucket) + .prefix(prefix) + .build(); + + ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request); + List filteredS3Objects = listObjectsV2Response.contents().stream() + .filter(s3Object -> !s3Object.key().endsWith("/") && s3Object.size() > 0) + .collect(Collectors.toList()); + + for (S3Object s3Object : filteredS3Objects) { + String key = s3Object.key(); + + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(bucket) + .key(key) + .build(); + + ResponseInputStream inputStream = s3Client.getObject(getObjectRequest); + + try { + documents.add(DocumentLoaderUtils.load(new S3Source(bucket, key, inputStream), parserFor(DocumentType.of(key)))); + } catch (Exception e) { + log.warn("Failed to load document from S3", e); + } + } + + return documents; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder extends AbstractS3Loader.Builder { + private String prefix = ""; + + /** + * Set the prefix. + * + * @param prefix Prefix. + */ + public Builder prefix(String prefix) { + this.prefix = prefix; + return this; + } + + @Override + public S3DirectoryLoader build() { + return new S3DirectoryLoader(this); + } + + @Override + protected Builder self() { + return this; + } + } +} + diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/S3FileLoader.java b/langchain4j/src/main/java/dev/langchain4j/data/document/S3FileLoader.java new file mode 100644 index 000000000..9c9915bab --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/S3FileLoader.java @@ -0,0 +1,70 @@ +package dev.langchain4j.data.document; + +import dev.langchain4j.data.document.source.S3Source; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + + +import static dev.langchain4j.data.document.DocumentLoaderUtils.parserFor; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + +/** + * S3 File Loader Implementation + */ +public class S3FileLoader extends AbstractS3Loader { + private final String key; + + private S3FileLoader(Builder builder) { + super(builder); + this.key = ensureNotBlank(builder.key, "key"); + } + + /** + * Loads a document from an S3 bucket based on the specified key. + * + * @param s3Client The S3 client used for the operation + * @return A Document object containing the content and metadata of the S3 object + * @throws RuntimeException if an S3 exception occurs during the operation + */ + @Override + protected Document load(S3Client s3Client) { + try { + GetObjectRequest objectRequest = GetObjectRequest.builder().bucket(bucket).key(key).build(); + ResponseInputStream inputStream = s3Client.getObject(objectRequest); + return DocumentLoaderUtils.load(new S3Source(bucket, key, inputStream), parserFor(DocumentType.of(key))); + } catch (S3Exception e) { + throw new RuntimeException("Failed to load document from s3", e); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder extends AbstractS3Loader.Builder { + private String key; + + /** + * Set the object key. + * + * @param key Key. + */ + public Builder key(String key) { + this.key = key; + return this; + } + + @Override + public S3FileLoader build() { + return new S3FileLoader(this); + } + + @Override + protected Builder self() { + return this; + } + } +} diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/source/S3Source.java b/langchain4j/src/main/java/dev/langchain4j/data/document/source/S3Source.java new file mode 100644 index 000000000..b3a577640 --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/source/S3Source.java @@ -0,0 +1,35 @@ +package dev.langchain4j.data.document.source; + +import dev.langchain4j.data.document.DocumentSource; +import dev.langchain4j.data.document.Metadata; + +import java.io.IOException; +import java.io.InputStream; + +public class S3Source implements DocumentSource { + + private static final String SOURCE = "source"; + + private InputStream inputStream; + + private final String bucket; + + private final String key; + + public S3Source(String bucket, String key, InputStream inputStream) { + this.inputStream = inputStream; + this.bucket = bucket; + this.key = key; + } + + @Override + public InputStream inputStream() throws IOException { + return inputStream; + } + + @Override + public Metadata metadata() { + return new Metadata() + .add(SOURCE, String.format("s3://%s/%s", bucket, key)); + } +} diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderIT.java b/langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderIT.java new file mode 100644 index 000000000..7320923f8 --- /dev/null +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderIT.java @@ -0,0 +1,136 @@ +package dev.langchain4j.data.document; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.utility.DockerImageName; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.core.sync.RequestBody; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; + +@Disabled("To run this test, you need a Docker-API compatible container runtime, such as using Testcontainers Cloud or installing Docker locally.") +public class S3DirectoryLoaderIT { + + private LocalStackContainer s3Container; + private S3Client s3Client; + + @BeforeAll + public static void setUpClass() { + System.setProperty("aws.region", "us-east-1"); + } + + @BeforeEach + public void setUp() { + s3Container = new LocalStackContainer(DockerImageName.parse("localstack/localstack:2.0")) + .withServices(S3) + .withEnv("DEFAULT_REGION", "us-east-1"); + s3Container.start(); + + s3Client = S3Client.builder() + .endpointOverride(s3Container.getEndpointOverride(S3)) + .build(); + + s3Client.createBucket(CreateBucketRequest.builder().bucket("test-bucket").build()); + } + + @Test + public void should_load_empty_list() { + S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder() + .bucket("test-bucket") + .endpointUrl(s3Container.getEndpointOverride(S3).toString()) + .build(); + + List documents = s3DirectoryLoader.load(); + + assertTrue(documents.isEmpty()); + } + + @Test + public void should_load_multiple_files_without_prefix() { + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("file1.txt").build(), + RequestBody.fromString("Hello, World!")); + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file2.txt").build(), + RequestBody.fromString("Hello, again!")); + + S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder() + .bucket("test-bucket") + .endpointUrl(s3Container.getEndpointOverride(S3).toString()) + .build(); + + List documents = s3DirectoryLoader.load(); + + assertEquals(2, documents.size()); + assertEquals("Hello, again!", documents.get(0).text()); + assertEquals("s3://test-bucket/directory/file2.txt", documents.get(0).metadata("source")); + assertEquals("Hello, World!", documents.get(1).text()); + assertEquals("s3://test-bucket/file1.txt", documents.get(1).metadata("source")); + } + + @Test + public void should_load_multiple_files_with_prefix() { + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("other_directory/file1.txt").build(), + RequestBody.fromString("You cannot load me!")); + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file2.txt").build(), + RequestBody.fromString("Hello, World!")); + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file3.txt").build(), + RequestBody.fromString("Hello, again!")); + + S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder() + .bucket("test-bucket") + .prefix("directory") + .endpointUrl(s3Container.getEndpointOverride(S3).toString()) + .build(); + + List documents = s3DirectoryLoader.load(); + + assertEquals(2, documents.size()); + assertEquals("Hello, World!", documents.get(0).text()); + assertEquals("s3://test-bucket/directory/file2.txt", documents.get(0).metadata("source")); + assertEquals("Hello, again!", documents.get(1).text()); + assertEquals("s3://test-bucket/directory/file3.txt", documents.get(1).metadata("source")); + } + + @Test + public void should_load_accepting_unknown_types() { + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file2.unknown").build(), + RequestBody.fromString("I am unknown.")); + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file3.txt").build(), + RequestBody.fromString("Hello, World!")); + + S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder() + .bucket("test-bucket") + .prefix("directory") + .endpointUrl(s3Container.getEndpointOverride(S3).toString()) + .build(); + + List documents = s3DirectoryLoader.load(); + + assertEquals(2, documents.size()); + assertEquals("I am unknown.", documents.get(0).text()); + assertEquals("s3://test-bucket/directory/file2.unknown", documents.get(0).metadata("source")); + assertEquals("Hello, World!", documents.get(1).text()); + assertEquals("s3://test-bucket/directory/file3.txt", documents.get(1).metadata("source")); + } + + @AfterEach + public void tearDown() { + s3Container.stop(); + } + + @AfterAll + public static void tearDownClass() { + System.clearProperty("aws.region"); + } +} + diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderTest.java new file mode 100644 index 000000000..0486cde00 --- /dev/null +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/S3DirectoryLoaderTest.java @@ -0,0 +1,118 @@ +package dev.langchain4j.data.document; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.S3Object; + +import java.io.ByteArrayInputStream; +import java.util.Arrays; +import java.util.List; + + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class S3DirectoryLoaderTest { + + @Mock + private S3Client s3Client; + + @Mock + private ListObjectsV2Response listObjectsV2Response; + + @Mock + private GetObjectResponse getObjectResponse; + + private S3DirectoryLoader s3DirectoryLoader; + + @BeforeEach + public void setUp() { + s3DirectoryLoader = S3DirectoryLoader.builder() + .bucket("langchain4j") + .prefix("testPrefix") + .build(); + } + + @Test + public void should_load_documents_from_directory() { + S3Object s3Object1 = S3Object.builder().key("testPrefix/testKey1.txt").size(10L).build(); + S3Object s3Object2 = S3Object.builder().key("testPrefix/testKey2.txt").size(20L).build(); + when(listObjectsV2Response.contents()).thenReturn(Arrays.asList(s3Object1, s3Object2)); + when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenReturn(listObjectsV2Response); + + ResponseInputStream responseInputStream1 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test1".getBytes())); + ResponseInputStream responseInputStream2 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test2".getBytes())); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(responseInputStream1).thenReturn(responseInputStream2); + + List documents = s3DirectoryLoader.load(s3Client); + + assertEquals(2, documents.size()); + assertEquals("test1", documents.get(0).text()); + assertEquals("test2", documents.get(1).text()); + assertEquals("s3://langchain4j/testPrefix/testKey1.txt", documents.get(0).metadata("source")); + assertEquals("s3://langchain4j/testPrefix/testKey2.txt", documents.get(1).metadata("source")); + } + + @Test + public void should_load_documents_from_directory_accepting_unknown_types() { + S3Object s3Object1 = S3Object.builder().key("testPrefix/testKey1.txt").size(10L).build(); + S3Object s3Object2 = S3Object.builder().key("testPrefix/testKey2.txt").size(20L).build(); + S3Object s3Object3 = S3Object.builder().key("testPrefix/testKey3.unknown").size(30L).build(); + + when(listObjectsV2Response.contents()).thenReturn(Arrays.asList(s3Object1, s3Object2, s3Object3)); + when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenReturn(listObjectsV2Response); + + ResponseInputStream responseInputStream1 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test1".getBytes())); + ResponseInputStream responseInputStream2 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test2".getBytes())); + ResponseInputStream responseInputStream3 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("unknown".getBytes())); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(responseInputStream1).thenReturn(responseInputStream2).thenReturn(responseInputStream3); + + List documents = s3DirectoryLoader.load(s3Client); + + assertEquals(3, documents.size()); + assertEquals("test1", documents.get(0).text()); + assertEquals("test2", documents.get(1).text()); + assertEquals("unknown", documents.get(2).text()); + assertEquals("s3://langchain4j/testPrefix/testKey1.txt", documents.get(0).metadata("source")); + assertEquals("s3://langchain4j/testPrefix/testKey2.txt", documents.get(1).metadata("source")); + assertEquals("s3://langchain4j/testPrefix/testKey3.unknown", documents.get(2).metadata("source")); + } + + @Test + public void should_return_empty_list_when_no_objects() { + when(listObjectsV2Response.contents()).thenReturn(Arrays.asList()); + when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenReturn(listObjectsV2Response); + + List documents = s3DirectoryLoader.load(s3Client); + + assertTrue(documents.isEmpty()); + } + + @Test + public void should_throw_s3_exception() { + when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenThrow(S3Exception.builder().message("S3 error").build()); + + assertThrows(RuntimeException.class, () -> s3DirectoryLoader.load(s3Client)); + } + + @Test + public void should_throw_invalid_bucket() { + assertThrows(IllegalArgumentException.class, () -> S3DirectoryLoader.builder() + .prefix("testPrefix") + .build()); + } +} diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderIT.java b/langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderIT.java new file mode 100644 index 000000000..8a71ca6cc --- /dev/null +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderIT.java @@ -0,0 +1,97 @@ +package dev.langchain4j.data.document; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.utility.DockerImageName; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.core.sync.RequestBody; + +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; +import static org.junit.jupiter.api.Assertions.*; + +@Disabled("To run this test, you need a Docker-API compatible container runtime, such as using Testcontainers Cloud or installing Docker locally.") +public class S3FileLoaderIT { + + private LocalStackContainer s3Container; + + private S3Client s3Client; + + private static final DockerImageName localstackImage = DockerImageName.parse("localstack/localstack:2.0"); + + @BeforeAll + public static void setUpClass() { + System.setProperty("aws.region", "us-east-1"); + } + + @BeforeEach + public void setUp() { + s3Container = new LocalStackContainer(localstackImage) + .withServices(S3) + .withEnv("DEFAULT_REGION", "us-east-1"); + s3Container.start(); + + s3Client = S3Client.builder() + .endpointOverride(s3Container.getEndpointOverride(LocalStackContainer.Service.S3)) + .build(); + + s3Client.createBucket(CreateBucketRequest.builder().bucket("test-bucket").build()); + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("test-file.txt").build(), + RequestBody.fromString("Hello, World!")); + } + + @Test + public void should_load_document() { + S3FileLoader s3FileLoader = S3FileLoader.builder() + .bucket("test-bucket") + .key("test-file.txt") + .endpointUrl(s3Container.getEndpointOverride(LocalStackContainer.Service.S3).toString()) + .build(); + + Document document = s3FileLoader.load(); + + assertNotNull(document); + assertEquals("Hello, World!", document.text()); + assertEquals("s3://test-bucket/test-file.txt", document.metadata("source")); + } + + @Test + public void should_load_document_unknown_type() { + S3Client s3Client = S3Client.builder() + .endpointOverride(s3Container.getEndpointOverride(LocalStackContainer.Service.S3)) + .build(); + + s3Client.createBucket(CreateBucketRequest.builder().bucket("test-bucket").build()); + s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("unknown-test-file.unknown").build(), + RequestBody.fromString("Hello, World! I am Unknown")); + + S3FileLoader s3FileLoader = S3FileLoader.builder() + .bucket("test-bucket") + .key("unknown-test-file.unknown") + .endpointUrl(s3Container.getEndpointOverride(LocalStackContainer.Service.S3).toString()) + .build(); + + Document document = s3FileLoader.load(); + + assertNotNull(document); + assertEquals("Hello, World! I am Unknown", document.text()); + assertEquals("s3://test-bucket/unknown-test-file.unknown", document.metadata("source")); + } + + @AfterEach + public void tearDown() { + s3Container.stop(); + } + + @AfterAll + public static void tearDownClass() { + System.clearProperty("aws.region"); + } + +} diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderTest.java new file mode 100644 index 000000000..1ec2a8e15 --- /dev/null +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/S3FileLoaderTest.java @@ -0,0 +1,73 @@ +package dev.langchain4j.data.document; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.io.ByteArrayInputStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class S3FileLoaderTest { + + @Mock + private S3Client s3Client; + + @Mock + private GetObjectResponse getObjectResponse; + + private S3FileLoader s3FileLoader; + + @BeforeEach + public void setUp() { + s3FileLoader = S3FileLoader.builder() + .bucket("langchain4j") + .key("key.txt") + .build(); + } + + @Test + public void should_load_document() { + ResponseInputStream responseInputStream = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test".getBytes())); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(responseInputStream); + + Document result = s3FileLoader.load(s3Client); + + assertNotNull(result); + assertEquals("test", result.text()); + assertEquals("s3://langchain4j/key.txt", result.metadata("source")); + } + + @Test + public void should_throw_s3_exception() { + when(s3Client.getObject(any(GetObjectRequest.class))).thenThrow(S3Exception.builder().message("S3 error").build()); + + assertThrows(RuntimeException.class, () -> s3FileLoader.load(s3Client)); + } + + @Test + public void should_throw_invalid_key() { + assertThrows(IllegalArgumentException.class, () -> S3FileLoader.builder() + .bucket("testBucket") + .build()); + } + + @Test + public void should_throw_invalid_bucket() { + assertThrows(IllegalArgumentException.class, () -> S3FileLoader.builder() + .key("testKey") + .build()); + } +} diff --git a/langchain4j/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/langchain4j/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 000000000..ca6ee9cea --- /dev/null +++ b/langchain4j/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline \ No newline at end of file