Fix a token usage statistical issue in DefaultAiServices (#280)
The statistical logic on token usage is consistent with AiServiceStreamingResponseHandler. And more: When it comes to rolling messages in MessageWindowChatMemory/TokenWindowChatMemory, LinkedList offers superior performance. ArrayList moves all the elements when the first element is deleted.
This commit is contained in:
parent
06ada5310d
commit
7cfa1a5dfc
|
@ -8,7 +8,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ public class MessageWindowChatMemory implements ChatMemory {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ChatMessage> messages() {
|
public List<ChatMessage> messages() {
|
||||||
List<ChatMessage> messages = new ArrayList<>(store.getMessages(id));
|
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
|
||||||
ensureCapacity(messages, maxMessages);
|
ensureCapacity(messages, maxMessages);
|
||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ public class TokenWindowChatMemory implements ChatMemory {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ChatMessage> messages() {
|
public List<ChatMessage> messages() {
|
||||||
List<ChatMessage> messages = new ArrayList<>(store.getMessages(id));
|
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
|
||||||
ensureCapacity(messages, maxTokens, tokenizer);
|
ensureCapacity(messages, maxTokens, tokenizer);
|
||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import dev.langchain4j.model.input.structured.StructuredPrompt;
|
||||||
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
|
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
|
||||||
import dev.langchain4j.model.moderation.Moderation;
|
import dev.langchain4j.model.moderation.Moderation;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
import java.lang.reflect.Array;
|
import java.lang.reflect.Array;
|
||||||
import java.lang.reflect.InvocationHandler;
|
import java.lang.reflect.InvocationHandler;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
@ -33,6 +34,7 @@ import java.util.Optional;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.Future;
|
import java.util.concurrent.Future;
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -142,6 +144,7 @@ class DefaultAiServices<T> extends AiServices<T> {
|
||||||
|
|
||||||
verifyModerationIfNeeded(moderationFuture);
|
verifyModerationIfNeeded(moderationFuture);
|
||||||
|
|
||||||
|
TokenUsage tokenUsage = new TokenUsage();
|
||||||
ToolExecutionRequest toolExecutionRequest;
|
ToolExecutionRequest toolExecutionRequest;
|
||||||
while (true) { // TODO limit number of cycles
|
while (true) { // TODO limit number of cycles
|
||||||
|
|
||||||
|
@ -149,6 +152,7 @@ class DefaultAiServices<T> extends AiServices<T> {
|
||||||
context.chatMemory(memoryId).add(response.content());
|
context.chatMemory(memoryId).add(response.content());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenUsage = tokenUsage.add(response.tokenUsage());
|
||||||
toolExecutionRequest = response.content().toolExecutionRequest();
|
toolExecutionRequest = response.content().toolExecutionRequest();
|
||||||
if (toolExecutionRequest == null) {
|
if (toolExecutionRequest == null) {
|
||||||
break;
|
break;
|
||||||
|
@ -165,6 +169,7 @@ class DefaultAiServices<T> extends AiServices<T> {
|
||||||
response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
|
response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response = Response.from(response.content(), tokenUsage, response.finishReason());
|
||||||
return ServiceOutputParser.parse(response, method.getReturnType());
|
return ServiceOutputParser.parse(response, method.getReturnType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue