diff --git a/langchain4j-jlama/README.md b/langchain4j-jlama/README.md
index 8909b32f9..9c4eebd59 100644
--- a/langchain4j-jlama/README.md
+++ b/langchain4j-jlama/README.md
@@ -3,7 +3,7 @@
[Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java
applications.
-Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
+Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
Jlama uses huggingface models in safetensor format.
Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`.
diff --git a/langchain4j-jlama/pom.xml b/langchain4j-jlama/pom.xml
index c55a8aee2..9ea8e7ac7 100644
--- a/langchain4j-jlama/pom.xml
+++ b/langchain4j-jlama/pom.xml
@@ -5,7 +5,7 @@
4.0.0
langchain4j-jlama
LangChain4j :: Integration :: Jlama
- Jlama: Pure Java LLM Inference Engine - Requires Java 21
+ Jlama: LLM Inference Engine for Java - Requires Java 20+
dev.langchain4j
@@ -15,7 +15,7 @@
- 0.3.1
+ 0.5.0
2.16.1
2.40.0
21
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java
index 54276dc85..666a4cb90 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java
@@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
+import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.*;
import com.github.tjake.jlama.util.JsonSupport;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
@@ -33,6 +34,7 @@ public class JlamaChatModel implements ChatLanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
+ DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@@ -42,6 +44,9 @@ public class JlamaChatModel implements ChatLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
+ if (workingQuantizedType != null)
+ loader = loader.workingQuantizationType(workingQuantizedType);
+
if (threadCount != null)
loader = loader.threadCount(threadCount);
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java
index cf71198ed..94fea97aa 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java
@@ -3,6 +3,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.bert.BertModel;
+import com.github.tjake.jlama.model.functions.Generator;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
@@ -20,6 +21,7 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories;
public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
private final BertModel model;
+ private final Generator.PoolingType poolingType;
@Builder
public JlamaEmbeddingModel(Path modelCachePath,
@@ -27,6 +29,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
String authToken,
Integer threadCount,
Boolean quantizeModelAtRuntime,
+ Generator.PoolingType poolingType,
Path workingDirectory) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@@ -46,10 +49,12 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
if (workingDirectory != null)
loader = loader.workingDirectory(workingDirectory);
- loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS);
+ loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING);
this.model = (BertModel) loader.load();
this.dimension = model.getConfig().embeddingLength;
+
+ this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType;
}
public static JlamaEmbeddingModelBuilder builder() {
@@ -64,7 +69,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
List embeddings = new ArrayList<>();
textSegments.forEach(textSegment -> {
- embeddings.add(Embedding.from(model.embed(textSegment.text())));
+ embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType)));
});
return Response.from(embeddings);
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java
index 1bf0d998e..8d81b3f83 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java
@@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
+import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory;
@@ -30,6 +31,7 @@ public class JlamaLanguageModel implements LanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
+ DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@@ -39,6 +41,9 @@ public class JlamaLanguageModel implements LanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
+ if (workingQuantizedType != null)
+ loader = loader.workingQuantizationType(workingQuantizedType);
+
if (threadCount != null)
loader = loader.threadCount(threadCount);
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java
index fc298900d..1b50b1fe6 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java
@@ -60,6 +60,7 @@ class JlamaModel {
registry.getModelCachePath().toString(),
owner,
modelName,
+ true,
Optional.empty(),
authToken,
Optional.empty());
@@ -67,6 +68,7 @@ class JlamaModel {
public class Loader {
private Path workingDirectory;
+ private DType workingQuantizationType = DType.I8;
private DType quantizationType;
private Integer threadCount;
private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION;
@@ -75,11 +77,19 @@ class JlamaModel {
}
public Loader quantized() {
- //For now only allow Q4 quantization at load time
+ //For now only allow Q4 quantization at runtime
this.quantizationType = DType.Q4;
return this;
}
+ /**
+ * Set the working quantization type. This is the type that the model will use for working inference memory.
+ */
+ public Loader workingQuantizationType(DType workingQuantizationType) {
+ this.workingQuantizationType = workingQuantizationType;
+ return this;
+ }
+
public Loader workingDirectory(Path workingDirectory) {
this.workingDirectory = workingDirectory;
return this;
@@ -101,10 +111,11 @@ class JlamaModel {
new File(registry.getModelCachePath().toFile(), modelName),
workingDirectory == null ? null : workingDirectory.toFile(),
DType.F32,
- DType.I8,
+ workingQuantizationType,
Optional.ofNullable(quantizationType),
Optional.ofNullable(threadCount),
- Optional.empty());
+ Optional.empty(),
+ SafeTensorSupport::loadWeights);
}
}
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java
index 36251dd56..446b02dbb 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java
@@ -90,7 +90,7 @@ class JlamaModelRegistry {
name = parts[1];
}
- File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty());
+ File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty());
File config = new File(modelDir, "config.json");
ModelSupport.ModelType type = SafeTensorSupport.detectModel(config);
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java
index 90a743690..5fe99a6c1 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java
@@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
+import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
@@ -34,6 +35,7 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
+ DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@@ -43,6 +45,9 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
+ if (workingQuantizedType != null)
+ loader = loader.workingQuantizationType(workingQuantizedType);
+
if (threadCount != null)
loader = loader.threadCount(threadCount);
diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java
index 8a3e65280..0405f31c6 100644
--- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java
+++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java
@@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
+import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.StreamingResponseHandler;
@@ -31,6 +32,7 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
+ DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@@ -40,6 +42,9 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
+ if (workingQuantizedType != null)
+ loader = loader.workingQuantizationType(workingQuantizedType);
+
if (threadCount != null)
loader = loader.threadCount(threadCount);
diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java
index 8e715c1c9..31085aa42 100644
--- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java
+++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java
@@ -29,7 +29,7 @@ class JlamaChatModelIT {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.temperature(0.0f)
- .maxTokens(30)
+ .maxTokens(64)
.build();
}
diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java
index d4d614a06..c304e23e9 100644
--- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java
+++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java
@@ -26,7 +26,7 @@ class JlamaLanguageModelIT {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.temperature(0.0f)
- .maxTokens(30)
+ .maxTokens(64)
.build();
}
diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java
index e4bdc4baf..97e5fdd58 100644
--- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java
+++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java
@@ -27,7 +27,7 @@ class JlamaStreamingChatModelIT {
model = JlamaStreamingChatModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
- .maxTokens(30)
+ .maxTokens(64)
.temperature(0.0f)
.build();
}
diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java
index 1d9c30217..818a6a2d5 100644
--- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java
+++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java
@@ -26,7 +26,7 @@ class JlamaStreamingLanguageModelIT {
model = JlamaStreamingLanguageModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
- .maxTokens(30)
+ .maxTokens(64)
.temperature(0.0f)
.build();
}