Azure OpenAI: migrate to the new API, from "function" to "tools" (#529)

As discussed in #521 there's a new API to call functions, which is
documented at https://platform.openai.com/docs/api-reference/chat/create

Fix #521
This commit is contained in:
Julien Dubois 2024-05-17 07:24:23 +02:00 committed by GitHub
parent 0efda674f9
commit a5b168f061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 407 additions and 71 deletions

View File

@ -24,6 +24,7 @@ import java.util.Map;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Collections.singletonList;
@ -244,11 +245,12 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima
.setSeed(seed)
.setResponseFormat(responseFormat);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
options.setFunctions(toFunctions(toolSpecifications));
}
if (toolThatMustBeExecuted != null) {
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted)));
options.setToolChoice(toToolChoice(toolThatMustBeExecuted));
}
if (!isNullOrEmpty(toolSpecifications)) {
options.setTools(toToolDefinitions(toolSpecifications));
}
try {

View File

@ -274,16 +274,13 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
if (toolThatMustBeExecuted != null) {
options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted)));
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
}
} else if (!isNullOrEmpty(toolSpecifications)) {
options.setFunctions(toFunctions(toolSpecifications));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted)));
options.setToolChoice(toToolChoice(toolThatMustBeExecuted));
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
}
if (!isNullOrEmpty(toolSpecifications)) {
options.setTools(toToolDefinitions(toolSpecifications));
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount);

View File

@ -1,7 +1,9 @@
package dev.langchain4j.model.azure;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.*;
import com.azure.ai.openai.models.Choice;
import com.azure.ai.openai.models.Completions;
import com.azure.ai.openai.models.CompletionsOptions;
import com.azure.core.credential.KeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.exception.HttpResponseException;

View File

@ -6,11 +6,16 @@ import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
* This class needs to be thread safe because it is called when a streaming result comes back
@ -19,9 +24,13 @@ import static java.util.Collections.singletonList;
*/
class AzureOpenAiStreamingResponseBuilder {
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingResponseBuilder.class);
private final StringBuffer contentBuilder = new StringBuffer();
private final StringBuffer toolNameBuilder = new StringBuffer();
private final StringBuffer toolArgumentsBuilder = new StringBuffer();
private String toolExecutionsIndex = "call_undefined";
private final Map<String, ToolExecutionRequestBuilder> toolExecutionRequestBuilderHashMap = new HashMap<>();
private volatile CompletionsFinishReason finishReason;
private final Integer inputTokenCount;
@ -61,14 +70,29 @@ class AzureOpenAiStreamingResponseBuilder {
return;
}
FunctionCall functionCall = delta.getFunctionCall();
if (functionCall != null) {
if (functionCall.getName() != null) {
toolNameBuilder.append(functionCall.getName());
}
if (functionCall.getArguments() != null) {
toolArgumentsBuilder.append(functionCall.getArguments());
if (delta.getToolCalls() != null && !delta.getToolCalls().isEmpty()) {
for (ChatCompletionsToolCall toolCall : delta.getToolCalls()) {
ToolExecutionRequestBuilder toolExecutionRequestBuilder;
if (toolCall.getId() != null) {
toolExecutionsIndex = toolCall.getId();
toolExecutionRequestBuilder = new ToolExecutionRequestBuilder();
toolExecutionRequestBuilder.idBuilder.append(toolExecutionsIndex);
toolExecutionRequestBuilderHashMap.put(toolExecutionsIndex, toolExecutionRequestBuilder);
} else {
toolExecutionRequestBuilder = toolExecutionRequestBuilderHashMap.get(toolExecutionsIndex);
if (toolExecutionRequestBuilder == null) {
throw new IllegalStateException("Function without an id defined in the tool call");
}
}
if (toolCall instanceof ChatCompletionsFunctionToolCall) {
ChatCompletionsFunctionToolCall functionCall = (ChatCompletionsFunctionToolCall) toolCall;
if (functionCall.getFunction().getName() != null) {
toolExecutionRequestBuilder.nameBuilder.append(functionCall.getFunction().getName());
}
if (functionCall.getFunction().getArguments() != null) {
toolExecutionRequestBuilder.argumentsBuilder.append(functionCall.getFunction().getArguments());
}
}
}
}
}
@ -118,7 +142,22 @@ class AzureOpenAiStreamingResponseBuilder {
.build();
return Response.from(
AiMessage.from(toolExecutionRequest),
tokenUsage(toolExecutionRequest, tokenizer, forcefulToolExecution),
tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution),
finishReasonFrom(finishReason)
);
}
if (!toolExecutionRequestBuilderHashMap.isEmpty()) {
List<ToolExecutionRequest> toolExecutionRequests = toolExecutionRequestBuilderHashMap.values().stream()
.map(it -> ToolExecutionRequest.builder()
.id(it.idBuilder.toString())
.name(it.nameBuilder.toString())
.arguments(it.argumentsBuilder.toString())
.build())
.collect(toList());
return Response.from(
AiMessage.from(toolExecutionRequests),
tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution),
finishReasonFrom(finishReason)
);
}
@ -134,7 +173,7 @@ class AzureOpenAiStreamingResponseBuilder {
return new TokenUsage(inputTokenCount, outputTokenCount);
}
private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokenizer tokenizer, boolean forcefulToolExecution) {
private TokenUsage tokenUsage(List<ToolExecutionRequest> toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) {
if (tokenizer == null) {
return null;
}
@ -142,11 +181,20 @@ class AzureOpenAiStreamingResponseBuilder {
int outputTokenCount = 0;
if (forcefulToolExecution) {
// OpenAI calculates output tokens differently when tool is executed forcefully
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
}
} else {
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests);
}
return new TokenUsage(inputTokenCount, outputTokenCount);
}
private static class ToolExecutionRequestBuilder {
private final StringBuffer idBuilder = new StringBuffer();
private final StringBuffer nameBuilder = new StringBuffer();
private final StringBuffer argumentsBuilder = new StringBuffer();
}
}

View File

@ -127,11 +127,11 @@ class InternalAzureOpenAiHelper {
if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message;
ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), ""));
chatRequestAssistantMessage.setFunctionCall(functionCallFrom(message));
chatRequestAssistantMessage.setToolCalls(toolExecutionRequestsFrom(message));
return chatRequestAssistantMessage;
} else if (message instanceof ToolExecutionResultMessage) {
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
return new ChatRequestFunctionMessage(nameFrom(message), toolExecutionResultMessage.text());
return new ChatRequestToolMessage(toolExecutionResultMessage.text(), toolExecutionResultMessage.id());
} else if (message instanceof SystemMessage) {
SystemMessage systemMessage = (SystemMessage) message;
return new ChatRequestSystemMessage(systemMessage.text());
@ -178,37 +178,45 @@ class InternalAzureOpenAiHelper {
return null;
}
private static FunctionCall functionCallFrom(ChatMessage message) {
private static List<ChatCompletionsToolCall> toolExecutionRequestsFrom(ChatMessage message) {
if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message;
if (aiMessage.hasToolExecutionRequests()) {
// TODO switch to tools once supported
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
return new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments());
return aiMessage.toolExecutionRequests().stream()
.map(toolExecutionRequest -> new ChatCompletionsFunctionToolCall(toolExecutionRequest.id(), new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments())))
.collect(toList());
}
}
return null;
}
public static List<FunctionDefinition> toFunctions(Collection<ToolSpecification> toolSpecifications) {
public static List<ChatCompletionsToolDefinition> toToolDefinitions(Collection<ToolSpecification> toolSpecifications) {
return toolSpecifications.stream()
.map(InternalAzureOpenAiHelper::toFunction)
.map(InternalAzureOpenAiHelper::toToolDefinition)
.collect(toList());
}
private static FunctionDefinition toFunction(ToolSpecification toolSpecification) {
private static ChatCompletionsToolDefinition toToolDefinition(ToolSpecification toolSpecification) {
FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name());
functionDefinition.setDescription(toolSpecification.description());
functionDefinition.setParameters(toOpenAiParameters(toolSpecification.parameters()));
return functionDefinition;
return new ChatCompletionsFunctionToolDefinition(functionDefinition);
}
public static BinaryData toToolChoice(ToolSpecification toolThatMustBeExecuted) {
FunctionCall functionCall = new FunctionCall(toolThatMustBeExecuted.name(), toOpenAiParameters(toolThatMustBeExecuted.parameters()).toString());
ChatCompletionsToolCall toolToCall = new ChatCompletionsFunctionToolCall(toolThatMustBeExecuted.name(), functionCall);
return BinaryData.fromObject(toolToCall);
}
private static final Map<String, Object> NO_PARAMETER_DATA = new HashMap<>();
static {
NO_PARAMETER_DATA.put("type", "object");
NO_PARAMETER_DATA.put("properties", new HashMap<>());
}
private static BinaryData toOpenAiParameters(ToolParameters toolParameters) {
Parameters parameters = new Parameters();
if (toolParameters == null) {
@ -224,6 +232,7 @@ class InternalAzureOpenAiHelper {
private final String type = "object";
private Map<String, Map<String, Object>> properties = new HashMap<>();
private List<String> required = new ArrayList<>();
public String getType() {
@ -251,14 +260,19 @@ class InternalAzureOpenAiHelper {
if (chatResponseMessage.getContent() != null) {
return aiMessage(chatResponseMessage.getContent());
} else {
FunctionCall functionCall = chatResponseMessage.getFunctionCall();
List<ToolExecutionRequest> toolExecutionRequests = chatResponseMessage.getToolCalls()
.stream()
.filter(toolCall -> toolCall instanceof ChatCompletionsFunctionToolCall)
.map(toolCall -> (ChatCompletionsFunctionToolCall) toolCall)
.map(chatCompletionsFunctionToolCall ->
ToolExecutionRequest.builder()
.id(chatCompletionsFunctionToolCall.getId())
.name(chatCompletionsFunctionToolCall.getFunction().getName())
.arguments(chatCompletionsFunctionToolCall.getFunction().getArguments())
.build())
.collect(toList());
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
.name(functionCall.getName())
.arguments(functionCall.getArguments())
.build();
return aiMessage(toolExecutionRequest);
return aiMessage(toolExecutionRequests);
}
}

View File

@ -1,7 +1,6 @@
package dev.langchain4j.model.azure;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.core.util.BinaryData;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
@ -9,27 +8,38 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
public class AzureOpenAiChatModelIT {
Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class);
Percentage tokenizerPrecision = withPercentage(5);
@ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo",
@ -121,6 +131,7 @@ public class AzureOpenAiChatModelIT {
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(response.finishReason()).isEqualTo(STOP);
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
@ -131,8 +142,7 @@ public class AzureOpenAiChatModelIT {
// We can now call the function with the correct parameters.
WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class);
int currentWeather = 0;
currentWeather = getCurrentWeather(weatherLocation);
int currentWeather = getCurrentWeather(weatherLocation);
String weather = String.format("The weather in %s is %d degrees %s.",
weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit());
@ -193,6 +203,85 @@ public class AzureOpenAiChatModelIT {
assertThat(toolExecutionRequest.arguments()).isEqualTo("{}");
}
@ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo",
"gpt-4, gpt-4"
})
void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception {
ChatLanguageModel model = AzureOpenAiChatModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight.");
List<ToolSpecification> toolSpecifications = asList(
ToolSpecification.builder()
.name("sum")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build(),
ToolSpecification.builder()
.name("square")
.description("returns the square of one number")
.addParameter("number", INTEGER)
.build(),
ToolSpecification.builder()
.name("cube")
.description("returns the cube of one number")
.addParameter("number", INTEGER)
.build()
);
Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), toolSpecifications);
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ChatMessage> messages = new ArrayList<>();
messages.add(userMessage);
messages.add(aiMessage);
assertThat(aiMessage.toolExecutionRequests()).hasSize(3);
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
assertThat(toolExecutionRequest.name()).isNotEmpty();
ToolExecutionResultMessage toolExecutionResultMessage;
if (toolExecutionRequest.name().equals("sum")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4");
} else if (toolExecutionRequest.name().equals("square")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16");
} else if (toolExecutionRequest.name().equals("cube")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512");
} else {
throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name());
}
messages.add(toolExecutionResultMessage);
}
Response<AiMessage> response2 = model.generate(messages);
AiMessage aiMessage2 = response2.content();
// then
logger.debug("Final answer is: " + aiMessage2);
assertThat(aiMessage2.text()).contains("4", "16", "512");
assertThat(aiMessage2.toolExecutionRequests()).isNull();
TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(112, tokenizerPrecision);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
assertThat(response2.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo",

View File

@ -6,12 +6,16 @@ import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
@ -19,20 +23,26 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
class AzureOpenAiStreamingChatModelIT {
Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModelIT.class);
Percentage tokenizerPrecision = withPercentage(5);
@ParameterizedTest(name = "Deployment name {0} using {1} with async client set to {2}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo, true",
@ -188,16 +198,7 @@ class AzureOpenAiStreamingChatModelIT {
"gpt-35-turbo, gpt-3.5-turbo",
"gpt-4, gpt-4"
})
void should_return_tool_execution_request(String deploymentName, String gptVersion) throws Exception {
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("calculator")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build();
UserMessage userMessage = userMessage("Two plus two?");
void should_call_function_with_argument(String deploymentName, String gptVersion) throws Exception {
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
@ -209,7 +210,18 @@ class AzureOpenAiStreamingChatModelIT {
.logRequestsAndResponses(true)
.build();
model.generate(singletonList(userMessage), singletonList(toolSpecification), new StreamingResponseHandler<AiMessage>() {
UserMessage userMessage = userMessage("Two plus two?");
String toolName = "calculator";
ToolSpecification toolSpecification = ToolSpecification.builder()
.name(toolName)
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build();
model.generate(singletonList(userMessage), toolSpecification, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
@ -237,14 +249,175 @@ class AzureOpenAiStreamingChatModelIT {
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53);
assertThat(response.tokenUsage().inputTokenCount()).isCloseTo(58, tokenizerPrecision);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount())
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
assertThat(response.finishReason()).isEqualTo(STOP);
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "four");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
CompletableFuture<Response<AiMessage>> futureResponse2 = new CompletableFuture<>();
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
}
@Override
public void onComplete(Response<AiMessage> response) {
logger.info("onComplete: '" + response + "'");
futureResponse2.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse2.completeExceptionally(error);
}
});
Response<AiMessage> response2 = futureResponse2.get(30, SECONDS);
AiMessage aiMessage2 = response2.content();
// then
assertThat(aiMessage2.text()).contains("four");
assertThat(aiMessage2.toolExecutionRequests()).isNull();
TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(33, tokenizerPrecision);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
assertThat(response2.finishReason()).isEqualTo(STOP);
}
@ParameterizedTest(name = "Deployment name {0} using {1}")
@CsvSource({
"gpt-35-turbo, gpt-3.5-turbo",
"gpt-4, gpt-4"
})
void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception {
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
StreamingChatLanguageModel model = AzureOpenAiStreamingChatModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight.");
List<ToolSpecification> toolSpecifications = asList(
ToolSpecification.builder()
.name("sum")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build(),
ToolSpecification.builder()
.name("square")
.description("returns the square of one number")
.addParameter("number", INTEGER)
.build(),
ToolSpecification.builder()
.name("cube")
.description("returns the cube of one number")
.addParameter("number", INTEGER)
.build()
);
model.generate(singletonList(userMessage), toolSpecifications, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
Exception e = new IllegalStateException("onNext() should never be called when tool is executed");
futureResponse.completeExceptionally(e);
}
@Override
public void onComplete(Response<AiMessage> response) {
logger.info("onComplete: '" + response + "'");
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse.completeExceptionally(error);
}
});
Response<AiMessage> response = futureResponse.get(30, SECONDS);
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
List<ChatMessage> messages = new ArrayList<>();
messages.add(userMessage);
messages.add(aiMessage);
assertThat(aiMessage.toolExecutionRequests()).hasSize(3);
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
assertThat(toolExecutionRequest.name()).isNotEmpty();
ToolExecutionResultMessage toolExecutionResultMessage;
if (toolExecutionRequest.name().equals("sum")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4");
} else if (toolExecutionRequest.name().equals("square")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16");
} else if (toolExecutionRequest.name().equals("cube")) {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}");
toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512");
} else {
throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name());
}
messages.add(toolExecutionResultMessage);
}
CompletableFuture<Response<AiMessage>> futureResponse2 = new CompletableFuture<>();
model.generate(messages, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
}
@Override
public void onComplete(Response<AiMessage> response) {
logger.info("onComplete: '" + response + "'");
futureResponse2.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse2.completeExceptionally(error);
}
});
Response<AiMessage> response2 = futureResponse2.get(30, SECONDS);
AiMessage aiMessage2 = response2.content();
// then
logger.debug("Final answer is: " + aiMessage2);
assertThat(aiMessage2.text()).contains("4", "16", "512");
assertThat(aiMessage2.toolExecutionRequests()).isNull();
TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isCloseTo(119, tokenizerPrecision);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
assertThat(response2.finishReason()).isEqualTo(STOP);
}
}

View File

@ -3,10 +3,7 @@ package dev.langchain4j.model.azure;
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIServiceVersion;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.ai.openai.models.*;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
@ -20,6 +17,7 @@ import java.util.Collection;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
class InternalAzureOpenAiHelperTest {
@ -83,7 +81,7 @@ class InternalAzureOpenAiHelperTest {
}
@Test
void toFunctionsShouldReturnCorrectFunctions() {
void toToolDefinitionsShouldReturnCorrectToolDefinition() {
Collection<ToolSpecification> toolSpecifications = new ArrayList<>();
toolSpecifications.add(ToolSpecification.builder()
.name("test-tool")
@ -91,10 +89,11 @@ class InternalAzureOpenAiHelperTest {
.parameters(ToolParameters.builder().build())
.build());
List<FunctionDefinition> functions = InternalAzureOpenAiHelper.toFunctions(toolSpecifications);
List<ChatCompletionsToolDefinition> tools = InternalAzureOpenAiHelper.toToolDefinitions(toolSpecifications);
assertThat(functions).hasSize(toolSpecifications.size());
assertThat(functions.get(0).getName()).isEqualTo(toolSpecifications.iterator().next().name());
assertEquals(toolSpecifications.size(), tools.size());
assertInstanceOf(ChatCompletionsFunctionToolDefinition.class, tools.get(0));
assertEquals(toolSpecifications.iterator().next().name(), ((ChatCompletionsFunctionToolDefinition) tools.get(0)).getFunction().getName());
}
@Test

View File

@ -301,6 +301,18 @@
<type>pom</type>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-xml</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>org.apache.opennlp</groupId>
<artifactId>opennlp-tools</artifactId>