parent
d64f02fe24
commit
11502be4a4
|
@ -18,6 +18,7 @@ import static dev.ai4j.openai4j.chat.Role.*;
|
|||
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
|
||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
@ -139,13 +140,8 @@ public class InternalOpenAiHelper {
|
|||
public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
|
||||
AssistantMessage assistantMessage = response.choices().get(0).message();
|
||||
|
||||
String content = assistantMessage.content();
|
||||
if (content != null) {
|
||||
return aiMessage(content);
|
||||
}
|
||||
|
||||
List<ToolCall> toolCalls = assistantMessage.toolCalls();
|
||||
if (toolCalls != null) {
|
||||
if (!isNullOrEmpty(toolCalls)) {
|
||||
List<ToolExecutionRequest> toolExecutionRequests = toolCalls.stream()
|
||||
.filter(toolCall -> toolCall.type() == FUNCTION)
|
||||
.map(InternalOpenAiHelper::toToolExecutionRequest)
|
||||
|
@ -162,7 +158,7 @@ public class InternalOpenAiHelper {
|
|||
return aiMessage(toolExecutionRequest);
|
||||
}
|
||||
|
||||
throw illegalArgument("Unexpected response: " + response);
|
||||
return aiMessage(assistantMessage.content());
|
||||
}
|
||||
|
||||
private static ToolExecutionRequest toToolExecutionRequest(ToolCall toolCall) {
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.chat.*;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class InternalOpenAiHelperTest {
|
||||
|
||||
@Test
|
||||
void should_return_ai_message_with_text_when_no_functions_and_tool_calls_are_present() {
|
||||
|
||||
// given
|
||||
String messageContent = "hello";
|
||||
|
||||
ChatCompletionResponse response = ChatCompletionResponse.builder()
|
||||
.choices(singletonList(ChatCompletionChoice.builder()
|
||||
.message(AssistantMessage.builder()
|
||||
.content(messageContent)
|
||||
.build())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
// when
|
||||
AiMessage aiMessage = aiMessageFrom(response);
|
||||
|
||||
// then
|
||||
assertThat(aiMessage.text()).contains(messageContent);
|
||||
assertThat(aiMessage.toolExecutionRequests()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_ai_message_with_toolExecutionRequests_when_function_is_present() {
|
||||
|
||||
// given
|
||||
String functionName = "current_time";
|
||||
String functionArguments = "{}";
|
||||
|
||||
ChatCompletionResponse response = ChatCompletionResponse.builder()
|
||||
.choices(singletonList(ChatCompletionChoice.builder()
|
||||
.message(AssistantMessage.builder()
|
||||
.content("unexpected text")
|
||||
.functionCall(FunctionCall.builder()
|
||||
.name(functionName)
|
||||
.arguments(functionArguments)
|
||||
.build())
|
||||
.build())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
// when
|
||||
AiMessage aiMessage = aiMessageFrom(response);
|
||||
|
||||
// then
|
||||
assertThat(aiMessage.text()).isNull();
|
||||
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
|
||||
.builder()
|
||||
.name(functionName)
|
||||
.arguments(functionArguments)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_ai_message_with_toolExecutionRequests_when_tool_calls_are_present() {
|
||||
|
||||
// given
|
||||
String functionName = "current_time";
|
||||
String functionArguments = "{}";
|
||||
|
||||
ChatCompletionResponse response = ChatCompletionResponse.builder()
|
||||
.choices(singletonList(ChatCompletionChoice.builder()
|
||||
.message(AssistantMessage.builder()
|
||||
.content("unexpected text")
|
||||
.toolCalls(ToolCall.builder()
|
||||
.type(FUNCTION)
|
||||
.function(FunctionCall.builder()
|
||||
.name(functionName)
|
||||
.arguments(functionArguments)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
// when
|
||||
AiMessage aiMessage = aiMessageFrom(response);
|
||||
|
||||
// then
|
||||
assertThat(aiMessage.text()).isNull();
|
||||
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
|
||||
.builder()
|
||||
.name(functionName)
|
||||
.arguments(functionArguments)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue