diff --git a/langchain4j-jlama/README.md b/langchain4j-jlama/README.md index 8909b32f9..9c4eebd59 100644 --- a/langchain4j-jlama/README.md +++ b/langchain4j-jlama/README.md @@ -3,7 +3,7 @@ [Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java applications. -Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference. +Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference. Jlama uses huggingface models in safetensor format. Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`. diff --git a/langchain4j-jlama/pom.xml b/langchain4j-jlama/pom.xml index c55a8aee2..9ea8e7ac7 100644 --- a/langchain4j-jlama/pom.xml +++ b/langchain4j-jlama/pom.xml @@ -5,7 +5,7 @@ 4.0.0 langchain4j-jlama LangChain4j :: Integration :: Jlama - Jlama: Pure Java LLM Inference Engine - Requires Java 21 + Jlama: LLM Inference Engine for Java - Requires Java 20+ dev.langchain4j @@ -15,7 +15,7 @@ - 0.3.1 + 0.5.0 2.16.1 2.40.0 21 diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java index 54276dc85..666a4cb90 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java @@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.util.JsonSupport; import dev.langchain4j.agent.tool.ToolExecutionRequest; @@ -33,6 +34,7 @@ public class JlamaChatModel implements ChatLanguageModel { Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -42,6 +44,9 @@ public class JlamaChatModel implements ChatLanguageModel { if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java index cf71198ed..94fea97aa 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java @@ -3,6 +3,7 @@ package dev.langchain4j.model.jlama; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.ModelSupport; import com.github.tjake.jlama.model.bert.BertModel; +import com.github.tjake.jlama.model.functions.Generator; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.internal.RetryUtils; @@ -20,6 +21,7 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories; public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel { private final BertModel model; + private final Generator.PoolingType poolingType; @Builder public JlamaEmbeddingModel(Path modelCachePath, @@ -27,6 +29,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel { String authToken, Integer threadCount, Boolean quantizeModelAtRuntime, + Generator.PoolingType poolingType, Path workingDirectory) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -46,10 +49,12 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel { if (workingDirectory != null) loader = loader.workingDirectory(workingDirectory); - loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS); + loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING); this.model = (BertModel) loader.load(); this.dimension = model.getConfig().embeddingLength; + + this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType; } public static JlamaEmbeddingModelBuilder builder() { @@ -64,7 +69,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel { List embeddings = new ArrayList<>(); textSegments.forEach(textSegment -> { - embeddings.add(Embedding.from(model.embed(textSegment.text()))); + embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType))); }); return Response.from(embeddings); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java index 1bf0d998e..8d81b3f83 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java @@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory; @@ -30,6 +31,7 @@ public class JlamaLanguageModel implements LanguageModel { Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -39,6 +41,9 @@ public class JlamaLanguageModel implements LanguageModel { if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java index fc298900d..1b50b1fe6 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java @@ -60,6 +60,7 @@ class JlamaModel { registry.getModelCachePath().toString(), owner, modelName, + true, Optional.empty(), authToken, Optional.empty()); @@ -67,6 +68,7 @@ class JlamaModel { public class Loader { private Path workingDirectory; + private DType workingQuantizationType = DType.I8; private DType quantizationType; private Integer threadCount; private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION; @@ -75,11 +77,19 @@ class JlamaModel { } public Loader quantized() { - //For now only allow Q4 quantization at load time + //For now only allow Q4 quantization at runtime this.quantizationType = DType.Q4; return this; } + /** + * Set the working quantization type. This is the type that the model will use for working inference memory. + */ + public Loader workingQuantizationType(DType workingQuantizationType) { + this.workingQuantizationType = workingQuantizationType; + return this; + } + public Loader workingDirectory(Path workingDirectory) { this.workingDirectory = workingDirectory; return this; @@ -101,10 +111,11 @@ class JlamaModel { new File(registry.getModelCachePath().toFile(), modelName), workingDirectory == null ? null : workingDirectory.toFile(), DType.F32, - DType.I8, + workingQuantizationType, Optional.ofNullable(quantizationType), Optional.ofNullable(threadCount), - Optional.empty()); + Optional.empty(), + SafeTensorSupport::loadWeights); } } diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java index 36251dd56..446b02dbb 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java @@ -90,7 +90,7 @@ class JlamaModelRegistry { name = parts[1]; } - File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty()); + File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty()); File config = new File(modelDir, "config.json"); ModelSupport.ModelType type = SafeTensorSupport.detectModel(config); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java index 90a743690..5fe99a6c1 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java @@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -34,6 +35,7 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel { Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -43,6 +45,9 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel { if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java index 8a3e65280..0405f31c6 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java @@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.model.StreamingResponseHandler; @@ -31,6 +32,7 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel { Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -40,6 +42,9 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel { if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java index 8e715c1c9..31085aa42 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java @@ -29,7 +29,7 @@ class JlamaChatModelIT { .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) .temperature(0.0f) - .maxTokens(30) + .maxTokens(64) .build(); } diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java index d4d614a06..c304e23e9 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java @@ -26,7 +26,7 @@ class JlamaLanguageModelIT { .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) .temperature(0.0f) - .maxTokens(30) + .maxTokens(64) .build(); } diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java index e4bdc4baf..97e5fdd58 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java @@ -27,7 +27,7 @@ class JlamaStreamingChatModelIT { model = JlamaStreamingChatModel.builder() .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) - .maxTokens(30) + .maxTokens(64) .temperature(0.0f) .build(); } diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java index 1d9c30217..818a6a2d5 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java @@ -26,7 +26,7 @@ class JlamaStreamingLanguageModelIT { model = JlamaStreamingLanguageModel.builder() .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) - .maxTokens(30) + .maxTokens(64) .temperature(0.0f) .build(); }