diff --git a/langchain4j/src/main/java/dev/langchain4j/memory/chat/MessageWindowChatMemory.java b/langchain4j/src/main/java/dev/langchain4j/memory/chat/MessageWindowChatMemory.java index 1fd5e96ef..9c2f81b14 100644 --- a/langchain4j/src/main/java/dev/langchain4j/memory/chat/MessageWindowChatMemory.java +++ b/langchain4j/src/main/java/dev/langchain4j/memory/chat/MessageWindowChatMemory.java @@ -8,7 +8,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Optional; @@ -73,7 +73,7 @@ public class MessageWindowChatMemory implements ChatMemory { @Override public List messages() { - List messages = new ArrayList<>(store.getMessages(id)); + List messages = new LinkedList<>(store.getMessages(id)); ensureCapacity(messages, maxMessages); return messages; } diff --git a/langchain4j/src/main/java/dev/langchain4j/memory/chat/TokenWindowChatMemory.java b/langchain4j/src/main/java/dev/langchain4j/memory/chat/TokenWindowChatMemory.java index 30717db8a..394dadbc6 100644 --- a/langchain4j/src/main/java/dev/langchain4j/memory/chat/TokenWindowChatMemory.java +++ b/langchain4j/src/main/java/dev/langchain4j/memory/chat/TokenWindowChatMemory.java @@ -9,7 +9,7 @@ import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Optional; @@ -77,7 +77,7 @@ public class TokenWindowChatMemory implements ChatMemory { @Override public List messages() { - List messages = new ArrayList<>(store.getMessages(id)); + List messages = new LinkedList<>(store.getMessages(id)); ensureCapacity(messages, maxTokens, tokenizer); return messages; } diff --git a/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java b/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java index 7f4689289..84ecc48c0 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java @@ -20,6 +20,7 @@ import dev.langchain4j.model.input.structured.StructuredPrompt; import dev.langchain4j.model.input.structured.StructuredPromptProcessor; import dev.langchain4j.model.moderation.Moderation; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import java.lang.reflect.Array; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; @@ -33,6 +34,7 @@ import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -142,6 +144,7 @@ class DefaultAiServices extends AiServices { verifyModerationIfNeeded(moderationFuture); + TokenUsage tokenUsage = new TokenUsage(); ToolExecutionRequest toolExecutionRequest; while (true) { // TODO limit number of cycles @@ -149,6 +152,7 @@ class DefaultAiServices extends AiServices { context.chatMemory(memoryId).add(response.content()); } + tokenUsage = tokenUsage.add(response.tokenUsage()); toolExecutionRequest = response.content().toolExecutionRequest(); if (toolExecutionRequest == null) { break; @@ -165,6 +169,7 @@ class DefaultAiServices extends AiServices { response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications); } + response = Response.from(response.content(), tokenUsage, response.finishReason()); return ServiceOutputParser.parse(response, method.getReturnType()); }