Fix a token usage statistical issue in DefaultAiServices (#280)
The statistical logic on token usage is consistent with AiServiceStreamingResponseHandler. And more: When it comes to rolling messages in MessageWindowChatMemory/TokenWindowChatMemory, LinkedList offers superior performance. ArrayList moves all the elements when the first element is deleted.
This commit is contained in:
parent
06ada5310d
commit
7cfa1a5dfc
|
@ -8,7 +8,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
|
@ -73,7 +73,7 @@ public class MessageWindowChatMemory implements ChatMemory {
|
|||
|
||||
@Override
|
||||
public List<ChatMessage> messages() {
|
||||
List<ChatMessage> messages = new ArrayList<>(store.getMessages(id));
|
||||
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
|
||||
ensureCapacity(messages, maxMessages);
|
||||
return messages;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
|
@ -77,7 +77,7 @@ public class TokenWindowChatMemory implements ChatMemory {
|
|||
|
||||
@Override
|
||||
public List<ChatMessage> messages() {
|
||||
List<ChatMessage> messages = new ArrayList<>(store.getMessages(id));
|
||||
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
|
||||
ensureCapacity(messages, maxTokens, tokenizer);
|
||||
return messages;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import dev.langchain4j.model.input.structured.StructuredPrompt;
|
|||
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
|
||||
import dev.langchain4j.model.moderation.Moderation;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import java.lang.reflect.Array;
|
||||
import java.lang.reflect.InvocationHandler;
|
||||
import java.lang.reflect.Method;
|
||||
|
@ -33,6 +34,7 @@ import java.util.Optional;
|
|||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -142,6 +144,7 @@ class DefaultAiServices<T> extends AiServices<T> {
|
|||
|
||||
verifyModerationIfNeeded(moderationFuture);
|
||||
|
||||
TokenUsage tokenUsage = new TokenUsage();
|
||||
ToolExecutionRequest toolExecutionRequest;
|
||||
while (true) { // TODO limit number of cycles
|
||||
|
||||
|
@ -149,6 +152,7 @@ class DefaultAiServices<T> extends AiServices<T> {
|
|||
context.chatMemory(memoryId).add(response.content());
|
||||
}
|
||||
|
||||
tokenUsage = tokenUsage.add(response.tokenUsage());
|
||||
toolExecutionRequest = response.content().toolExecutionRequest();
|
||||
if (toolExecutionRequest == null) {
|
||||
break;
|
||||
|
@ -165,6 +169,7 @@ class DefaultAiServices<T> extends AiServices<T> {
|
|||
response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
|
||||
}
|
||||
|
||||
response = Response.from(response.content(), tokenUsage, response.finishReason());
|
||||
return ServiceOutputParser.parse(response, method.getReturnType());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue