Add support for loading documents from s3 (#197)

Adding two separate loaders that load a single document or multiple
documents from S3 respectively. They also contain different parameters
to support different configurations. However, the document type is
dependent on the current parsers that langchain4j supports, but I am
planning to help in adding more parsers in the future.
This commit is contained in:
Jansen Ang 2023-10-27 20:00:16 +08:00 committed by GitHub
parent 65ef6554b6
commit 053a35d5a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 927 additions and 0 deletions

View File

@ -17,6 +17,25 @@ public class Utils {
return string == null || string.trim().isEmpty();
}
public static boolean isNotNullOrBlank(String string) {
return !isNullOrBlank(string);
}
public static boolean areNotNullOrBlank(String... strings) {
if (strings == null || strings.length == 0) {
return false;
}
for (String string : strings) {
if (isNullOrBlank(string)) {
return false;
}
}
return true;
}
public static boolean isCollectionEmpty(Collection<?> collection) {
return collection == null || collection.isEmpty();
}

View File

@ -42,6 +42,8 @@
<elastic.version>8.9.0</elastic.version>
<jackson.version>2.12.7.1</jackson.version>
<jedis.version>5.0.0</jedis.version>
<aws.java.sdk.version>2.20.149</aws.java.sdk.version>
<testcontainers.version>1.19.0</testcontainers.version>
<netty.version>4.1.100.Final</netty.version>
</properties>
@ -264,6 +266,34 @@
<version>${jedis.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>${aws.java.sdk.version}</version>
<type>pom</type>
<scope>import</scope>
<exclusions>
<!-- Exclusion due to CWE-295 vulnerability -->
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>

View File

@ -135,6 +135,33 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
<exclusions>
<!-- Exclusion due to CWE-295 vulnerability -->
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
</exclusion>
<!-- due to CVE-2023-44487 vulnerability -->
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
</dependency>
</dependencies>
<licenses>

View File

@ -0,0 +1,176 @@
package dev.langchain4j.data.document;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import java.net.URI;
import java.net.URISyntaxException;
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
public abstract class AbstractS3Loader<T> {
protected final String bucket;
protected final String region;
protected final String endpointUrl;
protected final String profile;
protected final boolean forcePathStyle;
protected final AwsCredentials awsCredentials;
protected AbstractS3Loader(Builder builder) {
this.bucket = ensureNotBlank(builder.bucket, "bucket");
this.region = builder.region;
this.endpointUrl = builder.endpointUrl;
this.profile = builder.profile;
this.forcePathStyle = builder.forcePathStyle;
this.awsCredentials = builder.awsCredentials;
}
/**
* Initiates the loading process by configuring the AWS credentials and S3 client,
* then loads either a document or a list of documents.
*
* @return A generic object of type T, which could be a Document or a list of Documents
* @throws RuntimeException if there are issues with AWS credentials or S3 client configuration
*/
public T load() {
AwsCredentialsProvider awsCredentialsProvider = configureCredentialsProvider();
S3Client s3Client = configureS3Client(awsCredentialsProvider);
return load(s3Client);
}
private static AwsSessionCredentials toAwsSessionCredentials(AwsCredentials awsCredentials) {
return AwsSessionCredentials.create(awsCredentials.accessKeyId(), awsCredentials.secretAccessKey(), awsCredentials.sessionToken());
}
private static software.amazon.awssdk.auth.credentials.AwsCredentials toAwsCredentials(AwsCredentials awsCredentials) {
return AwsBasicCredentials.create(awsCredentials.accessKeyId(), awsCredentials.secretAccessKey());
}
protected abstract T load(S3Client s3Client);
private AwsCredentialsProvider configureCredentialsProvider() {
AwsCredentialsProvider provider = DefaultCredentialsProvider.builder().build();
if (awsCredentials != null) {
if (awsCredentials.hasAllCredentials()) {
provider = StaticCredentialsProvider.create(toAwsSessionCredentials(awsCredentials));
}else if (awsCredentials.hasAccessKeyIdAndSecretKey()) {
provider = StaticCredentialsProvider.create(toAwsCredentials(awsCredentials));
}
}
if( isNotNullOrBlank(profile) ) {
provider = ProfileCredentialsProvider.create(profile);
}
return provider;
}
private S3Client configureS3Client(AwsCredentialsProvider provider) {
S3ClientBuilder s3ClientBuilder = S3Client.builder()
.region(Region.of(isNotNullOrBlank(region) ? region : Region.US_EAST_1.id()))
.forcePathStyle(forcePathStyle)
.credentialsProvider(provider);
if (!isNullOrBlank(endpointUrl)) {
try {
s3ClientBuilder.endpointOverride(new URI(endpointUrl));
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Invalid URL: " + endpointUrl, e);
}
}
return s3ClientBuilder.build();
}
public static abstract class Builder<T extends Builder<T>> {
private String bucket;
private String region;
private String endpointUrl;
private String profile;
private boolean forcePathStyle;
private AwsCredentials awsCredentials;
/**
* Set the AWS bucket.
*
* @return The builder instance.
*/
public T bucket(String bucket) {
this.bucket = bucket;
return self();
}
/**
* Set the AWS region. Defaults to US_EAST_1
*
* @param region The AWS region.
* @return The builder instance.
*/
public T region(String region) {
this.region = region;
return self();
}
/**
* Specifies a custom endpoint URL to override the default service URL.
*
* @param endpointUrl The endpoint URL.
* @return The builder instance.
*/
public T endpointUrl(String endpointUrl) {
this.endpointUrl = endpointUrl;
return self();
}
/**
* Set the profile defined in AWS credentials. If not set, it will use the default profile.
*
* @param profile The profile defined in AWS credentials.
* @return The builder instance.
*/
public T profile(String profile) {
this.profile = profile;
return self();
}
/**
* Set the forcePathStyle. When enabled, it will use the path-style URL
*
* @param forcePathStyle The forcePathStyle.
* @return The builder instance.
*/
public T forcePathStyle(boolean forcePathStyle) {
this.forcePathStyle = forcePathStyle;
return self();
}
/**
* Set the AWS credentials. If not set, it will use the default credentials.
*
* @param awsCredentials The AWS credentials.
* @return The builder instance.
*/
public T awsCredentials(AwsCredentials awsCredentials) {
this.awsCredentials = awsCredentials;
return self();
}
public abstract AbstractS3Loader build();
protected abstract T self();
}
}

View File

@ -0,0 +1,45 @@
package dev.langchain4j.data.document;
import static dev.langchain4j.internal.Utils.areNotNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
/**
* Represents an AWS credentials object, including access key ID, secret access key, and optional session token.
*/
public class AwsCredentials {
private final String accessKeyId;
private final String secretAccessKey;
private String sessionToken;
public AwsCredentials(String accessKeyId, String secretAccessKey, String sessionToken) {
this.accessKeyId = ensureNotBlank(accessKeyId, "accessKeyId");
this.secretAccessKey = ensureNotBlank(secretAccessKey, "secretAccessKey");
this.sessionToken = sessionToken;
}
public AwsCredentials(String accessKeyId, String secretAccessKey) {
this(accessKeyId, secretAccessKey, null);
}
public String accessKeyId() {
return accessKeyId;
}
public String secretAccessKey() {
return secretAccessKey;
}
public String sessionToken() {
return sessionToken;
}
public boolean hasAccessKeyIdAndSecretKey() {
return areNotNullOrBlank(accessKeyId, secretAccessKey);
}
public boolean hasAllCredentials() {
return areNotNullOrBlank(accessKeyId, secretAccessKey, sessionToken);
}
}

View File

@ -0,0 +1,100 @@
package dev.langchain4j.data.document;
import dev.langchain4j.data.document.source.S3Source;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.*;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import static dev.langchain4j.data.document.DocumentLoaderUtils.parserFor;
/**
* S3 Directory Loader Implementation
*/
public class S3DirectoryLoader extends AbstractS3Loader<List<Document>> {
private static final Logger log = LoggerFactory.getLogger(S3DirectoryLoader.class);
private final String prefix;
private S3DirectoryLoader(Builder builder) {
super(builder);
this.prefix = builder.prefix;
}
/**
* Loads a list of documents from an S3 bucket, ignoring unsupported document types.
* If a prefix is specified, only objects with that prefix will be loaded.
*
* @param s3Client The S3 client used for the operation
* @return A list of Document objects containing the content and metadata of the S3 objects
* @throws RuntimeException if an S3 exception occurs during the operation
*/
@Override
protected List<Document> load(S3Client s3Client) {
List<Document> documents = new ArrayList<>();
ListObjectsV2Request listObjectsV2Request = ListObjectsV2Request.builder()
.bucket(bucket)
.prefix(prefix)
.build();
ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);
List<S3Object> filteredS3Objects = listObjectsV2Response.contents().stream()
.filter(s3Object -> !s3Object.key().endsWith("/") && s3Object.size() > 0)
.collect(Collectors.toList());
for (S3Object s3Object : filteredS3Objects) {
String key = s3Object.key();
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
.bucket(bucket)
.key(key)
.build();
ResponseInputStream<GetObjectResponse> inputStream = s3Client.getObject(getObjectRequest);
try {
documents.add(DocumentLoaderUtils.load(new S3Source(bucket, key, inputStream), parserFor(DocumentType.of(key))));
} catch (Exception e) {
log.warn("Failed to load document from S3", e);
}
}
return documents;
}
public static Builder builder() {
return new Builder();
}
public static final class Builder extends AbstractS3Loader.Builder<Builder> {
private String prefix = "";
/**
* Set the prefix.
*
* @param prefix Prefix.
*/
public Builder prefix(String prefix) {
this.prefix = prefix;
return this;
}
@Override
public S3DirectoryLoader build() {
return new S3DirectoryLoader(this);
}
@Override
protected Builder self() {
return this;
}
}
}

View File

@ -0,0 +1,70 @@
package dev.langchain4j.data.document;
import dev.langchain4j.data.document.source.S3Source;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;
import static dev.langchain4j.data.document.DocumentLoaderUtils.parserFor;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
/**
* S3 File Loader Implementation
*/
public class S3FileLoader extends AbstractS3Loader<Document> {
private final String key;
private S3FileLoader(Builder builder) {
super(builder);
this.key = ensureNotBlank(builder.key, "key");
}
/**
* Loads a document from an S3 bucket based on the specified key.
*
* @param s3Client The S3 client used for the operation
* @return A Document object containing the content and metadata of the S3 object
* @throws RuntimeException if an S3 exception occurs during the operation
*/
@Override
protected Document load(S3Client s3Client) {
try {
GetObjectRequest objectRequest = GetObjectRequest.builder().bucket(bucket).key(key).build();
ResponseInputStream<GetObjectResponse> inputStream = s3Client.getObject(objectRequest);
return DocumentLoaderUtils.load(new S3Source(bucket, key, inputStream), parserFor(DocumentType.of(key)));
} catch (S3Exception e) {
throw new RuntimeException("Failed to load document from s3", e);
}
}
public static Builder builder() {
return new Builder();
}
public static final class Builder extends AbstractS3Loader.Builder<Builder> {
private String key;
/**
* Set the object key.
*
* @param key Key.
*/
public Builder key(String key) {
this.key = key;
return this;
}
@Override
public S3FileLoader build() {
return new S3FileLoader(this);
}
@Override
protected Builder self() {
return this;
}
}
}

View File

@ -0,0 +1,35 @@
package dev.langchain4j.data.document.source;
import dev.langchain4j.data.document.DocumentSource;
import dev.langchain4j.data.document.Metadata;
import java.io.IOException;
import java.io.InputStream;
public class S3Source implements DocumentSource {
private static final String SOURCE = "source";
private InputStream inputStream;
private final String bucket;
private final String key;
public S3Source(String bucket, String key, InputStream inputStream) {
this.inputStream = inputStream;
this.bucket = bucket;
this.key = key;
}
@Override
public InputStream inputStream() throws IOException {
return inputStream;
}
@Override
public Metadata metadata() {
return new Metadata()
.add(SOURCE, String.format("s3://%s/%s", bucket, key));
}
}

View File

@ -0,0 +1,136 @@
package dev.langchain4j.data.document;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.utility.DockerImageName;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.core.sync.RequestBody;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3;
@Disabled("To run this test, you need a Docker-API compatible container runtime, such as using Testcontainers Cloud or installing Docker locally.")
public class S3DirectoryLoaderIT {
private LocalStackContainer s3Container;
private S3Client s3Client;
@BeforeAll
public static void setUpClass() {
System.setProperty("aws.region", "us-east-1");
}
@BeforeEach
public void setUp() {
s3Container = new LocalStackContainer(DockerImageName.parse("localstack/localstack:2.0"))
.withServices(S3)
.withEnv("DEFAULT_REGION", "us-east-1");
s3Container.start();
s3Client = S3Client.builder()
.endpointOverride(s3Container.getEndpointOverride(S3))
.build();
s3Client.createBucket(CreateBucketRequest.builder().bucket("test-bucket").build());
}
@Test
public void should_load_empty_list() {
S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder()
.bucket("test-bucket")
.endpointUrl(s3Container.getEndpointOverride(S3).toString())
.build();
List<Document> documents = s3DirectoryLoader.load();
assertTrue(documents.isEmpty());
}
@Test
public void should_load_multiple_files_without_prefix() {
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("file1.txt").build(),
RequestBody.fromString("Hello, World!"));
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file2.txt").build(),
RequestBody.fromString("Hello, again!"));
S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder()
.bucket("test-bucket")
.endpointUrl(s3Container.getEndpointOverride(S3).toString())
.build();
List<Document> documents = s3DirectoryLoader.load();
assertEquals(2, documents.size());
assertEquals("Hello, again!", documents.get(0).text());
assertEquals("s3://test-bucket/directory/file2.txt", documents.get(0).metadata("source"));
assertEquals("Hello, World!", documents.get(1).text());
assertEquals("s3://test-bucket/file1.txt", documents.get(1).metadata("source"));
}
@Test
public void should_load_multiple_files_with_prefix() {
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("other_directory/file1.txt").build(),
RequestBody.fromString("You cannot load me!"));
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file2.txt").build(),
RequestBody.fromString("Hello, World!"));
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file3.txt").build(),
RequestBody.fromString("Hello, again!"));
S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder()
.bucket("test-bucket")
.prefix("directory")
.endpointUrl(s3Container.getEndpointOverride(S3).toString())
.build();
List<Document> documents = s3DirectoryLoader.load();
assertEquals(2, documents.size());
assertEquals("Hello, World!", documents.get(0).text());
assertEquals("s3://test-bucket/directory/file2.txt", documents.get(0).metadata("source"));
assertEquals("Hello, again!", documents.get(1).text());
assertEquals("s3://test-bucket/directory/file3.txt", documents.get(1).metadata("source"));
}
@Test
public void should_load_accepting_unknown_types() {
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file2.unknown").build(),
RequestBody.fromString("I am unknown."));
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("directory/file3.txt").build(),
RequestBody.fromString("Hello, World!"));
S3DirectoryLoader s3DirectoryLoader = S3DirectoryLoader.builder()
.bucket("test-bucket")
.prefix("directory")
.endpointUrl(s3Container.getEndpointOverride(S3).toString())
.build();
List<Document> documents = s3DirectoryLoader.load();
assertEquals(2, documents.size());
assertEquals("I am unknown.", documents.get(0).text());
assertEquals("s3://test-bucket/directory/file2.unknown", documents.get(0).metadata("source"));
assertEquals("Hello, World!", documents.get(1).text());
assertEquals("s3://test-bucket/directory/file3.txt", documents.get(1).metadata("source"));
}
@AfterEach
public void tearDown() {
s3Container.stop();
}
@AfterAll
public static void tearDownClass() {
System.clearProperty("aws.region");
}
}

View File

@ -0,0 +1,118 @@
package dev.langchain4j.data.document;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.model.S3Object;
import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class S3DirectoryLoaderTest {
@Mock
private S3Client s3Client;
@Mock
private ListObjectsV2Response listObjectsV2Response;
@Mock
private GetObjectResponse getObjectResponse;
private S3DirectoryLoader s3DirectoryLoader;
@BeforeEach
public void setUp() {
s3DirectoryLoader = S3DirectoryLoader.builder()
.bucket("langchain4j")
.prefix("testPrefix")
.build();
}
@Test
public void should_load_documents_from_directory() {
S3Object s3Object1 = S3Object.builder().key("testPrefix/testKey1.txt").size(10L).build();
S3Object s3Object2 = S3Object.builder().key("testPrefix/testKey2.txt").size(20L).build();
when(listObjectsV2Response.contents()).thenReturn(Arrays.asList(s3Object1, s3Object2));
when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenReturn(listObjectsV2Response);
ResponseInputStream<GetObjectResponse> responseInputStream1 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test1".getBytes()));
ResponseInputStream<GetObjectResponse> responseInputStream2 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test2".getBytes()));
when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(responseInputStream1).thenReturn(responseInputStream2);
List<Document> documents = s3DirectoryLoader.load(s3Client);
assertEquals(2, documents.size());
assertEquals("test1", documents.get(0).text());
assertEquals("test2", documents.get(1).text());
assertEquals("s3://langchain4j/testPrefix/testKey1.txt", documents.get(0).metadata("source"));
assertEquals("s3://langchain4j/testPrefix/testKey2.txt", documents.get(1).metadata("source"));
}
@Test
public void should_load_documents_from_directory_accepting_unknown_types() {
S3Object s3Object1 = S3Object.builder().key("testPrefix/testKey1.txt").size(10L).build();
S3Object s3Object2 = S3Object.builder().key("testPrefix/testKey2.txt").size(20L).build();
S3Object s3Object3 = S3Object.builder().key("testPrefix/testKey3.unknown").size(30L).build();
when(listObjectsV2Response.contents()).thenReturn(Arrays.asList(s3Object1, s3Object2, s3Object3));
when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenReturn(listObjectsV2Response);
ResponseInputStream<GetObjectResponse> responseInputStream1 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test1".getBytes()));
ResponseInputStream<GetObjectResponse> responseInputStream2 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test2".getBytes()));
ResponseInputStream<GetObjectResponse> responseInputStream3 = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("unknown".getBytes()));
when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(responseInputStream1).thenReturn(responseInputStream2).thenReturn(responseInputStream3);
List<Document> documents = s3DirectoryLoader.load(s3Client);
assertEquals(3, documents.size());
assertEquals("test1", documents.get(0).text());
assertEquals("test2", documents.get(1).text());
assertEquals("unknown", documents.get(2).text());
assertEquals("s3://langchain4j/testPrefix/testKey1.txt", documents.get(0).metadata("source"));
assertEquals("s3://langchain4j/testPrefix/testKey2.txt", documents.get(1).metadata("source"));
assertEquals("s3://langchain4j/testPrefix/testKey3.unknown", documents.get(2).metadata("source"));
}
@Test
public void should_return_empty_list_when_no_objects() {
when(listObjectsV2Response.contents()).thenReturn(Arrays.asList());
when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenReturn(listObjectsV2Response);
List<Document> documents = s3DirectoryLoader.load(s3Client);
assertTrue(documents.isEmpty());
}
@Test
public void should_throw_s3_exception() {
when(s3Client.listObjectsV2(any(ListObjectsV2Request.class))).thenThrow(S3Exception.builder().message("S3 error").build());
assertThrows(RuntimeException.class, () -> s3DirectoryLoader.load(s3Client));
}
@Test
public void should_throw_invalid_bucket() {
assertThrows(IllegalArgumentException.class, () -> S3DirectoryLoader.builder()
.prefix("testPrefix")
.build());
}
}

View File

@ -0,0 +1,97 @@
package dev.langchain4j.data.document;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.utility.DockerImageName;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.core.sync.RequestBody;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3;
import static org.junit.jupiter.api.Assertions.*;
@Disabled("To run this test, you need a Docker-API compatible container runtime, such as using Testcontainers Cloud or installing Docker locally.")
public class S3FileLoaderIT {
private LocalStackContainer s3Container;
private S3Client s3Client;
private static final DockerImageName localstackImage = DockerImageName.parse("localstack/localstack:2.0");
@BeforeAll
public static void setUpClass() {
System.setProperty("aws.region", "us-east-1");
}
@BeforeEach
public void setUp() {
s3Container = new LocalStackContainer(localstackImage)
.withServices(S3)
.withEnv("DEFAULT_REGION", "us-east-1");
s3Container.start();
s3Client = S3Client.builder()
.endpointOverride(s3Container.getEndpointOverride(LocalStackContainer.Service.S3))
.build();
s3Client.createBucket(CreateBucketRequest.builder().bucket("test-bucket").build());
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("test-file.txt").build(),
RequestBody.fromString("Hello, World!"));
}
@Test
public void should_load_document() {
S3FileLoader s3FileLoader = S3FileLoader.builder()
.bucket("test-bucket")
.key("test-file.txt")
.endpointUrl(s3Container.getEndpointOverride(LocalStackContainer.Service.S3).toString())
.build();
Document document = s3FileLoader.load();
assertNotNull(document);
assertEquals("Hello, World!", document.text());
assertEquals("s3://test-bucket/test-file.txt", document.metadata("source"));
}
@Test
public void should_load_document_unknown_type() {
S3Client s3Client = S3Client.builder()
.endpointOverride(s3Container.getEndpointOverride(LocalStackContainer.Service.S3))
.build();
s3Client.createBucket(CreateBucketRequest.builder().bucket("test-bucket").build());
s3Client.putObject(PutObjectRequest.builder().bucket("test-bucket").key("unknown-test-file.unknown").build(),
RequestBody.fromString("Hello, World! I am Unknown"));
S3FileLoader s3FileLoader = S3FileLoader.builder()
.bucket("test-bucket")
.key("unknown-test-file.unknown")
.endpointUrl(s3Container.getEndpointOverride(LocalStackContainer.Service.S3).toString())
.build();
Document document = s3FileLoader.load();
assertNotNull(document);
assertEquals("Hello, World! I am Unknown", document.text());
assertEquals("s3://test-bucket/unknown-test-file.unknown", document.metadata("source"));
}
@AfterEach
public void tearDown() {
s3Container.stop();
}
@AfterAll
public static void tearDownClass() {
System.clearProperty("aws.region");
}
}

View File

@ -0,0 +1,73 @@
package dev.langchain4j.data.document;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;
import java.io.ByteArrayInputStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class S3FileLoaderTest {
@Mock
private S3Client s3Client;
@Mock
private GetObjectResponse getObjectResponse;
private S3FileLoader s3FileLoader;
@BeforeEach
public void setUp() {
s3FileLoader = S3FileLoader.builder()
.bucket("langchain4j")
.key("key.txt")
.build();
}
@Test
public void should_load_document() {
ResponseInputStream<GetObjectResponse> responseInputStream = new ResponseInputStream<>(getObjectResponse, new ByteArrayInputStream("test".getBytes()));
when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(responseInputStream);
Document result = s3FileLoader.load(s3Client);
assertNotNull(result);
assertEquals("test", result.text());
assertEquals("s3://langchain4j/key.txt", result.metadata("source"));
}
@Test
public void should_throw_s3_exception() {
when(s3Client.getObject(any(GetObjectRequest.class))).thenThrow(S3Exception.builder().message("S3 error").build());
assertThrows(RuntimeException.class, () -> s3FileLoader.load(s3Client));
}
@Test
public void should_throw_invalid_key() {
assertThrows(IllegalArgumentException.class, () -> S3FileLoader.builder()
.bucket("testBucket")
.build());
}
@Test
public void should_throw_invalid_bucket() {
assertThrows(IllegalArgumentException.class, () -> S3FileLoader.builder()
.key("testKey")
.build());
}
}

View File

@ -0,0 +1 @@
mock-maker-inline