Fix a token usage statistical issue in DefaultAiServices (#280)

The statistical logic on token usage is consistent with AiServiceStreamingResponseHandler.

And more:
When it comes to rolling messages in
MessageWindowChatMemory/TokenWindowChatMemory, LinkedList offers
superior performance. ArrayList moves all the elements when the first
element is deleted.
This commit is contained in:
jiangsier-xyz 2023-11-24 19:48:45 +08:00 committed by GitHub
parent 06ada5310d
commit 7cfa1a5dfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 4 deletions

View File

@ -8,7 +8,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.ArrayList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@ -73,7 +73,7 @@ public class MessageWindowChatMemory implements ChatMemory {
@Override @Override
public List<ChatMessage> messages() { public List<ChatMessage> messages() {
List<ChatMessage> messages = new ArrayList<>(store.getMessages(id)); List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
ensureCapacity(messages, maxMessages); ensureCapacity(messages, maxMessages);
return messages; return messages;
} }

View File

@ -9,7 +9,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.ArrayList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@ -77,7 +77,7 @@ public class TokenWindowChatMemory implements ChatMemory {
@Override @Override
public List<ChatMessage> messages() { public List<ChatMessage> messages() {
List<ChatMessage> messages = new ArrayList<>(store.getMessages(id)); List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
ensureCapacity(messages, maxTokens, tokenizer); ensureCapacity(messages, maxTokens, tokenizer);
return messages; return messages;
} }

View File

@ -20,6 +20,7 @@ import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor; import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation; import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.lang.reflect.Array; import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method; import java.lang.reflect.Method;
@ -33,6 +34,7 @@ import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -142,6 +144,7 @@ class DefaultAiServices<T> extends AiServices<T> {
verifyModerationIfNeeded(moderationFuture); verifyModerationIfNeeded(moderationFuture);
TokenUsage tokenUsage = new TokenUsage();
ToolExecutionRequest toolExecutionRequest; ToolExecutionRequest toolExecutionRequest;
while (true) { // TODO limit number of cycles while (true) { // TODO limit number of cycles
@ -149,6 +152,7 @@ class DefaultAiServices<T> extends AiServices<T> {
context.chatMemory(memoryId).add(response.content()); context.chatMemory(memoryId).add(response.content());
} }
tokenUsage = tokenUsage.add(response.tokenUsage());
toolExecutionRequest = response.content().toolExecutionRequest(); toolExecutionRequest = response.content().toolExecutionRequest();
if (toolExecutionRequest == null) { if (toolExecutionRequest == null) {
break; break;
@ -165,6 +169,7 @@ class DefaultAiServices<T> extends AiServices<T> {
response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications); response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
} }
response = Response.from(response.content(), tokenUsage, response.finishReason());
return ServiceOutputParser.parse(response, method.getReturnType()); return ServiceOutputParser.parse(response, method.getReturnType());
} }