OpenAI: return token usage returned by OpenAI (#1622)
This commit is contained in:
parent
33199dc588
commit
10ea33fe26
|
@ -102,7 +102,7 @@ public class LocalAiStreamingChatModel implements StreamingChatLanguageModel {
|
|||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(null);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
|
||||
|
||||
client.chatCompletion(request)
|
||||
.onPartialResponse(partialResponse -> {
|
||||
|
@ -110,7 +110,7 @@ public class LocalAiStreamingChatModel implements StreamingChatLanguageModel {
|
|||
handle(partialResponse, handler);
|
||||
})
|
||||
.onComplete(() -> {
|
||||
Response<AiMessage> response = responseBuilder.build(null, false);
|
||||
Response<AiMessage> response = responseBuilder.build();
|
||||
handler.onComplete(response);
|
||||
})
|
||||
.onError(handler::onError)
|
||||
|
|
|
@ -67,7 +67,7 @@ public class LocalAiStreamingLanguageModel implements StreamingLanguageModel {
|
|||
.maxTokens(maxTokens)
|
||||
.build();
|
||||
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(null);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
|
||||
|
||||
client.completion(request)
|
||||
.onPartialResponse(partialResponse -> {
|
||||
|
@ -78,7 +78,7 @@ public class LocalAiStreamingLanguageModel implements StreamingLanguageModel {
|
|||
}
|
||||
})
|
||||
.onComplete(() -> {
|
||||
Response<AiMessage> response = responseBuilder.build(null, false);
|
||||
Response<AiMessage> response = responseBuilder.build();
|
||||
handler.onComplete(Response.from(
|
||||
response.content().text(),
|
||||
response.tokenUsage(),
|
||||
|
|
|
@ -385,22 +385,6 @@ public class InternalOpenAiHelper {
|
|||
}
|
||||
}
|
||||
|
||||
static boolean isOpenAiModel(String modelName) {
|
||||
if (modelName == null) {
|
||||
return false;
|
||||
}
|
||||
for (OpenAiChatModelName openAiChatModelName : OpenAiChatModelName.values()) {
|
||||
if (modelName.contains(openAiChatModelName.toString())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static Response<AiMessage> removeTokenUsage(Response<AiMessage> response) {
|
||||
return Response.from(response.content(), null, response.finishReason());
|
||||
}
|
||||
|
||||
static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request,
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
|
|
|
@ -7,7 +7,7 @@ import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
|||
import dev.ai4j.openai4j.chat.Delta;
|
||||
import dev.ai4j.openai4j.chat.ResponseFormat;
|
||||
import dev.ai4j.openai4j.chat.ResponseFormatType;
|
||||
import dev.ai4j.openai4j.chat.StreamOptions;
|
||||
import dev.ai4j.openai4j.shared.StreamOptions;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
|
@ -42,8 +42,6 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGE
|
|||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.isOpenAiModel;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.removeTokenUsage;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
|
@ -76,7 +74,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
private final Boolean strictTools;
|
||||
private final Boolean parallelToolCalls;
|
||||
private final Tokenizer tokenizer;
|
||||
private final boolean isOpenAiModel;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
@Builder
|
||||
|
@ -138,7 +135,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
this.strictTools = getOrDefault(strictTools, false);
|
||||
this.parallelToolCalls = parallelToolCalls;
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.isOpenAiModel = isOpenAiModel(this.modelName);
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
}
|
||||
|
||||
|
@ -206,8 +202,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
}
|
||||
});
|
||||
|
||||
int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
|
||||
|
||||
AtomicReference<String> responseId = new AtomicReference<>();
|
||||
AtomicReference<String> responseModel = new AtomicReference<>();
|
||||
|
@ -225,7 +220,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
}
|
||||
})
|
||||
.onComplete(() -> {
|
||||
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
|
||||
Response<AiMessage> response = responseBuilder.build();
|
||||
|
||||
ChatModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
responseId.get(),
|
||||
|
@ -248,7 +243,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
handler.onComplete(response);
|
||||
})
|
||||
.onError(error -> {
|
||||
Response<AiMessage> response = createResponse(responseBuilder, toolThatMustBeExecuted);
|
||||
Response<AiMessage> response = responseBuilder.build();
|
||||
|
||||
ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
|
||||
responseId.get(),
|
||||
|
@ -276,27 +271,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
.execute();
|
||||
}
|
||||
|
||||
private Response<AiMessage> createResponse(OpenAiStreamingResponseBuilder responseBuilder,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
|
||||
if (isOpenAiModel) {
|
||||
return response;
|
||||
}
|
||||
return removeTokenUsage(response);
|
||||
}
|
||||
|
||||
private int countInputTokens(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages);
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
|
||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
|
||||
}
|
||||
return inputTokenCount;
|
||||
}
|
||||
|
||||
private static void handle(ChatCompletionResponse partialResponse,
|
||||
StreamingResponseHandler<AiMessage> handler) {
|
||||
List<ChatCompletionChoice> choices = partialResponse.choices();
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.OpenAiClient;
|
||||
import dev.ai4j.openai4j.completion.CompletionChoice;
|
||||
import dev.ai4j.openai4j.completion.CompletionRequest;
|
||||
import dev.ai4j.openai4j.shared.StreamOptions;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
|
@ -16,6 +18,7 @@ import java.time.Duration;
|
|||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
|
@ -77,24 +80,29 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
|
|||
public void generate(String prompt, StreamingResponseHandler<String> handler) {
|
||||
|
||||
CompletionRequest request = CompletionRequest.builder()
|
||||
.stream(true)
|
||||
.streamOptions(StreamOptions.builder()
|
||||
.includeUsage(true)
|
||||
.build())
|
||||
.model(modelName)
|
||||
.prompt(prompt)
|
||||
.temperature(temperature)
|
||||
.build();
|
||||
|
||||
int inputTokenCount = tokenizer.estimateTokenCountInText(prompt);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);
|
||||
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
|
||||
|
||||
client.completion(request)
|
||||
.onPartialResponse(partialResponse -> {
|
||||
responseBuilder.append(partialResponse);
|
||||
String token = partialResponse.text();
|
||||
if (token != null) {
|
||||
handler.onNext(token);
|
||||
for (CompletionChoice choice : partialResponse.choices()) {
|
||||
String token = choice.text();
|
||||
if (isNotNullOrEmpty(token)) {
|
||||
handler.onNext(token);
|
||||
}
|
||||
}
|
||||
})
|
||||
.onComplete(() -> {
|
||||
Response<AiMessage> response = responseBuilder.build(tokenizer, false);
|
||||
Response<AiMessage> response = responseBuilder.build();
|
||||
handler.onComplete(Response.from(
|
||||
response.content().text(),
|
||||
response.tokenUsage(),
|
||||
|
|
|
@ -10,7 +10,6 @@ import dev.ai4j.openai4j.completion.CompletionResponse;
|
|||
import dev.ai4j.openai4j.shared.Usage;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
@ -21,7 +20,6 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
/**
|
||||
|
@ -41,12 +39,6 @@ public class OpenAiStreamingResponseBuilder {
|
|||
private volatile TokenUsage tokenUsage;
|
||||
private volatile FinishReason finishReason;
|
||||
|
||||
private final Integer inputTokenCount;
|
||||
|
||||
public OpenAiStreamingResponseBuilder(Integer inputTokenCount) {
|
||||
this.inputTokenCount = inputTokenCount;
|
||||
}
|
||||
|
||||
public void append(ChatCompletionResponse partialResponse) {
|
||||
if (partialResponse == null) {
|
||||
return;
|
||||
|
@ -122,6 +114,11 @@ public class OpenAiStreamingResponseBuilder {
|
|||
return;
|
||||
}
|
||||
|
||||
Usage usage = partialResponse.usage();
|
||||
if (usage != null) {
|
||||
this.tokenUsage = tokenUsageFrom(usage);
|
||||
}
|
||||
|
||||
List<CompletionChoice> choices = partialResponse.choices();
|
||||
if (choices == null || choices.isEmpty()) {
|
||||
return;
|
||||
|
@ -143,13 +140,13 @@ public class OpenAiStreamingResponseBuilder {
|
|||
}
|
||||
}
|
||||
|
||||
public Response<AiMessage> build(Tokenizer tokenizer, boolean forcefulToolExecution) {
|
||||
public Response<AiMessage> build() {
|
||||
|
||||
String content = contentBuilder.toString();
|
||||
if (!content.isEmpty()) {
|
||||
return Response.from(
|
||||
AiMessage.from(content),
|
||||
tokenUsage(content, tokenizer),
|
||||
tokenUsage,
|
||||
finishReason
|
||||
);
|
||||
}
|
||||
|
@ -162,7 +159,7 @@ public class OpenAiStreamingResponseBuilder {
|
|||
.build();
|
||||
return Response.from(
|
||||
AiMessage.from(toolExecutionRequest),
|
||||
tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution),
|
||||
tokenUsage,
|
||||
finishReason
|
||||
);
|
||||
}
|
||||
|
@ -177,7 +174,7 @@ public class OpenAiStreamingResponseBuilder {
|
|||
.collect(toList());
|
||||
return Response.from(
|
||||
AiMessage.from(toolExecutionRequests),
|
||||
tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution),
|
||||
tokenUsage,
|
||||
finishReason
|
||||
);
|
||||
}
|
||||
|
@ -185,41 +182,6 @@ public class OpenAiStreamingResponseBuilder {
|
|||
return null;
|
||||
}
|
||||
|
||||
private TokenUsage tokenUsage(String content, Tokenizer tokenizer) {
|
||||
if (tokenUsage != null) {
|
||||
return tokenUsage;
|
||||
}
|
||||
|
||||
if (tokenizer == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
int outputTokenCount = tokenizer.estimateTokenCountInText(content);
|
||||
return new TokenUsage(inputTokenCount, outputTokenCount);
|
||||
}
|
||||
|
||||
private TokenUsage tokenUsage(List<ToolExecutionRequest> toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) {
|
||||
if (tokenUsage != null) {
|
||||
return tokenUsage;
|
||||
}
|
||||
|
||||
if (tokenizer == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
int outputTokenCount = 0;
|
||||
if (forcefulToolExecution) {
|
||||
// OpenAI calculates output tokens differently when tool is executed forcefully
|
||||
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
|
||||
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
|
||||
}
|
||||
} else {
|
||||
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests);
|
||||
}
|
||||
|
||||
return new TokenUsage(inputTokenCount, outputTokenCount);
|
||||
}
|
||||
|
||||
private static class ToolExecutionRequestBuilder {
|
||||
|
||||
private final StringBuffer idBuilder = new StringBuffer();
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.chat.*;
|
||||
import dev.ai4j.openai4j.chat.AssistantMessage;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
||||
import dev.ai4j.openai4j.chat.FunctionCall;
|
||||
import dev.ai4j.openai4j.chat.ToolCall;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static dev.ai4j.openai4j.chat.ToolType.FUNCTION;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
|
@ -134,27 +135,4 @@ class InternalOpenAiHelperTest {
|
|||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_isOpenAiModel() {
|
||||
|
||||
assertThat(isOpenAiModel(null)).isFalse();
|
||||
assertThat(isOpenAiModel("")).isFalse();
|
||||
assertThat(isOpenAiModel(" ")).isFalse();
|
||||
assertThat(isOpenAiModel("llama2")).isFalse();
|
||||
|
||||
assertThat(isOpenAiModel("gpt-3.5-turbo")).isTrue();
|
||||
assertThat(isOpenAiModel("ft:gpt-3.5-turbo:my-org:custom_suffix:id")).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_removeTokenUsage() {
|
||||
|
||||
assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"))))
|
||||
.isEqualTo(Response.from(AiMessage.from("Hello")));
|
||||
assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"), new TokenUsage(42))))
|
||||
.isEqualTo(Response.from(AiMessage.from("Hello")));
|
||||
assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"), new TokenUsage(42), STOP)))
|
||||
.isEqualTo(Response.from(AiMessage.from("Hello"), null, STOP));
|
||||
}
|
||||
}
|
|
@ -18,7 +18,7 @@
|
|||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<project.build.outputTimestamp>1714382357</project.build.outputTimestamp>
|
||||
<openai4j.version>0.21.0</openai4j.version>
|
||||
<openai4j.version>0.22.0</openai4j.version>
|
||||
<azure-ai-openai.version>1.0.0-beta.11</azure-ai-openai.version>
|
||||
<azure-ai-search.version>11.7.1</azure-ai-search.version>
|
||||
<azure.storage-blob.version>12.28.0</azure.storage-blob.version>
|
||||
|
|
Loading…
Reference in New Issue