OpenAI: return token usage returned by OpenAI (#1622)

This commit is contained in:
LangChain4j 2024-09-22 10:39:00 +02:00
parent 33199dc588
commit 10ea33fe26
8 changed files with 38 additions and 132 deletions

View File

@ -102,7 +102,7 @@ public class LocalAiStreamingChatModel implements StreamingChatLanguageModel {
ChatCompletionRequest request = requestBuilder.build();
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(null);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
client.chatCompletion(request)
.onPartialResponse(partialResponse -> {
@ -110,7 +110,7 @@ public class LocalAiStreamingChatModel implements StreamingChatLanguageModel {
handle(partialResponse, handler);
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build(null, false);
Response<AiMessage> response = responseBuilder.build();
handler.onComplete(response);
})
.onError(handler::onError)

View File

@ -67,7 +67,7 @@ public class LocalAiStreamingLanguageModel implements StreamingLanguageModel {
.maxTokens(maxTokens)
.build();
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(null);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
client.completion(request)
.onPartialResponse(partialResponse -> {
@ -78,7 +78,7 @@ public class LocalAiStreamingLanguageModel implements StreamingLanguageModel {
}
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build(null, false);
Response<AiMessage> response = responseBuilder.build();
handler.onComplete(Response.from(
response.content().text(),
response.tokenUsage(),

View File

@ -385,22 +385,6 @@ public class InternalOpenAiHelper {
}
}
static boolean isOpenAiModel(String modelName) {
if (modelName == null) {
return false;
}
for (OpenAiChatModelName openAiChatModelName : OpenAiChatModelName.values()) {
if (modelName.contains(openAiChatModelName.toString())) {
return true;
}
}
return false;
}
static Response<AiMessage> removeTokenUsage(Response<AiMessage> response) {
return Response.from(response.content(), null, response.finishReason());
}
static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {

View File

@ -7,7 +7,7 @@ import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.ai4j.openai4j.chat.ResponseFormatType;
import dev.ai4j.openai4j.chat.StreamOptions;
import dev.ai4j.openai4j.shared.StreamOptions;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
@ -42,8 +42,6 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGE
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.isOpenAiModel;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.removeTokenUsage;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
@ -76,7 +74,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
private final Boolean strictTools;
private final Boolean parallelToolCalls;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;
private final List<ChatModelListener> listeners;
@Builder
@ -138,7 +135,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
this.strictTools = getOrDefault(strictTools, false);
this.parallelToolCalls = parallelToolCalls;
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.isOpenAiModel = isOpenAiModel(this.modelName);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
}
@ -206,8 +202,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
}
});
int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
AtomicReference<String> responseId = new AtomicReference<>();
AtomicReference<String> responseModel = new AtomicReference<>();
@ -225,7 +220,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
}
})
.onComplete(() -> {
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
Response<AiMessage> response = responseBuilder.build();
ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
@ -248,7 +243,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
handler.onComplete(response);
})
.onError(error -> {
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
Response<AiMessage> response = responseBuilder.build();
ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
responseId.get(),
@ -276,27 +271,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
.execute();
}
private Response<AiMessage> createResponse(OpenAiStreamingResponseBuilder responseBuilder,
ToolSpecification toolThatMustBeExecuted) {
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
if (isOpenAiModel) {
return response;
}
return removeTokenUsage(response);
}
private int countInputTokens(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages);
if (toolThatMustBeExecuted != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
} else if (!isNullOrEmpty(toolSpecifications)) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
return inputTokenCount;
}
private static void handle(ChatCompletionResponse partialResponse,
StreamingResponseHandler<AiMessage> handler) {
List<ChatCompletionChoice> choices = partialResponse.choices();

View File

@ -1,7 +1,9 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.completion.CompletionChoice;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.ai4j.openai4j.shared.StreamOptions;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
@ -16,6 +18,7 @@ import java.time.Duration;
import java.util.Map;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_INSTRUCT;
@ -77,24 +80,29 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
public void generate(String prompt, StreamingResponseHandler<String> handler) {
CompletionRequest request = CompletionRequest.builder()
.stream(true)
.streamOptions(StreamOptions.builder()
.includeUsage(true)
.build())
.model(modelName)
.prompt(prompt)
.temperature(temperature)
.build();
int inputTokenCount = tokenizer.estimateTokenCountInText(prompt);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
client.completion(request)
.onPartialResponse(partialResponse -> {
responseBuilder.append(partialResponse);
String token = partialResponse.text();
if (token != null) {
handler.onNext(token);
for (CompletionChoice choice : partialResponse.choices()) {
String token = choice.text();
if (isNotNullOrEmpty(token)) {
handler.onNext(token);
}
}
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build(tokenizer, false);
Response<AiMessage> response = responseBuilder.build();
handler.onComplete(Response.from(
response.content().text(),
response.tokenUsage(),

View File

@ -10,7 +10,6 @@ import dev.ai4j.openai4j.completion.CompletionResponse;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
@ -21,7 +20,6 @@ import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
@ -41,12 +39,6 @@ public class OpenAiStreamingResponseBuilder {
private volatile TokenUsage tokenUsage;
private volatile FinishReason finishReason;
private final Integer inputTokenCount;
public OpenAiStreamingResponseBuilder(Integer inputTokenCount) {
this.inputTokenCount = inputTokenCount;
}
public void append(ChatCompletionResponse partialResponse) {
if (partialResponse == null) {
return;
@ -122,6 +114,11 @@ public class OpenAiStreamingResponseBuilder {
return;
}
Usage usage = partialResponse.usage();
if (usage != null) {
this.tokenUsage = tokenUsageFrom(usage);
}
List<CompletionChoice> choices = partialResponse.choices();
if (choices == null || choices.isEmpty()) {
return;
@ -143,13 +140,13 @@ public class OpenAiStreamingResponseBuilder {
}
}
public Response<AiMessage> build(Tokenizer tokenizer, boolean forcefulToolExecution) {
public Response<AiMessage> build() {
String content = contentBuilder.toString();
if (!content.isEmpty()) {
return Response.from(
AiMessage.from(content),
tokenUsage(content, tokenizer),
tokenUsage,
finishReason
);
}
@ -162,7 +159,7 @@ public class OpenAiStreamingResponseBuilder {
.build();
return Response.from(
AiMessage.from(toolExecutionRequest),
tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution),
tokenUsage,
finishReason
);
}
@ -177,7 +174,7 @@ public class OpenAiStreamingResponseBuilder {
.collect(toList());
return Response.from(
AiMessage.from(toolExecutionRequests),
tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution),
tokenUsage,
finishReason
);
}
@ -185,41 +182,6 @@ public class OpenAiStreamingResponseBuilder {
return null;
}
private TokenUsage tokenUsage(String content, Tokenizer tokenizer) {
if (tokenUsage != null) {
return tokenUsage;
}
if (tokenizer == null) {
return null;
}
int outputTokenCount = tokenizer.estimateTokenCountInText(content);
return new TokenUsage(inputTokenCount, outputTokenCount);
}
private TokenUsage tokenUsage(List<ToolExecutionRequest> toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) {
if (tokenUsage != null) {
return tokenUsage;
}
if (tokenizer == null) {
return null;
}
int outputTokenCount = 0;
if (forcefulToolExecution) {
// OpenAI calculates output tokens differently when tool is executed forcefully
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
}
} else {
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests);
}
return new TokenUsage(inputTokenCount, outputTokenCount);
}
private static class ToolExecutionRequestBuilder {
private final StringBuffer idBuilder = new StringBuffer();

View File

@ -1,15 +1,16 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.chat.*;
import dev.ai4j.openai4j.chat.AssistantMessage;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.FunctionCall;
import dev.ai4j.openai4j.chat.ToolCall;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
@ -134,27 +135,4 @@ class InternalOpenAiHelperTest {
.build()
);
}
@Test
void test_isOpenAiModel() {
assertThat(isOpenAiModel(null)).isFalse();
assertThat(isOpenAiModel("")).isFalse();
assertThat(isOpenAiModel(" ")).isFalse();
assertThat(isOpenAiModel("llama2")).isFalse();
assertThat(isOpenAiModel("gpt-3.5-turbo")).isTrue();
assertThat(isOpenAiModel("ft:gpt-3.5-turbo:my-org:custom_suffix:id")).isTrue();
}
@Test
void test_removeTokenUsage() {
assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"))))
.isEqualTo(Response.from(AiMessage.from("Hello")));
assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"), new TokenUsage(42))))
.isEqualTo(Response.from(AiMessage.from("Hello")));
assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"), new TokenUsage(42), STOP)))
.isEqualTo(Response.from(AiMessage.from("Hello"), null, STOP));
}
}

View File

@ -18,7 +18,7 @@
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.build.outputTimestamp>1714382357</project.build.outputTimestamp>
<openai4j.version>0.21.0</openai4j.version>
<openai4j.version>0.22.0</openai4j.version>
<azure-ai-openai.version>1.0.0-beta.11</azure-ai-openai.version>
<azure-ai-search.version>11.7.1</azure-ai-search.version>
<azure.storage-blob.version>12.28.0</azure.storage-blob.version>