Jlama revision bump, add working Q type to builder (#1825)
## Change Bump jlama rev to 0.5.0 This rev is currently being promoted in maven so should be there in next 6 hours. ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green
This commit is contained in:
parent
3579664e08
commit
21eb8b962f
|
@ -3,7 +3,7 @@
|
|||
[Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java
|
||||
applications.
|
||||
|
||||
Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
|
||||
Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
|
||||
|
||||
Jlama uses huggingface models in safetensor format.
|
||||
Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`.
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
<modelVersion>4.0.0</modelVersion>
|
||||
<artifactId>langchain4j-jlama</artifactId>
|
||||
<name>LangChain4j :: Integration :: Jlama</name>
|
||||
<description>Jlama: Pure Java LLM Inference Engine - Requires Java 21</description>
|
||||
<description>Jlama: LLM Inference Engine for Java - Requires Java 20+</description>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
|
@ -15,7 +15,7 @@
|
|||
</parent>
|
||||
|
||||
<properties>
|
||||
<jlama.version>0.3.1</jlama.version>
|
||||
<jlama.version>0.5.0</jlama.version>
|
||||
<jackson.version>2.16.1</jackson.version>
|
||||
<spotless.version>2.40.0</spotless.version>
|
||||
<maven.compiler.release>21</maven.compiler.release>
|
||||
|
|
|
@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
|
|||
|
||||
import com.github.tjake.jlama.model.AbstractModel;
|
||||
import com.github.tjake.jlama.model.functions.Generator;
|
||||
import com.github.tjake.jlama.safetensors.DType;
|
||||
import com.github.tjake.jlama.safetensors.prompt.*;
|
||||
import com.github.tjake.jlama.util.JsonSupport;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
|
@ -33,6 +34,7 @@ public class JlamaChatModel implements ChatLanguageModel {
|
|||
Integer threadCount,
|
||||
Boolean quantizeModelAtRuntime,
|
||||
Path workingDirectory,
|
||||
DType workingQuantizedType,
|
||||
Float temperature,
|
||||
Integer maxTokens) {
|
||||
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
|
||||
|
@ -42,6 +44,9 @@ public class JlamaChatModel implements ChatLanguageModel {
|
|||
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
|
||||
loader = loader.quantized();
|
||||
|
||||
if (workingQuantizedType != null)
|
||||
loader = loader.workingQuantizationType(workingQuantizedType);
|
||||
|
||||
if (threadCount != null)
|
||||
loader = loader.threadCount(threadCount);
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package dev.langchain4j.model.jlama;
|
|||
import com.github.tjake.jlama.model.AbstractModel;
|
||||
import com.github.tjake.jlama.model.ModelSupport;
|
||||
import com.github.tjake.jlama.model.bert.BertModel;
|
||||
import com.github.tjake.jlama.model.functions.Generator;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.RetryUtils;
|
||||
|
@ -20,6 +21,7 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
|||
|
||||
public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
|
||||
private final BertModel model;
|
||||
private final Generator.PoolingType poolingType;
|
||||
|
||||
@Builder
|
||||
public JlamaEmbeddingModel(Path modelCachePath,
|
||||
|
@ -27,6 +29,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
|
|||
String authToken,
|
||||
Integer threadCount,
|
||||
Boolean quantizeModelAtRuntime,
|
||||
Generator.PoolingType poolingType,
|
||||
Path workingDirectory) {
|
||||
|
||||
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
|
||||
|
@ -46,10 +49,12 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
|
|||
if (workingDirectory != null)
|
||||
loader = loader.workingDirectory(workingDirectory);
|
||||
|
||||
loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS);
|
||||
loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING);
|
||||
|
||||
this.model = (BertModel) loader.load();
|
||||
this.dimension = model.getConfig().embeddingLength;
|
||||
|
||||
this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType;
|
||||
}
|
||||
|
||||
public static JlamaEmbeddingModelBuilder builder() {
|
||||
|
@ -64,7 +69,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
|
|||
List<Embedding> embeddings = new ArrayList<>();
|
||||
|
||||
textSegments.forEach(textSegment -> {
|
||||
embeddings.add(Embedding.from(model.embed(textSegment.text())));
|
||||
embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType)));
|
||||
});
|
||||
|
||||
return Response.from(embeddings);
|
||||
|
|
|
@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
|
|||
|
||||
import com.github.tjake.jlama.model.AbstractModel;
|
||||
import com.github.tjake.jlama.model.functions.Generator;
|
||||
import com.github.tjake.jlama.safetensors.DType;
|
||||
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
|
||||
import dev.langchain4j.internal.RetryUtils;
|
||||
import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory;
|
||||
|
@ -30,6 +31,7 @@ public class JlamaLanguageModel implements LanguageModel {
|
|||
Integer threadCount,
|
||||
Boolean quantizeModelAtRuntime,
|
||||
Path workingDirectory,
|
||||
DType workingQuantizedType,
|
||||
Float temperature,
|
||||
Integer maxTokens) {
|
||||
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
|
||||
|
@ -39,6 +41,9 @@ public class JlamaLanguageModel implements LanguageModel {
|
|||
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
|
||||
loader = loader.quantized();
|
||||
|
||||
if (workingQuantizedType != null)
|
||||
loader = loader.workingQuantizationType(workingQuantizedType);
|
||||
|
||||
if (threadCount != null)
|
||||
loader = loader.threadCount(threadCount);
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class JlamaModel {
|
|||
registry.getModelCachePath().toString(),
|
||||
owner,
|
||||
modelName,
|
||||
true,
|
||||
Optional.empty(),
|
||||
authToken,
|
||||
Optional.empty());
|
||||
|
@ -67,6 +68,7 @@ class JlamaModel {
|
|||
|
||||
public class Loader {
|
||||
private Path workingDirectory;
|
||||
private DType workingQuantizationType = DType.I8;
|
||||
private DType quantizationType;
|
||||
private Integer threadCount;
|
||||
private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION;
|
||||
|
@ -75,11 +77,19 @@ class JlamaModel {
|
|||
}
|
||||
|
||||
public Loader quantized() {
|
||||
//For now only allow Q4 quantization at load time
|
||||
//For now only allow Q4 quantization at runtime
|
||||
this.quantizationType = DType.Q4;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the working quantization type. This is the type that the model will use for working inference memory.
|
||||
*/
|
||||
public Loader workingQuantizationType(DType workingQuantizationType) {
|
||||
this.workingQuantizationType = workingQuantizationType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Loader workingDirectory(Path workingDirectory) {
|
||||
this.workingDirectory = workingDirectory;
|
||||
return this;
|
||||
|
@ -101,10 +111,11 @@ class JlamaModel {
|
|||
new File(registry.getModelCachePath().toFile(), modelName),
|
||||
workingDirectory == null ? null : workingDirectory.toFile(),
|
||||
DType.F32,
|
||||
DType.I8,
|
||||
workingQuantizationType,
|
||||
Optional.ofNullable(quantizationType),
|
||||
Optional.ofNullable(threadCount),
|
||||
Optional.empty());
|
||||
Optional.empty(),
|
||||
SafeTensorSupport::loadWeights);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class JlamaModelRegistry {
|
|||
name = parts[1];
|
||||
}
|
||||
|
||||
File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty());
|
||||
File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty());
|
||||
|
||||
File config = new File(modelDir, "config.json");
|
||||
ModelSupport.ModelType type = SafeTensorSupport.detectModel(config);
|
||||
|
|
|
@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
|
|||
|
||||
import com.github.tjake.jlama.model.AbstractModel;
|
||||
import com.github.tjake.jlama.model.functions.Generator;
|
||||
import com.github.tjake.jlama.safetensors.DType;
|
||||
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
|
@ -34,6 +35,7 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
Integer threadCount,
|
||||
Boolean quantizeModelAtRuntime,
|
||||
Path workingDirectory,
|
||||
DType workingQuantizedType,
|
||||
Float temperature,
|
||||
Integer maxTokens) {
|
||||
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
|
||||
|
@ -43,6 +45,9 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
|
|||
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
|
||||
loader = loader.quantized();
|
||||
|
||||
if (workingQuantizedType != null)
|
||||
loader = loader.workingQuantizationType(workingQuantizedType);
|
||||
|
||||
if (threadCount != null)
|
||||
loader = loader.threadCount(threadCount);
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
|
|||
|
||||
import com.github.tjake.jlama.model.AbstractModel;
|
||||
import com.github.tjake.jlama.model.functions.Generator;
|
||||
import com.github.tjake.jlama.safetensors.DType;
|
||||
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
|
||||
import dev.langchain4j.internal.RetryUtils;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
|
@ -31,6 +32,7 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
|
|||
Integer threadCount,
|
||||
Boolean quantizeModelAtRuntime,
|
||||
Path workingDirectory,
|
||||
DType workingQuantizedType,
|
||||
Float temperature,
|
||||
Integer maxTokens) {
|
||||
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
|
||||
|
@ -40,6 +42,9 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
|
|||
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
|
||||
loader = loader.quantized();
|
||||
|
||||
if (workingQuantizedType != null)
|
||||
loader = loader.workingQuantizationType(workingQuantizedType);
|
||||
|
||||
if (threadCount != null)
|
||||
loader = loader.threadCount(threadCount);
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class JlamaChatModelIT {
|
|||
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
|
||||
.modelCachePath(tmpDir.toPath())
|
||||
.temperature(0.0f)
|
||||
.maxTokens(30)
|
||||
.maxTokens(64)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class JlamaLanguageModelIT {
|
|||
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
|
||||
.modelCachePath(tmpDir.toPath())
|
||||
.temperature(0.0f)
|
||||
.maxTokens(30)
|
||||
.maxTokens(64)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class JlamaStreamingChatModelIT {
|
|||
model = JlamaStreamingChatModel.builder()
|
||||
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
|
||||
.modelCachePath(tmpDir.toPath())
|
||||
.maxTokens(30)
|
||||
.maxTokens(64)
|
||||
.temperature(0.0f)
|
||||
.build();
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ class JlamaStreamingLanguageModelIT {
|
|||
model = JlamaStreamingLanguageModel.builder()
|
||||
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
|
||||
.modelCachePath(tmpDir.toPath())
|
||||
.maxTokens(30)
|
||||
.maxTokens(64)
|
||||
.temperature(0.0f)
|
||||
.build();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue