From 10ea33fe262a4e3595033cde86d0bd0266470527 Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Sun, 22 Sep 2024 10:39:00 +0200 Subject: [PATCH] OpenAI: return token usage returned by OpenAI (#1622) --- .../localai/LocalAiStreamingChatModel.java | 4 +- .../LocalAiStreamingLanguageModel.java | 4 +- .../model/openai/InternalOpenAiHelper.java | 16 ------ .../openai/OpenAiStreamingChatModel.java | 34 ++--------- .../openai/OpenAiStreamingLanguageModel.java | 20 +++++-- .../OpenAiStreamingResponseBuilder.java | 56 +++---------------- .../openai/InternalOpenAiHelperTest.java | 34 ++--------- langchain4j-parent/pom.xml | 2 +- 8 files changed, 38 insertions(+), 132 deletions(-) diff --git a/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingChatModel.java b/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingChatModel.java index abc25b9e8..d22143b19 100644 --- a/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingChatModel.java +++ b/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingChatModel.java @@ -102,7 +102,7 @@ public class LocalAiStreamingChatModel implements StreamingChatLanguageModel { ChatCompletionRequest request = requestBuilder.build(); - OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(null); + OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(); client.chatCompletion(request) .onPartialResponse(partialResponse -> { @@ -110,7 +110,7 @@ public class LocalAiStreamingChatModel implements StreamingChatLanguageModel { handle(partialResponse, handler); }) .onComplete(() -> { - Response response = responseBuilder.build(null, false); + Response response = responseBuilder.build(); handler.onComplete(response); }) .onError(handler::onError) diff --git a/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingLanguageModel.java b/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingLanguageModel.java index 876f1da08..59cc6ec31 100644 --- a/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingLanguageModel.java +++ b/langchain4j-local-ai/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingLanguageModel.java @@ -67,7 +67,7 @@ public class LocalAiStreamingLanguageModel implements StreamingLanguageModel { .maxTokens(maxTokens) .build(); - OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(null); + OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(); client.completion(request) .onPartialResponse(partialResponse -> { @@ -78,7 +78,7 @@ public class LocalAiStreamingLanguageModel implements StreamingLanguageModel { } }) .onComplete(() -> { - Response response = responseBuilder.build(null, false); + Response response = responseBuilder.build(); handler.onComplete(Response.from( response.content().text(), response.tokenUsage(), diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java index 3046dd98a..bdfe1f4e4 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java @@ -385,22 +385,6 @@ public class InternalOpenAiHelper { } } - static boolean isOpenAiModel(String modelName) { - if (modelName == null) { - return false; - } - for (OpenAiChatModelName openAiChatModelName : OpenAiChatModelName.values()) { - if (modelName.contains(openAiChatModelName.toString())) { - return true; - } - } - return false; - } - - static Response removeTokenUsage(Response response) { - return Response.from(response.content(), null, response.finishReason()); - } - static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request, List messages, List toolSpecifications) { diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java index 9d9f2b208..2afb3155f 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java @@ -7,7 +7,7 @@ import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.ai4j.openai4j.chat.Delta; import dev.ai4j.openai4j.chat.ResponseFormat; import dev.ai4j.openai4j.chat.ResponseFormatType; -import dev.ai4j.openai4j.chat.StreamOptions; +import dev.ai4j.openai4j.shared.StreamOptions; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -42,8 +42,6 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGE import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL; import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest; import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse; -import static dev.langchain4j.model.openai.InternalOpenAiHelper.isOpenAiModel; -import static dev.langchain4j.model.openai.InternalOpenAiHelper.removeTokenUsage; import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages; import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; @@ -76,7 +74,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok private final Boolean strictTools; private final Boolean parallelToolCalls; private final Tokenizer tokenizer; - private final boolean isOpenAiModel; private final List listeners; @Builder @@ -138,7 +135,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok this.strictTools = getOrDefault(strictTools, false); this.parallelToolCalls = parallelToolCalls; this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new); - this.isOpenAiModel = isOpenAiModel(this.modelName); this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); } @@ -206,8 +202,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok } }); - int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted); - OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount); + OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(); AtomicReference responseId = new AtomicReference<>(); AtomicReference responseModel = new AtomicReference<>(); @@ -225,7 +220,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok } }) .onComplete(() -> { - Response response = createResponse(responseBuilder, toolThatMustBeExecuted); + Response response = responseBuilder.build(); ChatModelResponse modelListenerResponse = createModelListenerResponse( responseId.get(), @@ -248,7 +243,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok handler.onComplete(response); }) .onError(error -> { - Response response = createResponse(responseBuilder, toolThatMustBeExecuted); + Response response = responseBuilder.build(); ChatModelResponse modelListenerPartialResponse = createModelListenerResponse( responseId.get(), @@ -276,27 +271,6 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok .execute(); } - private Response createResponse(OpenAiStreamingResponseBuilder responseBuilder, - ToolSpecification toolThatMustBeExecuted) { - Response response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null); - if (isOpenAiModel) { - return response; - } - return removeTokenUsage(response); - } - - private int countInputTokens(List messages, - List toolSpecifications, - ToolSpecification toolThatMustBeExecuted) { - int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages); - if (toolThatMustBeExecuted != null) { - inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); - } else if (!isNullOrEmpty(toolSpecifications)) { - inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); - } - return inputTokenCount; - } - private static void handle(ChatCompletionResponse partialResponse, StreamingResponseHandler handler) { List choices = partialResponse.choices(); diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingLanguageModel.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingLanguageModel.java index 8c9c5823e..68eb8a562 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingLanguageModel.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingLanguageModel.java @@ -1,7 +1,9 @@ package dev.langchain4j.model.openai; import dev.ai4j.openai4j.OpenAiClient; +import dev.ai4j.openai4j.completion.CompletionChoice; import dev.ai4j.openai4j.completion.CompletionRequest; +import dev.ai4j.openai4j.shared.StreamOptions; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.Tokenizer; @@ -16,6 +18,7 @@ import java.time.Duration; import java.util.Map; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNotNullOrEmpty; import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT; import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_INSTRUCT; @@ -77,24 +80,29 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok public void generate(String prompt, StreamingResponseHandler handler) { CompletionRequest request = CompletionRequest.builder() + .stream(true) + .streamOptions(StreamOptions.builder() + .includeUsage(true) + .build()) .model(modelName) .prompt(prompt) .temperature(temperature) .build(); - int inputTokenCount = tokenizer.estimateTokenCountInText(prompt); - OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount); + OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(); client.completion(request) .onPartialResponse(partialResponse -> { responseBuilder.append(partialResponse); - String token = partialResponse.text(); - if (token != null) { - handler.onNext(token); + for (CompletionChoice choice : partialResponse.choices()) { + String token = choice.text(); + if (isNotNullOrEmpty(token)) { + handler.onNext(token); + } } }) .onComplete(() -> { - Response response = responseBuilder.build(tokenizer, false); + Response response = responseBuilder.build(); handler.onComplete(Response.from( response.content().text(), response.tokenUsage(), diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingResponseBuilder.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingResponseBuilder.java index 9d2ca58ad..690221c31 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingResponseBuilder.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingResponseBuilder.java @@ -10,7 +10,6 @@ import dev.ai4j.openai4j.completion.CompletionResponse; import dev.ai4j.openai4j.shared.Usage; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -21,7 +20,6 @@ import java.util.concurrent.ConcurrentHashMap; import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom; import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom; -import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; /** @@ -41,12 +39,6 @@ public class OpenAiStreamingResponseBuilder { private volatile TokenUsage tokenUsage; private volatile FinishReason finishReason; - private final Integer inputTokenCount; - - public OpenAiStreamingResponseBuilder(Integer inputTokenCount) { - this.inputTokenCount = inputTokenCount; - } - public void append(ChatCompletionResponse partialResponse) { if (partialResponse == null) { return; @@ -122,6 +114,11 @@ public class OpenAiStreamingResponseBuilder { return; } + Usage usage = partialResponse.usage(); + if (usage != null) { + this.tokenUsage = tokenUsageFrom(usage); + } + List choices = partialResponse.choices(); if (choices == null || choices.isEmpty()) { return; @@ -143,13 +140,13 @@ public class OpenAiStreamingResponseBuilder { } } - public Response build(Tokenizer tokenizer, boolean forcefulToolExecution) { + public Response build() { String content = contentBuilder.toString(); if (!content.isEmpty()) { return Response.from( AiMessage.from(content), - tokenUsage(content, tokenizer), + tokenUsage, finishReason ); } @@ -162,7 +159,7 @@ public class OpenAiStreamingResponseBuilder { .build(); return Response.from( AiMessage.from(toolExecutionRequest), - tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution), + tokenUsage, finishReason ); } @@ -177,7 +174,7 @@ public class OpenAiStreamingResponseBuilder { .collect(toList()); return Response.from( AiMessage.from(toolExecutionRequests), - tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution), + tokenUsage, finishReason ); } @@ -185,41 +182,6 @@ public class OpenAiStreamingResponseBuilder { return null; } - private TokenUsage tokenUsage(String content, Tokenizer tokenizer) { - if (tokenUsage != null) { - return tokenUsage; - } - - if (tokenizer == null) { - return null; - } - - int outputTokenCount = tokenizer.estimateTokenCountInText(content); - return new TokenUsage(inputTokenCount, outputTokenCount); - } - - private TokenUsage tokenUsage(List toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) { - if (tokenUsage != null) { - return tokenUsage; - } - - if (tokenizer == null) { - return null; - } - - int outputTokenCount = 0; - if (forcefulToolExecution) { - // OpenAI calculates output tokens differently when tool is executed forcefully - for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { - outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest); - } - } else { - outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); - } - - return new TokenUsage(inputTokenCount, outputTokenCount); - } - private static class ToolExecutionRequestBuilder { private final StringBuffer idBuilder = new StringBuffer(); diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/InternalOpenAiHelperTest.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/InternalOpenAiHelperTest.java index e0cd51e28..2d93d38ca 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/InternalOpenAiHelperTest.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/InternalOpenAiHelperTest.java @@ -1,15 +1,16 @@ package dev.langchain4j.model.openai; -import dev.ai4j.openai4j.chat.*; +import dev.ai4j.openai4j.chat.AssistantMessage; +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.chat.FunctionCall; +import dev.ai4j.openai4j.chat.ToolCall; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; import static dev.ai4j.openai4j.chat.ToolType.FUNCTION; -import static dev.langchain4j.model.openai.InternalOpenAiHelper.*; -import static dev.langchain4j.model.output.FinishReason.STOP; +import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @@ -134,27 +135,4 @@ class InternalOpenAiHelperTest { .build() ); } - - @Test - void test_isOpenAiModel() { - - assertThat(isOpenAiModel(null)).isFalse(); - assertThat(isOpenAiModel("")).isFalse(); - assertThat(isOpenAiModel(" ")).isFalse(); - assertThat(isOpenAiModel("llama2")).isFalse(); - - assertThat(isOpenAiModel("gpt-3.5-turbo")).isTrue(); - assertThat(isOpenAiModel("ft:gpt-3.5-turbo:my-org:custom_suffix:id")).isTrue(); - } - - @Test - void test_removeTokenUsage() { - - assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello")))) - .isEqualTo(Response.from(AiMessage.from("Hello"))); - assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"), new TokenUsage(42)))) - .isEqualTo(Response.from(AiMessage.from("Hello"))); - assertThat(removeTokenUsage(Response.from(AiMessage.from("Hello"), new TokenUsage(42), STOP))) - .isEqualTo(Response.from(AiMessage.from("Hello"), null, STOP)); - } } \ No newline at end of file diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml index 755938e57..d5a473e23 100644 --- a/langchain4j-parent/pom.xml +++ b/langchain4j-parent/pom.xml @@ -18,7 +18,7 @@ 8 UTF-8 1714382357 - 0.21.0 + 0.22.0 1.0.0-beta.11 11.7.1 12.28.0