add maven test github action (#11)

This PR:

    adds a github action for running unit tests
    tests that require an OpenAI/HuggingFace token and hit their API are now considered integration tests (and have been renamed to end in IT)
    integration tests are now run through a separate goal (mvn integration-test) via the maven-failsafe-plugin
    to fix the PromptTemplate tests a Clock has been added to that class. Its constructor is now private: whether this is the convention we want to follow can be discussed
This commit is contained in:
Julien Perrochet 2023-07-05 21:55:49 +02:00 committed by GitHub
parent fb47dbd1b3
commit d427b7ba06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 75 additions and 28 deletions

22
.github/workflows/main.yaml vendored Normal file
View File

@ -0,0 +1,22 @@
name: Java CI
on: [push]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up JDK 8
uses: actions/setup-java@v3
with:
java-version: '8'
distribution: 'temurin'
- name: Test
run: mvn --batch-mode test
# TODO's
# - setup integration tests
# - these require an openAI (and hugging face, etc) token
# - do so that they always run for commits on main
# - make the running be manually triggered for PRs (we don't want to burn through credits)

View File

@ -6,6 +6,7 @@ import com.github.mustachejava.MustacheFactory;
import java.io.StringReader;
import java.io.StringWriter;
import java.time.Clock;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
@ -21,12 +22,11 @@ public class PromptTemplate {
private static final MustacheFactory MUSTACHE_FACTORY = new DefaultMustacheFactory();
private final Mustache mustache;
private final Clock clock;
public PromptTemplate(String template) {
if (isNullOrBlank(template)) {
throw illegalArgument("Prompt template cannot be null or empty");
}
this.mustache = MUSTACHE_FACTORY.compile(new StringReader(template), "template");
private PromptTemplate(Mustache mustache, Clock clock) {
this.mustache = mustache;
this.clock = clock;
}
public Prompt apply(Object value) {
@ -39,15 +39,25 @@ public class PromptTemplate {
return Prompt.from(writer.toString());
}
private static Map<String, Object> injectDateTimeVariables(Map<String, Object> variables) {
private Map<String, Object> injectDateTimeVariables(Map<String, Object> variables) {
Map<String, Object> variablesCopy = new HashMap<>(variables);
variablesCopy.put("current_date", LocalDate.now());
variablesCopy.put("current_time", LocalTime.now());
variablesCopy.put("current_date_time", LocalDateTime.now());
variablesCopy.put("current_date", LocalDate.now(clock));
variablesCopy.put("current_time", LocalTime.now(clock));
variablesCopy.put("current_date_time", LocalDateTime.now(clock));
return variablesCopy;
}
public static PromptTemplate from(String template) {
return new PromptTemplate(template);
return from(template, Clock.systemDefaultZone());
}
public static PromptTemplate from(String template, Clock clock) {
if (isNullOrBlank(template)) {
throw illegalArgument("Prompt template cannot be null or empty");
}
return new PromptTemplate(
MUSTACHE_FACTORY.compile(new StringReader(template), "template"),
clock
);
}
}

View File

@ -2,9 +2,7 @@ package dev.langchain4j.model.input;
import org.junit.jupiter.api.Test;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.*;
import java.util.HashMap;
import java.util.Map;
@ -15,7 +13,7 @@ class PromptTemplateTest {
@Test
void should_create_prompt_from_template_with_single_variable() {
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}}.");
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}}.");
Prompt prompt = promptTemplate.apply("Klaus");
@ -25,7 +23,7 @@ class PromptTemplateTest {
@Test
void should_create_prompt_from_template_with_multiple_variables() {
PromptTemplate promptTemplate = new PromptTemplate("My name is {{name}} {{surname}}.");
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{name}} {{surname}}.");
Map<String, Object> variables = new HashMap<>();
variables.put("name", "Klaus");
@ -41,7 +39,7 @@ class PromptTemplateTest {
@Test
void should_provide_date_automatically() {
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}} and today is {{current_date}}");
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}} and today is {{current_date}}");
Prompt prompt = promptTemplate.apply("Klaus");
@ -51,20 +49,24 @@ class PromptTemplateTest {
@Test
void should_provide_time_automatically() {
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}} and now is {{current_time}}");
Clock clock = Clock.fixed(Instant.now(), ZoneOffset.UTC);
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}} and now is {{current_time}}", clock);
Prompt prompt = promptTemplate.apply("Klaus");
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalTime.now());
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalTime.now(clock));
}
@Test
void should_provide_date_and_time_automatically() {
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}} and now is {{current_date_time}}");
Clock clock = Clock.fixed(Instant.now(), ZoneOffset.UTC);
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}} and now is {{current_date_time}}", clock);
Prompt prompt = promptTemplate.apply("Klaus");
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalDateTime.now());
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalDateTime.now(clock));
}
}

View File

@ -208,6 +208,20 @@
</executions>
</plugin>
<plugin>
<!-- failsafe will be in charge of running the integration tests (everything that ends in IT) -->
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.1.2</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
@ -239,7 +253,6 @@
</plugins>
</build>
</profile>
</profiles>
</project>

View File

@ -28,7 +28,7 @@ public class ConversationalRetrievalChain implements Chain<String, String> {
private static final DocumentSplitter DEFAULT_DOCUMENT_SPLITTER = new ParagraphSplitter();
private static final EmbeddingStore<DocumentSegment> DEFAULT_EMBEDDING_STORE = new InMemoryEmbeddingStore();
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("Answer the following question to the best of your ability: {{question}}\n\nBase your answer on the following information:\n{{information}}");
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("Answer the following question to the best of your ability: {{question}}\n\nBase your answer on the following information:\n{{information}}");
private final DocumentLoader documentLoader;
private final DocumentSplitter documentSplitter;

View File

@ -12,7 +12,7 @@ import static dev.langchain4j.model.huggingface.HuggingFaceModelName.TII_UAE_FAL
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
class HuggingFaceChatModelTest {
class HuggingFaceChatModelIT {
@Test
public void testWhenNullAccessToken() {

View File

@ -10,7 +10,7 @@ import java.util.List;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
class HuggingFaceEmbeddingModelTest {
class HuggingFaceEmbeddingModelIT {
HuggingFaceEmbeddingModel model = HuggingFaceEmbeddingModel.builder()
.accessToken(System.getenv("HF_API_KEY"))

View File

@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
class HuggingFaceLanguageModelTest {
class HuggingFaceLanguageModelIT {
@Test
public void testWhenNullAccessToken() {

View File

@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class OpenAiModerationModelTest {
class OpenAiModerationModelIT {
ModerationModel model = OpenAiModerationModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))

View File

@ -27,7 +27,7 @@ import java.util.List;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.service.AiServicesTest.Sentiment.POSITIVE;
import static dev.langchain4j.service.AiServicesIT.Sentiment.POSITIVE;
import static java.time.Month.JULY;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
@ -35,7 +35,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
public class AiServicesTest {
public class AiServicesIT {
@Spy
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()