Ollama: test that OpenAI API (OpenAiChatModel) works

This commit is contained in:
LangChain4j 2024-03-22 11:46:28 +01:00
parent e0b7a2816b
commit fbced4e70e
2 changed files with 61 additions and 0 deletions

View File

@ -76,6 +76,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>

View File

@ -0,0 +1,55 @@
package dev.langchain4j.model.ollama;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import static dev.langchain4j.model.ollama.OllamaImage.TINY_DOLPHIN_MODEL;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests if Ollama can be used via OpenAI API (langchain4j-open-ai module)
* See https://github.com/ollama/ollama/blob/main/docs/openai.md
*/
class OllamaOpenAiChatModelIT extends AbstractOllamaLanguageModelInfrastructure {
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey("does not matter") // TODO make apiKey optional when using custom baseUrl?
.baseUrl(ollama.getEndpoint() + "/v1") // TODO add "/v1" by default?
.modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();
@Test
void should_generate_response() {
// given
UserMessage userMessage = UserMessage.from("What is the capital of Germany?");
// when
Response<AiMessage> response = model.generate(userMessage);
System.out.println(response);
// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).contains("Berlin");
assertThat(aiMessage.toolExecutionRequests()).isNull();
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(35);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(STOP);
}
// TODO add more tests
}