Azure OpenAI: migrate to the new API, from "function" to "tools" (#529)
As discussed in #521 there's a new API to call functions, which is documented at https://platform.openai.com/docs/api-reference/chat/create Fix #521
This commit is contained in:
parent
0efda674f9
commit
a5b168f061
|
@ -24,6 +24,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||||
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*;
|
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*;
|
||||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
|
@ -244,11 +245,12 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima
|
||||||
.setSeed(seed)
|
.setSeed(seed)
|
||||||
.setResponseFormat(responseFormat);
|
.setResponseFormat(responseFormat);
|
||||||
|
|
||||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
|
||||||
options.setFunctions(toFunctions(toolSpecifications));
|
|
||||||
}
|
|
||||||
if (toolThatMustBeExecuted != null) {
|
if (toolThatMustBeExecuted != null) {
|
||||||
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
|
options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted)));
|
||||||
|
options.setToolChoice(toToolChoice(toolThatMustBeExecuted));
|
||||||
|
}
|
||||||
|
if (!isNullOrEmpty(toolSpecifications)) {
|
||||||
|
options.setTools(toToolDefinitions(toolSpecifications));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -274,17 +274,14 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
|
||||||
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
|
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
|
||||||
|
|
||||||
if (toolThatMustBeExecuted != null) {
|
if (toolThatMustBeExecuted != null) {
|
||||||
options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted)));
|
options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted)));
|
||||||
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
|
options.setToolChoice(toToolChoice(toolThatMustBeExecuted));
|
||||||
if (tokenizer != null) {
|
|
||||||
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
|
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
|
||||||
}
|
}
|
||||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
if (!isNullOrEmpty(toolSpecifications)) {
|
||||||
options.setFunctions(toFunctions(toolSpecifications));
|
options.setTools(toToolDefinitions(toolSpecifications));
|
||||||
if (tokenizer != null) {
|
|
||||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
|
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount);
|
AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount);
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package dev.langchain4j.model.azure;
|
package dev.langchain4j.model.azure;
|
||||||
|
|
||||||
import com.azure.ai.openai.OpenAIClient;
|
import com.azure.ai.openai.OpenAIClient;
|
||||||
import com.azure.ai.openai.models.*;
|
import com.azure.ai.openai.models.Choice;
|
||||||
|
import com.azure.ai.openai.models.Completions;
|
||||||
|
import com.azure.ai.openai.models.CompletionsOptions;
|
||||||
import com.azure.core.credential.KeyCredential;
|
import com.azure.core.credential.KeyCredential;
|
||||||
import com.azure.core.credential.TokenCredential;
|
import com.azure.core.credential.TokenCredential;
|
||||||
import com.azure.core.exception.HttpResponseException;
|
import com.azure.core.exception.HttpResponseException;
|
||||||
|
|
|
@ -6,11 +6,16 @@ import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.Tokenizer;
|
import dev.langchain4j.model.Tokenizer;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom;
|
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class needs to be thread safe because it is called when a streaming result comes back
|
* This class needs to be thread safe because it is called when a streaming result comes back
|
||||||
|
@ -19,9 +24,13 @@ import static java.util.Collections.singletonList;
|
||||||
*/
|
*/
|
||||||
class AzureOpenAiStreamingResponseBuilder {
|
class AzureOpenAiStreamingResponseBuilder {
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingResponseBuilder.class);
|
||||||
|
|
||||||
private final StringBuffer contentBuilder = new StringBuffer();
|
private final StringBuffer contentBuilder = new StringBuffer();
|
||||||
private final StringBuffer toolNameBuilder = new StringBuffer();
|
private final StringBuffer toolNameBuilder = new StringBuffer();
|
||||||
private final StringBuffer toolArgumentsBuilder = new StringBuffer();
|
private final StringBuffer toolArgumentsBuilder = new StringBuffer();
|
||||||
|
private String toolExecutionsIndex = "call_undefined";
|
||||||
|
private final Map<String, ToolExecutionRequestBuilder> toolExecutionRequestBuilderHashMap = new HashMap<>();
|
||||||
private volatile CompletionsFinishReason finishReason;
|
private volatile CompletionsFinishReason finishReason;
|
||||||
|
|
||||||
private final Integer inputTokenCount;
|
private final Integer inputTokenCount;
|
||||||
|
@ -61,14 +70,29 @@ class AzureOpenAiStreamingResponseBuilder {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionCall functionCall = delta.getFunctionCall();
|
if (delta.getToolCalls() != null && !delta.getToolCalls().isEmpty()) {
|
||||||
if (functionCall != null) {
|
for (ChatCompletionsToolCall toolCall : delta.getToolCalls()) {
|
||||||
if (functionCall.getName() != null) {
|
ToolExecutionRequestBuilder toolExecutionRequestBuilder;
|
||||||
toolNameBuilder.append(functionCall.getName());
|
if (toolCall.getId() != null) {
|
||||||
|
toolExecutionsIndex = toolCall.getId();
|
||||||
|
toolExecutionRequestBuilder = new ToolExecutionRequestBuilder();
|
||||||
|
toolExecutionRequestBuilder.idBuilder.append(toolExecutionsIndex);
|
||||||
|
toolExecutionRequestBuilderHashMap.put(toolExecutionsIndex, toolExecutionRequestBuilder);
|
||||||
|
} else {
|
||||||
|
toolExecutionRequestBuilder = toolExecutionRequestBuilderHashMap.get(toolExecutionsIndex);
|
||||||
|
if (toolExecutionRequestBuilder == null) {
|
||||||
|
throw new IllegalStateException("Function without an id defined in the tool call");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (toolCall instanceof ChatCompletionsFunctionToolCall) {
|
||||||
|
ChatCompletionsFunctionToolCall functionCall = (ChatCompletionsFunctionToolCall) toolCall;
|
||||||
|
if (functionCall.getFunction().getName() != null) {
|
||||||
|
toolExecutionRequestBuilder.nameBuilder.append(functionCall.getFunction().getName());
|
||||||
|
}
|
||||||
|
if (functionCall.getFunction().getArguments() != null) {
|
||||||
|
toolExecutionRequestBuilder.argumentsBuilder.append(functionCall.getFunction().getArguments());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (functionCall.getArguments() != null) {
|
|
||||||
toolArgumentsBuilder.append(functionCall.getArguments());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,7 +142,22 @@ class AzureOpenAiStreamingResponseBuilder {
|
||||||
.build();
|
.build();
|
||||||
return Response.from(
|
return Response.from(
|
||||||
AiMessage.from(toolExecutionRequest),
|
AiMessage.from(toolExecutionRequest),
|
||||||
tokenUsage(toolExecutionRequest, tokenizer, forcefulToolExecution),
|
tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution),
|
||||||
|
finishReasonFrom(finishReason)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!toolExecutionRequestBuilderHashMap.isEmpty()) {
|
||||||
|
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderHashMap.values().stream()
|
||||||
|
.map(it -> ToolExecutionRequest.builder()
|
||||||
|
.id(it.idBuilder.toString())
|
||||||
|
.name(it.nameBuilder.toString())
|
||||||
|
.arguments(it.argumentsBuilder.toString())
|
||||||
|
.build())
|
||||||
|
.collect(toList());
|
||||||
|
return Response.from(
|
||||||
|
AiMessage.from(toolExecutionRequests),
|
||||||
|
tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution),
|
||||||
finishReasonFrom(finishReason)
|
finishReasonFrom(finishReason)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -134,7 +173,7 @@ class AzureOpenAiStreamingResponseBuilder {
|
||||||
return new TokenUsage(inputTokenCount, outputTokenCount);
|
return new TokenUsage(inputTokenCount, outputTokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokenizer tokenizer, boolean forcefulToolExecution) {
|
private TokenUsage tokenUsage(List<ToolExecutionRequest> toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) {
|
||||||
if (tokenizer == null) {
|
if (tokenizer == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -142,11 +181,20 @@ class AzureOpenAiStreamingResponseBuilder {
|
||||||
int outputTokenCount = 0;
|
int outputTokenCount = 0;
|
||||||
if (forcefulToolExecution) {
|
if (forcefulToolExecution) {
|
||||||
// OpenAI calculates output tokens differently when tool is executed forcefully
|
// OpenAI calculates output tokens differently when tool is executed forcefully
|
||||||
|
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
|
||||||
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
|
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
|
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new TokenUsage(inputTokenCount, outputTokenCount);
|
return new TokenUsage(inputTokenCount, outputTokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class ToolExecutionRequestBuilder {
|
||||||
|
|
||||||
|
private final StringBuffer idBuilder = new StringBuffer();
|
||||||
|
private final StringBuffer nameBuilder = new StringBuffer();
|
||||||
|
private final StringBuffer argumentsBuilder = new StringBuffer();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,11 +127,11 @@ class InternalAzureOpenAiHelper {
|
||||||
if (message instanceof AiMessage) {
|
if (message instanceof AiMessage) {
|
||||||
AiMessage aiMessage = (AiMessage) message;
|
AiMessage aiMessage = (AiMessage) message;
|
||||||
ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), ""));
|
ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), ""));
|
||||||
chatRequestAssistantMessage.setFunctionCall(functionCallFrom(message));
|
chatRequestAssistantMessage.setToolCalls(toolExecutionRequestsFrom(message));
|
||||||
return chatRequestAssistantMessage;
|
return chatRequestAssistantMessage;
|
||||||
} else if (message instanceof ToolExecutionResultMessage) {
|
} else if (message instanceof ToolExecutionResultMessage) {
|
||||||
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
|
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
|
||||||
return new ChatRequestFunctionMessage(nameFrom(message), toolExecutionResultMessage.text());
|
return new ChatRequestToolMessage(toolExecutionResultMessage.text(), toolExecutionResultMessage.id());
|
||||||
} else if (message instanceof SystemMessage) {
|
} else if (message instanceof SystemMessage) {
|
||||||
SystemMessage systemMessage = (SystemMessage) message;
|
SystemMessage systemMessage = (SystemMessage) message;
|
||||||
return new ChatRequestSystemMessage(systemMessage.text());
|
return new ChatRequestSystemMessage(systemMessage.text());
|
||||||
|
@ -178,37 +178,45 @@ class InternalAzureOpenAiHelper {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FunctionCall functionCallFrom(ChatMessage message) {
|
private static List<ChatCompletionsToolCall> toolExecutionRequestsFrom(ChatMessage message) {
|
||||||
if (message instanceof AiMessage) {
|
if (message instanceof AiMessage) {
|
||||||
AiMessage aiMessage = (AiMessage) message;
|
AiMessage aiMessage = (AiMessage) message;
|
||||||
if (aiMessage.hasToolExecutionRequests()) {
|
if (aiMessage.hasToolExecutionRequests()) {
|
||||||
// TODO switch to tools once supported
|
return aiMessage.toolExecutionRequests().stream()
|
||||||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
.map(toolExecutionRequest -> new ChatCompletionsFunctionToolCall(toolExecutionRequest.id(), new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments())))
|
||||||
return new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments());
|
.collect(toList());
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<FunctionDefinition> toFunctions(Collection<ToolSpecification> toolSpecifications) {
|
public static List<ChatCompletionsToolDefinition> toToolDefinitions(Collection<ToolSpecification> toolSpecifications) {
|
||||||
return toolSpecifications.stream()
|
return toolSpecifications.stream()
|
||||||
.map(InternalAzureOpenAiHelper::toFunction)
|
.map(InternalAzureOpenAiHelper::toToolDefinition)
|
||||||
.collect(toList());
|
.collect(toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FunctionDefinition toFunction(ToolSpecification toolSpecification) {
|
private static ChatCompletionsToolDefinition toToolDefinition(ToolSpecification toolSpecification) {
|
||||||
FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name());
|
FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name());
|
||||||
functionDefinition.setDescription(toolSpecification.description());
|
functionDefinition.setDescription(toolSpecification.description());
|
||||||
functionDefinition.setParameters(toOpenAiParameters(toolSpecification.parameters()));
|
functionDefinition.setParameters(toOpenAiParameters(toolSpecification.parameters()));
|
||||||
return functionDefinition;
|
return new ChatCompletionsFunctionToolDefinition(functionDefinition);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BinaryData toToolChoice(ToolSpecification toolThatMustBeExecuted) {
|
||||||
|
FunctionCall functionCall = new FunctionCall(toolThatMustBeExecuted.name(), toOpenAiParameters(toolThatMustBeExecuted.parameters()).toString());
|
||||||
|
ChatCompletionsToolCall toolToCall = new ChatCompletionsFunctionToolCall(toolThatMustBeExecuted.name(), functionCall);
|
||||||
|
return BinaryData.fromObject(toolToCall);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final Map<String, Object> NO_PARAMETER_DATA = new HashMap<>();
|
private static final Map<String, Object> NO_PARAMETER_DATA = new HashMap<>();
|
||||||
|
|
||||||
static {
|
static {
|
||||||
NO_PARAMETER_DATA.put("type", "object");
|
NO_PARAMETER_DATA.put("type", "object");
|
||||||
NO_PARAMETER_DATA.put("properties", new HashMap<>());
|
NO_PARAMETER_DATA.put("properties", new HashMap<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static BinaryData toOpenAiParameters(ToolParameters toolParameters) {
|
private static BinaryData toOpenAiParameters(ToolParameters toolParameters) {
|
||||||
Parameters parameters = new Parameters();
|
Parameters parameters = new Parameters();
|
||||||
if (toolParameters == null) {
|
if (toolParameters == null) {
|
||||||
|
@ -224,6 +232,7 @@ class InternalAzureOpenAiHelper {
|
||||||
private final String type = "object";
|
private final String type = "object";
|
||||||
|
|
||||||
private Map<String, Map<String, Object>> properties = new HashMap<>();
|
private Map<String, Map<String, Object>> properties = new HashMap<>();
|
||||||
|
|
||||||
private List<String> required = new ArrayList<>();
|
private List<String> required = new ArrayList<>();
|
||||||
|
|
||||||
public String getType() {
|
public String getType() {
|
||||||
|
@ -251,14 +260,19 @@ class InternalAzureOpenAiHelper {
|
||||||
if (chatResponseMessage.getContent() != null) {
|
if (chatResponseMessage.getContent() != null) {
|
||||||
return aiMessage(chatResponseMessage.getContent());
|
return aiMessage(chatResponseMessage.getContent());
|
||||||
} else {
|
} else {
|
||||||
FunctionCall functionCall = chatResponseMessage.getFunctionCall();
|
List<ToolExecutionRequest> toolExecutionRequests = chatResponseMessage.getToolCalls()
|
||||||
|
.stream()
|
||||||
|
.filter(toolCall -> toolCall instanceof ChatCompletionsFunctionToolCall)
|
||||||
|
.map(toolCall -> (ChatCompletionsFunctionToolCall) toolCall)
|
||||||
|
.map(chatCompletionsFunctionToolCall ->
|
||||||
|
ToolExecutionRequest.builder()
|
||||||
|
.id(chatCompletionsFunctionToolCall.getId())
|
||||||
|
.name(chatCompletionsFunctionToolCall.getFunction().getName())
|
||||||
|
.arguments(chatCompletionsFunctionToolCall.getFunction().getArguments())
|
||||||
|
.build())
|
||||||
|
.collect(toList());
|
||||||
|
|
||||||
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
|
return aiMessage(toolExecutionRequests);
|
||||||
.name(functionCall.getName())
|
|
||||||
.arguments(functionCall.getArguments())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
return aiMessage(toolExecutionRequest);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package dev.langchain4j.model.azure;
|
package dev.langchain4j.model.azure;
|
||||||
|
|
||||||
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
|
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
|
||||||
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
|
|
||||||
import com.azure.core.util.BinaryData;
|
import com.azure.core.util.BinaryData;
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
@ -9,27 +8,38 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
import dev.langchain4j.agent.tool.ToolParameters;
|
import dev.langchain4j.agent.tool.ToolParameters;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.*;
|
import dev.langchain4j.data.message.*;
|
||||||
|
import dev.langchain4j.model.StreamingResponseHandler;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.output.TokenUsage;
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
import org.assertj.core.data.Percentage;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.CsvSource;
|
import org.junit.jupiter.params.provider.CsvSource;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
|
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
|
||||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||||
|
import static java.util.Arrays.asList;
|
||||||
|
import static java.util.Collections.singletonList;
|
||||||
|
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.data.Percentage.withPercentage;
|
||||||
|
|
||||||
public class AzureOpenAiChatModelIT {
|
public class AzureOpenAiChatModelIT {
|
||||||
|
|
||||||
Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class);
|
Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class);
|
||||||
|
|
||||||
|
Percentage tokenizerPrecision = withPercentage(5);
|
||||||
|
|
||||||
@ParameterizedTest(name = "Deployment name {0} using {1}")
|
@ParameterizedTest(name = "Deployment name {0} using {1}")
|
||||||
@CsvSource({
|
@CsvSource({
|
||||||
"gpt-35-turbo, gpt-3.5-turbo",
|
"gpt-35-turbo, gpt-3.5-turbo",
|
||||||
|
@ -121,6 +131,7 @@ public class AzureOpenAiChatModelIT {
|
||||||
|
|
||||||
AiMessage aiMessage = response.content();
|
AiMessage aiMessage = response.content();
|
||||||
assertThat(aiMessage.text()).isNull();
|
assertThat(aiMessage.text()).isNull();
|
||||||
|
assertThat(response.finishReason()).isEqualTo(STOP);
|
||||||
|
|
||||||
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
|
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
|
||||||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
||||||
|
@ -131,8 +142,7 @@ public class AzureOpenAiChatModelIT {
|
||||||
|
|
||||||
// We can now call the function with the correct parameters.
|
// We can now call the function with the correct parameters.
|
||||||
WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class);
|
WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class);
|
||||||
int currentWeather = 0;
|
int currentWeather = getCurrentWeather(weatherLocation);
|
||||||
currentWeather = getCurrentWeather(weatherLocation);
|
|
||||||
|
|
||||||
String weather = String.format("The weather in %s is %d degrees %s.",
|
String weather = String.format("The weather in %s is %d degrees %s.",
|
||||||
weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit());
|
weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit());
|
||||||
|
@ -193,6 +203,85 @@ public class AzureOpenAiChatModelIT {
|
||||||
assertThat(toolExecutionRequest.arguments()).isEqualTo("{}");
|
assertThat(toolExecutionRequest.arguments()).isEqualTo("{}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest(name = "Deployment name {0} using {1}")
|
||||||
|
@CsvSource({
|
||||||
|
"gpt-35-turbo, gpt-3.5-turbo",
|
||||||
|
"gpt-4, gpt-4"
|
||||||
|
})
|
||||||
|
void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception {
|
||||||
|
|
||||||
|
ChatLanguageModel model = AzureOpenAiChatModel.builder()
|
||||||
|
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||||
|
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||||
|
.deploymentName(deploymentName)
|
||||||
|
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight.");
|
||||||
|
|
||||||
|
List<ToolSpecification> toolSpecifications = asList(
|
||||||
|
ToolSpecification.builder()
|
||||||
|
.name("sum")
|
||||||
|
.description("returns a sum of two numbers")
|
||||||
|
.addParameter("first", INTEGER)
|
||||||
|
.addParameter("second", INTEGER)
|
||||||
|
.build(),
|
||||||
|
ToolSpecification.builder()
|
||||||
|
.name("square")
|
||||||
|
.description("returns the square of one number")
|
||||||
|
.addParameter("number", INTEGER)
|
||||||
|
.build(),
|
||||||
|
ToolSpecification.builder()
|
||||||
|
.name("cube")
|
||||||
|
.description("returns the cube of one number")
|
||||||
|
.addParameter("number", INTEGER)
|
||||||
|
.build()
|
||||||
|
);
|
||||||
|
|
||||||
|
Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), toolSpecifications);
|
||||||
|
|
||||||
|
AiMessage aiMessage = response.content();
|
||||||
|
assertThat(aiMessage.text()).isNull();
|
||||||
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
messages.add(userMessage);
|
||||||
|
messages.add(aiMessage);
|
||||||
|
assertThat(aiMessage.toolExecutionRequests()).hasSize(3);
|
||||||
|
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
|
||||||
|
assertThat(toolExecutionRequest.name()).isNotEmpty();
|
||||||
|
ToolExecutionResultMessage toolExecutionResultMessage;
|
||||||
|
if (toolExecutionRequest.name().equals("sum")) {
|
||||||
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||||
|
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4");
|
||||||
|
} else if (toolExecutionRequest.name().equals("square")) {
|
||||||
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}");
|
||||||
|
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16");
|
||||||
|
} else if (toolExecutionRequest.name().equals("cube")) {
|
||||||
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}");
|
||||||
|
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512");
|
||||||
|
} else {
|
||||||
|
throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name());
|
||||||
|
}
|
||||||
|
messages.add(toolExecutionResultMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
Response<AiMessage> response2 = model.generate(messages);
|
||||||
|
AiMessage aiMessage2 = response2.content();
|
||||||
|
|
||||||
|
// then
|
||||||
|
logger.debug("Final answer is: " + aiMessage2);
|
||||||
|
assertThat(aiMessage2.text()).contains("4", "16", "512");
|
||||||
|
assertThat(aiMessage2.toolExecutionRequests()).isNull();
|
||||||
|
|
||||||
|
TokenUsage tokenUsage2 = response2.tokenUsage();
|
||||||
|
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(112, tokenizerPrecision);
|
||||||
|
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
|
||||||
|
assertThat(tokenUsage2.totalTokenCount())
|
||||||
|
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
|
||||||
|
|
||||||
|
assertThat(response2.finishReason()).isEqualTo(STOP);
|
||||||
|
}
|
||||||
|
|
||||||
@ParameterizedTest(name = "Deployment name {0} using {1}")
|
@ParameterizedTest(name = "Deployment name {0} using {1}")
|
||||||
@CsvSource({
|
@CsvSource({
|
||||||
"gpt-35-turbo, gpt-3.5-turbo",
|
"gpt-35-turbo, gpt-3.5-turbo",
|
||||||
|
|
|
@ -6,12 +6,16 @@ import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
|
||||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||||
import dev.langchain4j.data.message.UserMessage;
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.model.StreamingResponseHandler;
|
import dev.langchain4j.model.StreamingResponseHandler;
|
||||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
import org.assertj.core.data.Percentage;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.CsvSource;
|
import org.junit.jupiter.params.provider.CsvSource;
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
@ -19,20 +23,26 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||||
|
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
|
||||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
import static java.util.Arrays.asList;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.data.Percentage.withPercentage;
|
||||||
|
|
||||||
class AzureOpenAiStreamingChatModelIT {
|
class AzureOpenAiStreamingChatModelIT {
|
||||||
|
|
||||||
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModelIT.class);
|
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModelIT.class);
|
||||||
|
|
||||||
|
Percentage tokenizerPrecision = withPercentage(5);
|
||||||
|
|
||||||
@ParameterizedTest(name = "Deployment name {0} using {1} with async client set to {2}")
|
@ParameterizedTest(name = "Deployment name {0} using {1} with async client set to {2}")
|
||||||
@CsvSource({
|
@CsvSource({
|
||||||
"gpt-35-turbo, gpt-3.5-turbo, true",
|
"gpt-35-turbo, gpt-3.5-turbo, true",
|
||||||
|
@ -188,16 +198,7 @@ class AzureOpenAiStreamingChatModelIT {
|
||||||
"gpt-35-turbo, gpt-3.5-turbo",
|
"gpt-35-turbo, gpt-3.5-turbo",
|
||||||
"gpt-4, gpt-4"
|
"gpt-4, gpt-4"
|
||||||
})
|
})
|
||||||
void should_return_tool_execution_request(String deploymentName, String gptVersion) throws Exception {
|
void should_call_function_with_argument(String deploymentName, String gptVersion) throws Exception {
|
||||||
|
|
||||||
ToolSpecification toolSpecification = ToolSpecification.builder()
|
|
||||||
.name("calculator")
|
|
||||||
.description("returns a sum of two numbers")
|
|
||||||
.addParameter("first", INTEGER)
|
|
||||||
.addParameter("second", INTEGER)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
UserMessage userMessage = userMessage("Two plus two?");
|
|
||||||
|
|
||||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||||
|
|
||||||
|
@ -209,7 +210,18 @@ class AzureOpenAiStreamingChatModelIT {
|
||||||
.logRequestsAndResponses(true)
|
.logRequestsAndResponses(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
model.generate(singletonList(userMessage), singletonList(toolSpecification), new StreamingResponseHandler<AiMessage>() {
|
UserMessage userMessage = userMessage("Two plus two?");
|
||||||
|
|
||||||
|
String toolName = "calculator";
|
||||||
|
|
||||||
|
ToolSpecification toolSpecification = ToolSpecification.builder()
|
||||||
|
.name(toolName)
|
||||||
|
.description("returns a sum of two numbers")
|
||||||
|
.addParameter("first", INTEGER)
|
||||||
|
.addParameter("second", INTEGER)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
model.generate(singletonList(userMessage), toolSpecification, new StreamingResponseHandler<AiMessage>() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onNext(String token) {
|
public void onNext(String token) {
|
||||||
|
@ -237,14 +249,175 @@ class AzureOpenAiStreamingChatModelIT {
|
||||||
|
|
||||||
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
|
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
|
||||||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
||||||
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
|
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);
|
||||||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||||
|
|
||||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53);
|
assertThat(response.tokenUsage().inputTokenCount()).isCloseTo(58, tokenizerPrecision);
|
||||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||||
assertThat(response.tokenUsage().totalTokenCount())
|
assertThat(response.tokenUsage().totalTokenCount())
|
||||||
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
|
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
|
||||||
|
|
||||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
assertThat(response.finishReason()).isEqualTo(STOP);
|
||||||
|
|
||||||
|
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "four");
|
||||||
|
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||||
|
|
||||||
|
CompletableFuture<Response<AiMessage>> futureResponse2 = new CompletableFuture<>();
|
||||||
|
|
||||||
|
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(String token) {
|
||||||
|
logger.info("onNext: '" + token + "'");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete(Response<AiMessage> response) {
|
||||||
|
logger.info("onComplete: '" + response + "'");
|
||||||
|
futureResponse2.complete(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable error) {
|
||||||
|
futureResponse2.completeExceptionally(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Response<AiMessage> response2 = futureResponse2.get(30, SECONDS);
|
||||||
|
AiMessage aiMessage2 = response2.content();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(aiMessage2.text()).contains("four");
|
||||||
|
assertThat(aiMessage2.toolExecutionRequests()).isNull();
|
||||||
|
|
||||||
|
TokenUsage tokenUsage2 = response2.tokenUsage();
|
||||||
|
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(33, tokenizerPrecision);
|
||||||
|
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
|
||||||
|
assertThat(tokenUsage2.totalTokenCount())
|
||||||
|
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
|
||||||
|
|
||||||
|
assertThat(response2.finishReason()).isEqualTo(STOP);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest(name = "Deployment name {0} using {1}")
|
||||||
|
@CsvSource({
|
||||||
|
"gpt-35-turbo, gpt-3.5-turbo",
|
||||||
|
"gpt-4, gpt-4"
|
||||||
|
})
|
||||||
|
void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception {
|
||||||
|
|
||||||
|
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||||
|
|
||||||
|
StreamingChatLanguageModel model = AzureOpenAiStreamingChatModel.builder()
|
||||||
|
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||||
|
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||||
|
.deploymentName(deploymentName)
|
||||||
|
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight.");
|
||||||
|
|
||||||
|
List<ToolSpecification> toolSpecifications = asList(
|
||||||
|
ToolSpecification.builder()
|
||||||
|
.name("sum")
|
||||||
|
.description("returns a sum of two numbers")
|
||||||
|
.addParameter("first", INTEGER)
|
||||||
|
.addParameter("second", INTEGER)
|
||||||
|
.build(),
|
||||||
|
ToolSpecification.builder()
|
||||||
|
.name("square")
|
||||||
|
.description("returns the square of one number")
|
||||||
|
.addParameter("number", INTEGER)
|
||||||
|
.build(),
|
||||||
|
ToolSpecification.builder()
|
||||||
|
.name("cube")
|
||||||
|
.description("returns the cube of one number")
|
||||||
|
.addParameter("number", INTEGER)
|
||||||
|
.build()
|
||||||
|
);
|
||||||
|
|
||||||
|
model.generate(singletonList(userMessage), toolSpecifications, new StreamingResponseHandler<AiMessage>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(String token) {
|
||||||
|
logger.info("onNext: '" + token + "'");
|
||||||
|
Exception e = new IllegalStateException("onNext() should never be called when tool is executed");
|
||||||
|
futureResponse.completeExceptionally(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete(Response<AiMessage> response) {
|
||||||
|
logger.info("onComplete: '" + response + "'");
|
||||||
|
futureResponse.complete(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable error) {
|
||||||
|
futureResponse.completeExceptionally(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
||||||
|
|
||||||
|
AiMessage aiMessage = response.content();
|
||||||
|
assertThat(aiMessage.text()).isNull();
|
||||||
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
messages.add(userMessage);
|
||||||
|
messages.add(aiMessage);
|
||||||
|
assertThat(aiMessage.toolExecutionRequests()).hasSize(3);
|
||||||
|
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
|
||||||
|
assertThat(toolExecutionRequest.name()).isNotEmpty();
|
||||||
|
ToolExecutionResultMessage toolExecutionResultMessage;
|
||||||
|
if (toolExecutionRequest.name().equals("sum")) {
|
||||||
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||||
|
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4");
|
||||||
|
} else if (toolExecutionRequest.name().equals("square")) {
|
||||||
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}");
|
||||||
|
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16");
|
||||||
|
} else if (toolExecutionRequest.name().equals("cube")) {
|
||||||
|
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}");
|
||||||
|
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512");
|
||||||
|
} else {
|
||||||
|
throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name());
|
||||||
|
}
|
||||||
|
messages.add(toolExecutionResultMessage);
|
||||||
|
}
|
||||||
|
CompletableFuture<Response<AiMessage>> futureResponse2 = new CompletableFuture<>();
|
||||||
|
|
||||||
|
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(String token) {
|
||||||
|
logger.info("onNext: '" + token + "'");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete(Response<AiMessage> response) {
|
||||||
|
logger.info("onComplete: '" + response + "'");
|
||||||
|
futureResponse2.complete(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable error) {
|
||||||
|
futureResponse2.completeExceptionally(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Response<AiMessage> response2 = futureResponse2.get(30, SECONDS);
|
||||||
|
AiMessage aiMessage2 = response2.content();
|
||||||
|
|
||||||
|
// then
|
||||||
|
logger.debug("Final answer is: " + aiMessage2);
|
||||||
|
assertThat(aiMessage2.text()).contains("4", "16", "512");
|
||||||
|
assertThat(aiMessage2.toolExecutionRequests()).isNull();
|
||||||
|
|
||||||
|
TokenUsage tokenUsage2 = response2.tokenUsage();
|
||||||
|
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(119, tokenizerPrecision);
|
||||||
|
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
|
||||||
|
assertThat(tokenUsage2.totalTokenCount())
|
||||||
|
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
|
||||||
|
|
||||||
|
assertThat(response2.finishReason()).isEqualTo(STOP);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,7 @@ package dev.langchain4j.model.azure;
|
||||||
import com.azure.ai.openai.OpenAIAsyncClient;
|
import com.azure.ai.openai.OpenAIAsyncClient;
|
||||||
import com.azure.ai.openai.OpenAIClient;
|
import com.azure.ai.openai.OpenAIClient;
|
||||||
import com.azure.ai.openai.OpenAIServiceVersion;
|
import com.azure.ai.openai.OpenAIServiceVersion;
|
||||||
import com.azure.ai.openai.models.ChatRequestMessage;
|
import com.azure.ai.openai.models.*;
|
||||||
import com.azure.ai.openai.models.ChatRequestUserMessage;
|
|
||||||
import com.azure.ai.openai.models.CompletionsFinishReason;
|
|
||||||
import com.azure.ai.openai.models.FunctionDefinition;
|
|
||||||
import dev.langchain4j.agent.tool.ToolParameters;
|
import dev.langchain4j.agent.tool.ToolParameters;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
@ -20,6 +17,7 @@ import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
|
|
||||||
class InternalAzureOpenAiHelperTest {
|
class InternalAzureOpenAiHelperTest {
|
||||||
|
@ -83,7 +81,7 @@ class InternalAzureOpenAiHelperTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void toFunctionsShouldReturnCorrectFunctions() {
|
void toToolDefinitionsShouldReturnCorrectToolDefinition() {
|
||||||
Collection<ToolSpecification> toolSpecifications = new ArrayList<>();
|
Collection<ToolSpecification> toolSpecifications = new ArrayList<>();
|
||||||
toolSpecifications.add(ToolSpecification.builder()
|
toolSpecifications.add(ToolSpecification.builder()
|
||||||
.name("test-tool")
|
.name("test-tool")
|
||||||
|
@ -91,10 +89,11 @@ class InternalAzureOpenAiHelperTest {
|
||||||
.parameters(ToolParameters.builder().build())
|
.parameters(ToolParameters.builder().build())
|
||||||
.build());
|
.build());
|
||||||
|
|
||||||
List<FunctionDefinition> functions = InternalAzureOpenAiHelper.toFunctions(toolSpecifications);
|
List<ChatCompletionsToolDefinition> tools = InternalAzureOpenAiHelper.toToolDefinitions(toolSpecifications);
|
||||||
|
|
||||||
assertThat(functions).hasSize(toolSpecifications.size());
|
assertEquals(toolSpecifications.size(), tools.size());
|
||||||
assertThat(functions.get(0).getName()).isEqualTo(toolSpecifications.iterator().next().name());
|
assertInstanceOf(ChatCompletionsFunctionToolDefinition.class, tools.get(0));
|
||||||
|
assertEquals(toolSpecifications.iterator().next().name(), ((ChatCompletionsFunctionToolDefinition) tools.get(0)).getFunction().getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -301,6 +301,18 @@
|
||||||
<type>pom</type>
|
<type>pom</type>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||||
|
<artifactId>jackson-dataformat-xml</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.opennlp</groupId>
|
<groupId>org.apache.opennlp</groupId>
|
||||||
<artifactId>opennlp-tools</artifactId>
|
<artifactId>opennlp-tools</artifactId>
|
||||||
|
|
Loading…
Reference in New Issue