From 8ceb6ba3f7dd32eda61f0d8b4d6e7c07bb3274ad Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Wed, 31 Jan 2024 10:11:33 +0100 Subject: [PATCH] OpenAiTokenizer: default ctor with GPT_3_5_TURBO by default --- .../SampleDocumentLoaderAndRagWithAstraTest.java | 2 +- .../chat/cassandra/ChatMemoryStoreAstraTest.java | 3 +-- .../langchain4j/model/openai/OpenAiTokenizer.java | 12 ++++++++++++ .../langchain4j/model/openai/OpenAiTokenizerIT.java | 3 ++- .../model/openai/OpenAiTokenizerTest.java | 3 +-- .../splitter/DocumentByParagraphSplitterTest.java | 11 +++++------ .../splitter/DocumentBySentenceSplitterTest.java | 3 +-- .../java/dev/langchain4j/internal/TestUtils.java | 3 +-- .../memory/chat/TokenWindowChatMemoryTest.java | 9 ++++----- .../dev/langchain4j/service/AiServicesWithRagIT.java | 4 ++-- 10 files changed, 30 insertions(+), 23 deletions(-) diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java index 3813d23a0..4236e6140 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/SampleDocumentLoaderAndRagWithAstraTest.java @@ -59,7 +59,7 @@ class SampleDocumentLoaderAndRagWithAstraTest { Path path = new File(getClass().getResource("/story-about-happy-carrot.txt").getFile()).toPath(); Document document = FileSystemDocumentLoader.loadDocument(path, new TextDocumentParser()); DocumentSplitter splitter = DocumentSplitters - .recursive(100, 10, new OpenAiTokenizer(GPT_3_5_TURBO)); + .recursive(100, 10, new OpenAiTokenizer()); // Embedding model (OpenAI) EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder() diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java index 238429f4d..b9b09f0b7 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/ChatMemoryStoreAstraTest.java @@ -16,7 +16,6 @@ import java.util.UUID; import static com.dtsx.astra.sdk.utils.TestUtils.*; import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -53,7 +52,7 @@ public class ChatMemoryStoreAstraTest { ChatMemory chatMemory = TokenWindowChatMemory.builder() .chatMemoryStore(chatMemoryStore) .id(chatSessionId) - .maxTokens(300, new OpenAiTokenizer(GPT_3_5_TURBO)) + .maxTokens(300, new OpenAiTokenizer()) .build(); // When diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java index 78fefe3bd..63310f409 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java @@ -30,6 +30,18 @@ public class OpenAiTokenizer implements Tokenizer { private final String modelName; private final Optional encoding; + public OpenAiTokenizer() { + this(GPT_3_5_TURBO); + } + + public OpenAiTokenizer(OpenAiChatModelName modelName) { + this(modelName.toString()); + } + + public OpenAiTokenizer(OpenAiLanguageModelName modelName) { + this(modelName.toString()); + } + public OpenAiTokenizer(String modelName) { this.modelName = ensureNotBlank(modelName, "modelName"); // If the model is unknown, we should NOT fail fast during the creation of OpenAiTokenizer. diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java index 0ad289766..415296230 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java @@ -25,6 +25,7 @@ import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.data.message.SystemMessage.systemMessage; import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO; import static java.util.Arrays.asList; import static java.util.Arrays.stream; import static java.util.Collections.singletonList; @@ -1088,7 +1089,7 @@ class OpenAiTokenizerIT { // TODO remove once they fix it e.printStackTrace(); // there is some pattern to it, so we are going to check if this is really the case or our calculation is wrong - Tokenizer tokenizer2 = new OpenAiTokenizer(GPT_3_5_TURBO.toString()); + Tokenizer tokenizer2 = new OpenAiTokenizer(GPT_3_5_TURBO); int tokenCount2 = tokenizer2.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); assertThat(tokenCount2).isEqualTo(expectedTokenCount - 3); } else { diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java index 3c5130eff..0d1713ff1 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java @@ -9,13 +9,12 @@ import org.junit.jupiter.params.provider.EnumSource; import java.util.ArrayList; import java.util.List; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.openai.OpenAiTokenizer.countArguments; import static org.assertj.core.api.Assertions.assertThat; class OpenAiTokenizerTest { - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + OpenAiTokenizer tokenizer = new OpenAiTokenizer(); @Test void should_encode_and_decode_text() { diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java index 70c9c4164..171d6a06b 100644 --- a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java @@ -12,7 +12,6 @@ import java.util.List; import static dev.langchain4j.data.document.Metadata.metadata; import static dev.langchain4j.data.segment.TextSegment.textSegment; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; @@ -123,7 +122,7 @@ class DocumentByParagraphSplitterTest { void should_split_sample_text_containing_multiple_paragraphs() { int maxSegmentSize = 65; - Tokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + Tokenizer tokenizer = new OpenAiTokenizer(); String p1 = "In a small town nestled between two vast mountains, there was a shop unlike any other. " + "A unique haven. " + @@ -200,7 +199,7 @@ class DocumentByParagraphSplitterTest { int maxSegmentSize = 65; int maxOverlapSize = 15; - Tokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + Tokenizer tokenizer = new OpenAiTokenizer(); String s1 = "In a small town nestled between two vast mountains, there was a shop unlike any other."; String s2 = "A unique haven."; @@ -269,7 +268,7 @@ class DocumentByParagraphSplitterTest { void should_split_sample_text_without_paragraphs() { int maxSegmentSize = 100; - Tokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + Tokenizer tokenizer = new OpenAiTokenizer(); String segment1 = "In a small town nestled between two vast mountains, there was a shop unlike any other. " + "A unique haven. " + @@ -332,7 +331,7 @@ class DocumentByParagraphSplitterTest { // given int maxSegmentSize = 100; int maxOverlapSize = 25; - Tokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + Tokenizer tokenizer = new OpenAiTokenizer(); DocumentSplitter splitter = new DocumentByParagraphSplitter(maxSegmentSize, maxOverlapSize, tokenizer); Document document = Document.from(sentences(0, 28), Metadata.from("document", "0")); @@ -364,7 +363,7 @@ class DocumentByParagraphSplitterTest { // given int maxSegmentSize = 100; int maxOverlapSize = 80; - Tokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + Tokenizer tokenizer = new OpenAiTokenizer(); DocumentSplitter splitter = new DocumentByParagraphSplitter(maxSegmentSize, maxOverlapSize, tokenizer); Document document = Document.from(sentences(0, 28), Metadata.from("document", "0")); diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java index fbfa57ab3..051c87a9c 100644 --- a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java @@ -10,7 +10,6 @@ import java.util.List; import static dev.langchain4j.data.document.Metadata.metadata; import static dev.langchain4j.data.segment.TextSegment.textSegment; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; @@ -152,7 +151,7 @@ class DocumentBySentenceSplitterTest { ); int maxSegmentSize = 26; - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + OpenAiTokenizer tokenizer = new OpenAiTokenizer(); DocumentSplitter splitter = new DocumentBySentenceSplitter(maxSegmentSize, 0, tokenizer); List segments = splitter.split(document); diff --git a/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java b/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java index 860dac93f..645b7807c 100644 --- a/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java +++ b/langchain4j/src/test/java/dev/langchain4j/internal/TestUtils.java @@ -15,14 +15,13 @@ import java.util.List; import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.data.message.SystemMessage.systemMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static org.assertj.core.api.Assertions.assertThat; public class TestUtils { private static final int EXTRA_TOKENS_PER_EACH_MESSAGE = 3 /* extra tokens for each message */ + 1 /* extra token for 'role' */; - private static final OpenAiTokenizer TOKENIZER = new OpenAiTokenizer(GPT_3_5_TURBO); + private static final OpenAiTokenizer TOKENIZER = new OpenAiTokenizer(); @ParameterizedTest @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000}) diff --git a/langchain4j/src/test/java/dev/langchain4j/memory/chat/TokenWindowChatMemoryTest.java b/langchain4j/src/test/java/dev/langchain4j/memory/chat/TokenWindowChatMemoryTest.java index 95a925bb3..e6c51bbaf 100644 --- a/langchain4j/src/test/java/dev/langchain4j/memory/chat/TokenWindowChatMemoryTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/memory/chat/TokenWindowChatMemoryTest.java @@ -9,7 +9,6 @@ import org.junit.jupiter.api.Test; import static dev.langchain4j.data.message.SystemMessage.systemMessage; import static dev.langchain4j.internal.TestUtils.*; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static org.assertj.core.api.Assertions.assertThat; class TokenWindowChatMemoryTest { @@ -17,7 +16,7 @@ class TokenWindowChatMemoryTest { @Test void should_keep_specified_number_of_tokens_in_chat_memory() { - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + OpenAiTokenizer tokenizer = new OpenAiTokenizer(); ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(33, tokenizer); UserMessage firstUserMessage = userMessageWithTokens(10); @@ -78,7 +77,7 @@ class TokenWindowChatMemoryTest { @Test void should_not_remove_system_message_from_chat_memory() { - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + OpenAiTokenizer tokenizer = new OpenAiTokenizer(); ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(33, tokenizer); SystemMessage systemMessage = systemMessageWithTokens(10); @@ -118,7 +117,7 @@ class TokenWindowChatMemoryTest { @Test void should_keep_only_the_latest_system_message_in_chat_memory() { - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + OpenAiTokenizer tokenizer = new OpenAiTokenizer(); ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(40, tokenizer); SystemMessage firstSystemMessage = systemMessage("You are a helpful assistant"); @@ -149,7 +148,7 @@ class TokenWindowChatMemoryTest { @Test void should_not_add_the_same_system_message_to_chat_memory_if_it_is_already_there() { - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + OpenAiTokenizer tokenizer = new OpenAiTokenizer(); ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(33, tokenizer); SystemMessage systemMessage = systemMessageWithTokens(10); diff --git a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java index 0875cb349..e5f6f1eb7 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java @@ -8,6 +8,7 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -44,7 +45,6 @@ import java.util.Map; import java.util.stream.Stream; import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; import static org.assertj.core.api.Assertions.assertThat; @@ -329,7 +329,7 @@ class AiServicesWithRagIT { } private void ingest(EmbeddingStore embeddingStore, EmbeddingModel embeddingModel) { - OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); + Tokenizer tokenizer = new OpenAiTokenizer(); DocumentSplitter splitter = DocumentSplitters.recursive(100, 0, tokenizer); EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() .documentSplitter(splitter)