This commit is contained in:
LangChain4j 2023-12-22 15:35:50 +01:00 committed by GitHub
parent d64f02fe24
commit 11502be4a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 7 deletions

View File

@ -18,6 +18,7 @@ import static dev.ai4j.openai4j.chat.Role.*;
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.stream.Collectors.toList;
@ -139,13 +140,8 @@ public class InternalOpenAiHelper {
public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
AssistantMessage assistantMessage = response.choices().get(0).message();
String content = assistantMessage.content();
if (content != null) {
return aiMessage(content);
}
List<ToolCall> toolCalls = assistantMessage.toolCalls();
if (toolCalls != null) {
if (!isNullOrEmpty(toolCalls)) {
List<ToolExecutionRequest> toolExecutionRequests = toolCalls.stream()
.filter(toolCall -> toolCall.type() == FUNCTION)
.map(InternalOpenAiHelper::toToolExecutionRequest)
@ -162,7 +158,7 @@ public class InternalOpenAiHelper {
return aiMessage(toolExecutionRequest);
}
throw illegalArgument("Unexpected response: " + response);
return aiMessage(assistantMessage.content());
}
private static ToolExecutionRequest toToolExecutionRequest(ToolCall toolCall) {

View File

@ -0,0 +1,103 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.chat.*;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import org.junit.jupiter.api.Test;
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
class InternalOpenAiHelperTest {
@Test
void should_return_ai_message_with_text_when_no_functions_and_tool_calls_are_present() {
// given
String messageContent = "hello";
ChatCompletionResponse response = ChatCompletionResponse.builder()
.choices(singletonList(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.content(messageContent)
.build())
.build()))
.build();
// when
AiMessage aiMessage = aiMessageFrom(response);
// then
assertThat(aiMessage.text()).contains(messageContent);
assertThat(aiMessage.toolExecutionRequests()).isNull();
}
@Test
void should_return_ai_message_with_toolExecutionRequests_when_function_is_present() {
// given
String functionName = "current_time";
String functionArguments = "{}";
ChatCompletionResponse response = ChatCompletionResponse.builder()
.choices(singletonList(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.content("unexpected text")
.functionCall(FunctionCall.builder()
.name(functionName)
.arguments(functionArguments)
.build())
.build())
.build()))
.build();
// when
AiMessage aiMessage = aiMessageFrom(response);
// then
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
.builder()
.name(functionName)
.arguments(functionArguments)
.build()
);
}
@Test
void should_return_ai_message_with_toolExecutionRequests_when_tool_calls_are_present() {
// given
String functionName = "current_time";
String functionArguments = "{}";
ChatCompletionResponse response = ChatCompletionResponse.builder()
.choices(singletonList(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.content("unexpected text")
.toolCalls(ToolCall.builder()
.type(FUNCTION)
.function(FunctionCall.builder()
.name(functionName)
.arguments(functionArguments)
.build())
.build())
.build())
.build()))
.build();
// when
AiMessage aiMessage = aiMessageFrom(response);
// then
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
.builder()
.name(functionName)
.arguments(functionArguments)
.build()
);
}
}