Azure OpenAI: migrate to the new API, from "function" to "tools" (#529)

As discussed in #521 there's a new API to call functions, which is
documented at https://platform.openai.com/docs/api-reference/chat/create

Fix #521
This commit is contained in:
Julien Dubois 2024-05-17 07:24:23 +02:00 committed by GitHub
parent 0efda674f9
commit a5b168f061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 407 additions and 71 deletions

View File

@ -24,6 +24,7 @@ import java.util.Map;
import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*;
import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
@ -244,11 +245,12 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima
.setSeed(seed) .setSeed(seed)
.setResponseFormat(responseFormat); .setResponseFormat(responseFormat);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
options.setFunctions(toFunctions(toolSpecifications));
}
if (toolThatMustBeExecuted != null) { if (toolThatMustBeExecuted != null) {
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name())); options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted)));
options.setToolChoice(toToolChoice(toolThatMustBeExecuted));
}
if (!isNullOrEmpty(toolSpecifications)) {
options.setTools(toToolDefinitions(toolSpecifications));
} }
try { try {

View File

@ -274,17 +274,14 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages); Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
if (toolThatMustBeExecuted != null) { if (toolThatMustBeExecuted != null) {
options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted))); options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted)));
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name())); options.setToolChoice(toToolChoice(toolThatMustBeExecuted));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
} }
} else if (!isNullOrEmpty(toolSpecifications)) { if (!isNullOrEmpty(toolSpecifications)) {
options.setFunctions(toFunctions(toolSpecifications)); options.setTools(toToolDefinitions(toolSpecifications));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
} }
}
AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount); AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount);

View File

@ -1,7 +1,9 @@
package dev.langchain4j.model.azure; package dev.langchain4j.model.azure;
import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.*; import com.azure.ai.openai.models.Choice;
import com.azure.ai.openai.models.Completions;
import com.azure.ai.openai.models.CompletionsOptions;
import com.azure.core.credential.KeyCredential; import com.azure.core.credential.KeyCredential;
import com.azure.core.credential.TokenCredential; import com.azure.core.credential.TokenCredential;
import com.azure.core.exception.HttpResponseException; import com.azure.core.exception.HttpResponseException;

View File

@ -6,11 +6,16 @@ import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/** /**
* This class needs to be thread safe because it is called when a streaming result comes back * This class needs to be thread safe because it is called when a streaming result comes back
@ -19,9 +24,13 @@ import static java.util.Collections.singletonList;
*/ */
class AzureOpenAiStreamingResponseBuilder { class AzureOpenAiStreamingResponseBuilder {
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingResponseBuilder.class);
private final StringBuffer contentBuilder = new StringBuffer(); private final StringBuffer contentBuilder = new StringBuffer();
private final StringBuffer toolNameBuilder = new StringBuffer(); private final StringBuffer toolNameBuilder = new StringBuffer();
private final StringBuffer toolArgumentsBuilder = new StringBuffer(); private final StringBuffer toolArgumentsBuilder = new StringBuffer();
private String toolExecutionsIndex = "call_undefined";
private final Map<String, ToolExecutionRequestBuilder> toolExecutionRequestBuilderHashMap = new HashMap<>();
private volatile CompletionsFinishReason finishReason; private volatile CompletionsFinishReason finishReason;
private final Integer inputTokenCount; private final Integer inputTokenCount;
@ -61,14 +70,29 @@ class AzureOpenAiStreamingResponseBuilder {
return; return;
} }
FunctionCall functionCall = delta.getFunctionCall(); if (delta.getToolCalls() != null && !delta.getToolCalls().isEmpty()) {
if (functionCall != null) { for (ChatCompletionsToolCall toolCall : delta.getToolCalls()) {
if (functionCall.getName() != null) { ToolExecutionRequestBuilder toolExecutionRequestBuilder;
toolNameBuilder.append(functionCall.getName()); if (toolCall.getId() != null) {
toolExecutionsIndex = toolCall.getId();
toolExecutionRequestBuilder = new ToolExecutionRequestBuilder();
toolExecutionRequestBuilder.idBuilder.append(toolExecutionsIndex);
toolExecutionRequestBuilderHashMap.put(toolExecutionsIndex, toolExecutionRequestBuilder);
} else {
toolExecutionRequestBuilder = toolExecutionRequestBuilderHashMap.get(toolExecutionsIndex);
if (toolExecutionRequestBuilder == null) {
throw new IllegalStateException("Function without an id defined in the tool call");
}
}
if (toolCall instanceof ChatCompletionsFunctionToolCall) {
ChatCompletionsFunctionToolCall functionCall = (ChatCompletionsFunctionToolCall) toolCall;
if (functionCall.getFunction().getName() != null) {
toolExecutionRequestBuilder.nameBuilder.append(functionCall.getFunction().getName());
}
if (functionCall.getFunction().getArguments() != null) {
toolExecutionRequestBuilder.argumentsBuilder.append(functionCall.getFunction().getArguments());
}
} }
if (functionCall.getArguments() != null) {
toolArgumentsBuilder.append(functionCall.getArguments());
} }
} }
} }
@ -118,7 +142,22 @@ class AzureOpenAiStreamingResponseBuilder {
.build(); .build();
return Response.from( return Response.from(
AiMessage.from(toolExecutionRequest), AiMessage.from(toolExecutionRequest),
tokenUsage(toolExecutionRequest, tokenizer, forcefulToolExecution), tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution),
finishReasonFrom(finishReason)
);
}
if (!toolExecutionRequestBuilderHashMap.isEmpty()) {
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderHashMap.values().stream()
.map(it -> ToolExecutionRequest.builder()
.id(it.idBuilder.toString())
.name(it.nameBuilder.toString())
.arguments(it.argumentsBuilder.toString())
.build())
.collect(toList());
return Response.from(
AiMessage.from(toolExecutionRequests),
tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution),
finishReasonFrom(finishReason) finishReasonFrom(finishReason)
); );
} }
@ -134,7 +173,7 @@ class AzureOpenAiStreamingResponseBuilder {
return new TokenUsage(inputTokenCount, outputTokenCount); return new TokenUsage(inputTokenCount, outputTokenCount);
} }
private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokenizer tokenizer, boolean forcefulToolExecution) { private TokenUsage tokenUsage(List<ToolExecutionRequest> toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) {
if (tokenizer == null) { if (tokenizer == null) {
return null; return null;
} }
@ -142,11 +181,20 @@ class AzureOpenAiStreamingResponseBuilder {
int outputTokenCount = 0; int outputTokenCount = 0;
if (forcefulToolExecution) { if (forcefulToolExecution) {
// OpenAI calculates output tokens differently when tool is executed forcefully // OpenAI calculates output tokens differently when tool is executed forcefully
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest); outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
}
} else { } else {
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest)); outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests);
} }
return new TokenUsage(inputTokenCount, outputTokenCount); return new TokenUsage(inputTokenCount, outputTokenCount);
} }
private static class ToolExecutionRequestBuilder {
private final StringBuffer idBuilder = new StringBuffer();
private final StringBuffer nameBuilder = new StringBuffer();
private final StringBuffer argumentsBuilder = new StringBuffer();
}
} }

View File

@ -127,11 +127,11 @@ class InternalAzureOpenAiHelper {
if (message instanceof AiMessage) { if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message; AiMessage aiMessage = (AiMessage) message;
ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), "")); ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), ""));
chatRequestAssistantMessage.setFunctionCall(functionCallFrom(message)); chatRequestAssistantMessage.setToolCalls(toolExecutionRequestsFrom(message));
return chatRequestAssistantMessage; return chatRequestAssistantMessage;
} else if (message instanceof ToolExecutionResultMessage) { } else if (message instanceof ToolExecutionResultMessage) {
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message; ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
return new ChatRequestFunctionMessage(nameFrom(message), toolExecutionResultMessage.text()); return new ChatRequestToolMessage(toolExecutionResultMessage.text(), toolExecutionResultMessage.id());
} else if (message instanceof SystemMessage) { } else if (message instanceof SystemMessage) {
SystemMessage systemMessage = (SystemMessage) message; SystemMessage systemMessage = (SystemMessage) message;
return new ChatRequestSystemMessage(systemMessage.text()); return new ChatRequestSystemMessage(systemMessage.text());
@ -178,37 +178,45 @@ class InternalAzureOpenAiHelper {
return null; return null;
} }
private static FunctionCall functionCallFrom(ChatMessage message) { private static List<ChatCompletionsToolCall> toolExecutionRequestsFrom(ChatMessage message) {
if (message instanceof AiMessage) { if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message; AiMessage aiMessage = (AiMessage) message;
if (aiMessage.hasToolExecutionRequests()) { if (aiMessage.hasToolExecutionRequests()) {
// TODO switch to tools once supported return aiMessage.toolExecutionRequests().stream()
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); .map(toolExecutionRequest -> new ChatCompletionsFunctionToolCall(toolExecutionRequest.id(), new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments())))
return new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments()); .collect(toList());
}
}
}
}
return null; return null;
} }
public static List<FunctionDefinition> toFunctions(Collection<ToolSpecification> toolSpecifications) { public static List<ChatCompletionsToolDefinition> toToolDefinitions(Collection<ToolSpecification> toolSpecifications) {
return toolSpecifications.stream() return toolSpecifications.stream()
.map(InternalAzureOpenAiHelper::toFunction) .map(InternalAzureOpenAiHelper::toToolDefinition)
.collect(toList()); .collect(toList());
} }
private static FunctionDefinition toFunction(ToolSpecification toolSpecification) { private static ChatCompletionsToolDefinition toToolDefinition(ToolSpecification toolSpecification) {
FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name()); FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name());
functionDefinition.setDescription(toolSpecification.description()); functionDefinition.setDescription(toolSpecification.description());
functionDefinition.setParameters(toOpenAiParameters(toolSpecification.parameters())); functionDefinition.setParameters(toOpenAiParameters(toolSpecification.parameters()));
return functionDefinition; return new ChatCompletionsFunctionToolDefinition(functionDefinition);
}
public static BinaryData toToolChoice(ToolSpecification toolThatMustBeExecuted) {
FunctionCall functionCall = new FunctionCall(toolThatMustBeExecuted.name(), toOpenAiParameters(toolThatMustBeExecuted.parameters()).toString());
ChatCompletionsToolCall toolToCall = new ChatCompletionsFunctionToolCall(toolThatMustBeExecuted.name(), functionCall);
return BinaryData.fromObject(toolToCall);
} }
private static final Map<String, Object> NO_PARAMETER_DATA = new HashMap<>(); private static final Map<String, Object> NO_PARAMETER_DATA = new HashMap<>();
static { static {
NO_PARAMETER_DATA.put("type", "object"); NO_PARAMETER_DATA.put("type", "object");
NO_PARAMETER_DATA.put("properties", new HashMap<>()); NO_PARAMETER_DATA.put("properties", new HashMap<>());
} }
private static BinaryData toOpenAiParameters(ToolParameters toolParameters) { private static BinaryData toOpenAiParameters(ToolParameters toolParameters) {
Parameters parameters = new Parameters(); Parameters parameters = new Parameters();
if (toolParameters == null) { if (toolParameters == null) {
@ -224,6 +232,7 @@ class InternalAzureOpenAiHelper {
private final String type = "object"; private final String type = "object";
private Map<String, Map<String, Object>> properties = new HashMap<>(); private Map<String, Map<String, Object>> properties = new HashMap<>();
private List<String> required = new ArrayList<>(); private List<String> required = new ArrayList<>();
public String getType() { public String getType() {
@ -251,14 +260,19 @@ class InternalAzureOpenAiHelper {
if (chatResponseMessage.getContent() != null) { if (chatResponseMessage.getContent() != null) {
return aiMessage(chatResponseMessage.getContent()); return aiMessage(chatResponseMessage.getContent());
} else { } else {
FunctionCall functionCall = chatResponseMessage.getFunctionCall(); List<ToolExecutionRequest> toolExecutionRequests = chatResponseMessage.getToolCalls()
.stream()
.filter(toolCall -> toolCall instanceof ChatCompletionsFunctionToolCall)
.map(toolCall -> (ChatCompletionsFunctionToolCall) toolCall)
.map(chatCompletionsFunctionToolCall ->
ToolExecutionRequest.builder()
.id(chatCompletionsFunctionToolCall.getId())
.name(chatCompletionsFunctionToolCall.getFunction().getName())
.arguments(chatCompletionsFunctionToolCall.getFunction().getArguments())
.build())
.collect(toList());
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() return aiMessage(toolExecutionRequests);
.name(functionCall.getName())
.arguments(functionCall.getArguments())
.build();
return aiMessage(toolExecutionRequest);
} }
} }

View File

@ -1,7 +1,6 @@
package dev.langchain4j.model.azure; package dev.langchain4j.model.azure;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.core.util.BinaryData; import com.azure.core.util.BinaryData;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@ -9,27 +8,38 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*; import dev.langchain4j.data.message.*;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.*; import java.util.*;
import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.LENGTH; import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP; import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
public class AzureOpenAiChatModelIT { public class AzureOpenAiChatModelIT {
Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class); Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class);
Percentage tokenizerPrecision = withPercentage(5);
@ParameterizedTest(name = "Deployment name {0} using {1}") @ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({ @CsvSource({
"gpt-35-turbo, gpt-3.5-turbo", "gpt-35-turbo, gpt-3.5-turbo",
@ -121,6 +131,7 @@ public class AzureOpenAiChatModelIT {
AiMessage aiMessage = response.content(); AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull(); assertThat(aiMessage.text()).isNull();
assertThat(response.finishReason()).isEqualTo(STOP);
assertThat(aiMessage.toolExecutionRequests()).hasSize(1); assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
@ -131,8 +142,7 @@ public class AzureOpenAiChatModelIT {
// We can now call the function with the correct parameters. // We can now call the function with the correct parameters.
WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class); WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class);
int currentWeather = 0; int currentWeather = getCurrentWeather(weatherLocation);
currentWeather = getCurrentWeather(weatherLocation);
String weather = String.format("The weather in %s is %d degrees %s.", String weather = String.format("The weather in %s is %d degrees %s.",
weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit()); weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit());
@ -193,6 +203,85 @@ public class AzureOpenAiChatModelIT {
assertThat(toolExecutionRequest.arguments()).isEqualTo("{}"); assertThat(toolExecutionRequest.arguments()).isEqualTo("{}");
} }
@ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo",
"gpt-4, gpt-4"
})
void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception {
ChatLanguageModel model = AzureOpenAiChatModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight.");
List<ToolSpecification> toolSpecifications = asList(
ToolSpecification.builder()
.name("sum")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build(),
ToolSpecification.builder()
.name("square")
.description("returns the square of one number")
.addParameter("number", INTEGER)
.build(),
ToolSpecification.builder()
.name("cube")
.description("returns the cube of one number")
.addParameter("number", INTEGER)
.build()
);
Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), toolSpecifications);
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ChatMessage> messages = new ArrayList<>();
messages.add(userMessage);
messages.add(aiMessage);
assertThat(aiMessage.toolExecutionRequests()).hasSize(3);
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
assertThat(toolExecutionRequest.name()).isNotEmpty();
ToolExecutionResultMessage toolExecutionResultMessage;
if (toolExecutionRequest.name().equals("sum")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4");
} else if (toolExecutionRequest.name().equals("square")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16");
} else if (toolExecutionRequest.name().equals("cube")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512");
} else {
throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name());
}
messages.add(toolExecutionResultMessage);
}
Response<AiMessage> response2 = model.generate(messages);
AiMessage aiMessage2 = response2.content();
// then
logger.debug("Final answer is: " + aiMessage2);
assertThat(aiMessage2.text()).contains("4", "16", "512");
assertThat(aiMessage2.toolExecutionRequests()).isNull();
TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(112, tokenizerPrecision);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
assertThat(response2.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest(name = "Deployment name {0} using {1}") @ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({ @CsvSource({
"gpt-35-turbo, gpt-3.5-turbo", "gpt-35-turbo, gpt-3.5-turbo",

View File

@ -6,12 +6,16 @@ import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
@ -19,20 +23,26 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.STOP; import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; import static java.util.Arrays.asList;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
class AzureOpenAiStreamingChatModelIT { class AzureOpenAiStreamingChatModelIT {
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModelIT.class); Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModelIT.class);
Percentage tokenizerPrecision = withPercentage(5);
@ParameterizedTest(name = "Deployment name {0} using {1} with async client set to {2}") @ParameterizedTest(name = "Deployment name {0} using {1} with async client set to {2}")
@CsvSource({ @CsvSource({
"gpt-35-turbo, gpt-3.5-turbo, true", "gpt-35-turbo, gpt-3.5-turbo, true",
@ -188,16 +198,7 @@ class AzureOpenAiStreamingChatModelIT {
"gpt-35-turbo, gpt-3.5-turbo", "gpt-35-turbo, gpt-3.5-turbo",
"gpt-4, gpt-4" "gpt-4, gpt-4"
}) })
void should_return_tool_execution_request(String deploymentName, String gptVersion) throws Exception { void should_call_function_with_argument(String deploymentName, String gptVersion) throws Exception {
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("calculator")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build();
UserMessage userMessage = userMessage("Two plus two?");
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
@ -209,7 +210,18 @@ class AzureOpenAiStreamingChatModelIT {
.logRequestsAndResponses(true) .logRequestsAndResponses(true)
.build(); .build();
model.generate(singletonList(userMessage), singletonList(toolSpecification), new StreamingResponseHandler<AiMessage>() { UserMessage userMessage = userMessage("Two plus two?");
String toolName = "calculator";
ToolSpecification toolSpecification = ToolSpecification.builder()
.name(toolName)
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build();
model.generate(singletonList(userMessage), toolSpecification, new StreamingResponseHandler<AiMessage>() {
@Override @Override
public void onNext(String token) { public void onNext(String token) {
@ -237,14 +249,175 @@ class AzureOpenAiStreamingChatModelIT {
assertThat(aiMessage.toolExecutionRequests()).hasSize(1); assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("calculator"); assertThat(toolExecutionRequest.name()).isEqualTo(toolName);
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53); assertThat(response.tokenUsage().inputTokenCount()).isCloseTo(58, tokenizerPrecision);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount()) assertThat(response.tokenUsage().totalTokenCount())
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); assertThat(response.finishReason()).isEqualTo(STOP);
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "four");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
CompletableFuture<Response<AiMessage>> futureResponse2 = new CompletableFuture<>();
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
}
@Override
public void onComplete(Response<AiMessage> response) {
logger.info("onComplete: '" + response + "'");
futureResponse2.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse2.completeExceptionally(error);
}
});
Response<AiMessage> response2 = futureResponse2.get(30, SECONDS);
AiMessage aiMessage2 = response2.content();
// then
assertThat(aiMessage2.text()).contains("four");
assertThat(aiMessage2.toolExecutionRequests()).isNull();
TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(33, tokenizerPrecision);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
assertThat(response2.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo",
"gpt-4, gpt-4"
})
void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception {
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
StreamingChatLanguageModel model = AzureOpenAiStreamingChatModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight.");
List<ToolSpecification> toolSpecifications = asList(
ToolSpecification.builder()
.name("sum")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build(),
ToolSpecification.builder()
.name("square")
.description("returns the square of one number")
.addParameter("number", INTEGER)
.build(),
ToolSpecification.builder()
.name("cube")
.description("returns the cube of one number")
.addParameter("number", INTEGER)
.build()
);
model.generate(singletonList(userMessage), toolSpecifications, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
Exception e = new IllegalStateException("onNext() should never be called when tool is executed");
futureResponse.completeExceptionally(e);
}
@Override
public void onComplete(Response<AiMessage> response) {
logger.info("onComplete: '" + response + "'");
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse.completeExceptionally(error);
}
});
Response<AiMessage> response = futureResponse.get(30, SECONDS);
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ChatMessage> messages = new ArrayList<>();
messages.add(userMessage);
messages.add(aiMessage);
assertThat(aiMessage.toolExecutionRequests()).hasSize(3);
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
assertThat(toolExecutionRequest.name()).isNotEmpty();
ToolExecutionResultMessage toolExecutionResultMessage;
if (toolExecutionRequest.name().equals("sum")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4");
} else if (toolExecutionRequest.name().equals("square")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16");
} else if (toolExecutionRequest.name().equals("cube")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512");
} else {
throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name());
}
messages.add(toolExecutionResultMessage);
}
CompletableFuture<Response<AiMessage>> futureResponse2 = new CompletableFuture<>();
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
}
@Override
public void onComplete(Response<AiMessage> response) {
logger.info("onComplete: '" + response + "'");
futureResponse2.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse2.completeExceptionally(error);
}
});
Response<AiMessage> response2 = futureResponse2.get(30, SECONDS);
AiMessage aiMessage2 = response2.content();
// then
logger.debug("Final answer is: " + aiMessage2);
assertThat(aiMessage2.text()).contains("4", "16", "512");
assertThat(aiMessage2.toolExecutionRequests()).isNull();
TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(119, tokenizerPrecision);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
assertThat(response2.finishReason()).isEqualTo(STOP);
} }
} }

View File

@ -3,10 +3,7 @@ package dev.langchain4j.model.azure;
import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIServiceVersion; import com.azure.ai.openai.OpenAIServiceVersion;
import com.azure.ai.openai.models.ChatRequestMessage; import com.azure.ai.openai.models.*;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.FunctionDefinition;
import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ChatMessage;
@ -20,6 +17,7 @@ import java.util.Collection;
import java.util.List; import java.util.List;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertInstanceOf;
class InternalAzureOpenAiHelperTest { class InternalAzureOpenAiHelperTest {
@ -83,7 +81,7 @@ class InternalAzureOpenAiHelperTest {
} }
@Test @Test
void toFunctionsShouldReturnCorrectFunctions() { void toToolDefinitionsShouldReturnCorrectToolDefinition() {
Collection<ToolSpecification> toolSpecifications = new ArrayList<>(); Collection<ToolSpecification> toolSpecifications = new ArrayList<>();
toolSpecifications.add(ToolSpecification.builder() toolSpecifications.add(ToolSpecification.builder()
.name("test-tool") .name("test-tool")
@ -91,10 +89,11 @@ class InternalAzureOpenAiHelperTest {
.parameters(ToolParameters.builder().build()) .parameters(ToolParameters.builder().build())
.build()); .build());
List<FunctionDefinition> functions = InternalAzureOpenAiHelper.toFunctions(toolSpecifications); List<ChatCompletionsToolDefinition> tools = InternalAzureOpenAiHelper.toToolDefinitions(toolSpecifications);
assertThat(functions).hasSize(toolSpecifications.size()); assertEquals(toolSpecifications.size(), tools.size());
assertThat(functions.get(0).getName()).isEqualTo(toolSpecifications.iterator().next().name()); assertInstanceOf(ChatCompletionsFunctionToolDefinition.class, tools.get(0));
assertEquals(toolSpecifications.iterator().next().name(), ((ChatCompletionsFunctionToolDefinition) tools.get(0)).getFunction().getName());
} }
@Test @Test

View File

@ -301,6 +301,18 @@
<type>pom</type> <type>pom</type>
</dependency> </dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-xml</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.opennlp</groupId> <groupId>org.apache.opennlp</groupId>
<artifactId>opennlp-tools</artifactId> <artifactId>opennlp-tools</artifactId>