add maven test github action (#11)
This PR: adds a github action for running unit tests tests that require an OpenAI/HuggingFace token and hit their API are now considered integration tests (and have been renamed to end in IT) integration tests are now run through a separate goal (mvn integration-test) via the maven-failsafe-plugin to fix the PromptTemplate tests a Clock has been added to that class. Its constructor is now private: whether this is the convention we want to follow can be discussed
This commit is contained in:
parent
fb47dbd1b3
commit
d427b7ba06
|
@ -0,0 +1,22 @@
|
|||
name: Java CI
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up JDK 8
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '8'
|
||||
distribution: 'temurin'
|
||||
- name: Test
|
||||
run: mvn --batch-mode test
|
||||
|
||||
# TODO's
|
||||
# - setup integration tests
|
||||
# - these require an openAI (and hugging face, etc) token
|
||||
# - do so that they always run for commits on main
|
||||
# - make the running be manually triggered for PRs (we don't want to burn through credits)
|
|
@ -6,6 +6,7 @@ import com.github.mustachejava.MustacheFactory;
|
|||
|
||||
import java.io.StringReader;
|
||||
import java.io.StringWriter;
|
||||
import java.time.Clock;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
|
@ -21,12 +22,11 @@ public class PromptTemplate {
|
|||
private static final MustacheFactory MUSTACHE_FACTORY = new DefaultMustacheFactory();
|
||||
|
||||
private final Mustache mustache;
|
||||
private final Clock clock;
|
||||
|
||||
public PromptTemplate(String template) {
|
||||
if (isNullOrBlank(template)) {
|
||||
throw illegalArgument("Prompt template cannot be null or empty");
|
||||
}
|
||||
this.mustache = MUSTACHE_FACTORY.compile(new StringReader(template), "template");
|
||||
private PromptTemplate(Mustache mustache, Clock clock) {
|
||||
this.mustache = mustache;
|
||||
this.clock = clock;
|
||||
}
|
||||
|
||||
public Prompt apply(Object value) {
|
||||
|
@ -39,15 +39,25 @@ public class PromptTemplate {
|
|||
return Prompt.from(writer.toString());
|
||||
}
|
||||
|
||||
private static Map<String, Object> injectDateTimeVariables(Map<String, Object> variables) {
|
||||
private Map<String, Object> injectDateTimeVariables(Map<String, Object> variables) {
|
||||
Map<String, Object> variablesCopy = new HashMap<>(variables);
|
||||
variablesCopy.put("current_date", LocalDate.now());
|
||||
variablesCopy.put("current_time", LocalTime.now());
|
||||
variablesCopy.put("current_date_time", LocalDateTime.now());
|
||||
variablesCopy.put("current_date", LocalDate.now(clock));
|
||||
variablesCopy.put("current_time", LocalTime.now(clock));
|
||||
variablesCopy.put("current_date_time", LocalDateTime.now(clock));
|
||||
return variablesCopy;
|
||||
}
|
||||
|
||||
public static PromptTemplate from(String template) {
|
||||
return new PromptTemplate(template);
|
||||
return from(template, Clock.systemDefaultZone());
|
||||
}
|
||||
|
||||
public static PromptTemplate from(String template, Clock clock) {
|
||||
if (isNullOrBlank(template)) {
|
||||
throw illegalArgument("Prompt template cannot be null or empty");
|
||||
}
|
||||
return new PromptTemplate(
|
||||
MUSTACHE_FACTORY.compile(new StringReader(template), "template"),
|
||||
clock
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,9 +2,7 @@ package dev.langchain4j.model.input;
|
|||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
import java.time.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -15,7 +13,7 @@ class PromptTemplateTest {
|
|||
@Test
|
||||
void should_create_prompt_from_template_with_single_variable() {
|
||||
|
||||
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}}.");
|
||||
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}}.");
|
||||
|
||||
Prompt prompt = promptTemplate.apply("Klaus");
|
||||
|
||||
|
@ -25,7 +23,7 @@ class PromptTemplateTest {
|
|||
@Test
|
||||
void should_create_prompt_from_template_with_multiple_variables() {
|
||||
|
||||
PromptTemplate promptTemplate = new PromptTemplate("My name is {{name}} {{surname}}.");
|
||||
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{name}} {{surname}}.");
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("name", "Klaus");
|
||||
|
@ -41,7 +39,7 @@ class PromptTemplateTest {
|
|||
@Test
|
||||
void should_provide_date_automatically() {
|
||||
|
||||
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}} and today is {{current_date}}");
|
||||
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}} and today is {{current_date}}");
|
||||
|
||||
Prompt prompt = promptTemplate.apply("Klaus");
|
||||
|
||||
|
@ -51,20 +49,24 @@ class PromptTemplateTest {
|
|||
@Test
|
||||
void should_provide_time_automatically() {
|
||||
|
||||
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}} and now is {{current_time}}");
|
||||
Clock clock = Clock.fixed(Instant.now(), ZoneOffset.UTC);
|
||||
|
||||
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}} and now is {{current_time}}", clock);
|
||||
|
||||
Prompt prompt = promptTemplate.apply("Klaus");
|
||||
|
||||
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalTime.now());
|
||||
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalTime.now(clock));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_provide_date_and_time_automatically() {
|
||||
|
||||
PromptTemplate promptTemplate = new PromptTemplate("My name is {{it}} and now is {{current_date_time}}");
|
||||
Clock clock = Clock.fixed(Instant.now(), ZoneOffset.UTC);
|
||||
|
||||
PromptTemplate promptTemplate = PromptTemplate.from("My name is {{it}} and now is {{current_date_time}}", clock);
|
||||
|
||||
Prompt prompt = promptTemplate.apply("Klaus");
|
||||
|
||||
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalDateTime.now());
|
||||
assertThat(prompt.text()).isEqualTo("My name is Klaus and now is " + LocalDateTime.now(clock));
|
||||
}
|
||||
}
|
|
@ -208,6 +208,20 @@
|
|||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<!-- failsafe will be in charge of running the integration tests (everything that ends in IT) -->
|
||||
<artifactId>maven-failsafe-plugin</artifactId>
|
||||
<version>3.1.2</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>integration-test</goal>
|
||||
<goal>verify</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
|
@ -239,7 +253,6 @@
|
|||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
|
||||
</profiles>
|
||||
|
||||
</project>
|
|
@ -28,7 +28,7 @@ public class ConversationalRetrievalChain implements Chain<String, String> {
|
|||
|
||||
private static final DocumentSplitter DEFAULT_DOCUMENT_SPLITTER = new ParagraphSplitter();
|
||||
private static final EmbeddingStore<DocumentSegment> DEFAULT_EMBEDDING_STORE = new InMemoryEmbeddingStore();
|
||||
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("Answer the following question to the best of your ability: {{question}}\n\nBase your answer on the following information:\n{{information}}");
|
||||
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("Answer the following question to the best of your ability: {{question}}\n\nBase your answer on the following information:\n{{information}}");
|
||||
|
||||
private final DocumentLoader documentLoader;
|
||||
private final DocumentSplitter documentSplitter;
|
||||
|
|
|
@ -12,7 +12,7 @@ import static dev.langchain4j.model.huggingface.HuggingFaceModelName.TII_UAE_FAL
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
class HuggingFaceChatModelTest {
|
||||
class HuggingFaceChatModelIT {
|
||||
|
||||
@Test
|
||||
public void testWhenNullAccessToken() {
|
|
@ -10,7 +10,7 @@ import java.util.List;
|
|||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class HuggingFaceEmbeddingModelTest {
|
||||
class HuggingFaceEmbeddingModelIT {
|
||||
|
||||
HuggingFaceEmbeddingModel model = HuggingFaceEmbeddingModel.builder()
|
||||
.accessToken(System.getenv("HF_API_KEY"))
|
|
@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test;
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
class HuggingFaceLanguageModelTest {
|
||||
class HuggingFaceLanguageModelIT {
|
||||
|
||||
@Test
|
||||
public void testWhenNullAccessToken() {
|
|
@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test;
|
|||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class OpenAiModerationModelTest {
|
||||
class OpenAiModerationModelIT {
|
||||
|
||||
ModerationModel model = OpenAiModerationModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
|
@ -27,7 +27,7 @@ import java.util.List;
|
|||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.service.AiServicesTest.Sentiment.POSITIVE;
|
||||
import static dev.langchain4j.service.AiServicesIT.Sentiment.POSITIVE;
|
||||
import static java.time.Month.JULY;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -35,7 +35,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class AiServicesTest {
|
||||
public class AiServicesIT {
|
||||
|
||||
@Spy
|
||||
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
|
Loading…
Reference in New Issue