Jlama revision bump, add working Q type to builder (#1825)

## Change
Bump jlama rev to 0.5.0

This rev is currently being promoted in maven so should be there in next
6 hours.

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
This commit is contained in:
Jake Luciani 2024-09-25 02:55:25 -04:00 committed by GitHub
parent 3579664e08
commit 21eb8b962f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 49 additions and 13 deletions

View File

@ -3,7 +3,7 @@
[Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java [Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java
applications. applications.
Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference. Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
Jlama uses huggingface models in safetensor format. Jlama uses huggingface models in safetensor format.
Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`. Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`.

View File

@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>langchain4j-jlama</artifactId> <artifactId>langchain4j-jlama</artifactId>
<name>LangChain4j :: Integration :: Jlama</name> <name>LangChain4j :: Integration :: Jlama</name>
<description>Jlama: Pure Java LLM Inference Engine - Requires Java 21</description> <description>Jlama: LLM Inference Engine for Java - Requires Java 20+</description>
<parent> <parent>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
@ -15,7 +15,7 @@
</parent> </parent>
<properties> <properties>
<jlama.version>0.3.1</jlama.version> <jlama.version>0.5.0</jlama.version>
<jackson.version>2.16.1</jackson.version> <jackson.version>2.16.1</jackson.version>
<spotless.version>2.40.0</spotless.version> <spotless.version>2.40.0</spotless.version>
<maven.compiler.release>21</maven.compiler.release> <maven.compiler.release>21</maven.compiler.release>

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.safetensors.prompt.*;
import com.github.tjake.jlama.util.JsonSupport; import com.github.tjake.jlama.util.JsonSupport;
import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolExecutionRequest;
@ -33,6 +34,7 @@ public class JlamaChatModel implements ChatLanguageModel {
Integer threadCount, Integer threadCount,
Boolean quantizeModelAtRuntime, Boolean quantizeModelAtRuntime,
Path workingDirectory, Path workingDirectory,
DType workingQuantizedType,
Float temperature, Float temperature,
Integer maxTokens) { Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -42,6 +44,9 @@ public class JlamaChatModel implements ChatLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized(); loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null) if (threadCount != null)
loader = loader.threadCount(threadCount); loader = loader.threadCount(threadCount);

View File

@ -3,6 +3,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport; import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.bert.BertModel; import com.github.tjake.jlama.model.bert.BertModel;
import com.github.tjake.jlama.model.functions.Generator;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.internal.RetryUtils;
@ -20,6 +21,7 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories;
public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel { public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
private final BertModel model; private final BertModel model;
private final Generator.PoolingType poolingType;
@Builder @Builder
public JlamaEmbeddingModel(Path modelCachePath, public JlamaEmbeddingModel(Path modelCachePath,
@ -27,6 +29,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
String authToken, String authToken,
Integer threadCount, Integer threadCount,
Boolean quantizeModelAtRuntime, Boolean quantizeModelAtRuntime,
Generator.PoolingType poolingType,
Path workingDirectory) { Path workingDirectory) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -46,10 +49,12 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
if (workingDirectory != null) if (workingDirectory != null)
loader = loader.workingDirectory(workingDirectory); loader = loader.workingDirectory(workingDirectory);
loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS); loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING);
this.model = (BertModel) loader.load(); this.model = (BertModel) loader.load();
this.dimension = model.getConfig().embeddingLength; this.dimension = model.getConfig().embeddingLength;
this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType;
} }
public static JlamaEmbeddingModelBuilder builder() { public static JlamaEmbeddingModelBuilder builder() {
@ -64,7 +69,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
List<Embedding> embeddings = new ArrayList<>(); List<Embedding> embeddings = new ArrayList<>();
textSegments.forEach(textSegment -> { textSegments.forEach(textSegment -> {
embeddings.add(Embedding.from(model.embed(textSegment.text()))); embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType)));
}); });
return Response.from(embeddings); return Response.from(embeddings);

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext; import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory; import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory;
@ -30,6 +31,7 @@ public class JlamaLanguageModel implements LanguageModel {
Integer threadCount, Integer threadCount,
Boolean quantizeModelAtRuntime, Boolean quantizeModelAtRuntime,
Path workingDirectory, Path workingDirectory,
DType workingQuantizedType,
Float temperature, Float temperature,
Integer maxTokens) { Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -39,6 +41,9 @@ public class JlamaLanguageModel implements LanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized(); loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null) if (threadCount != null)
loader = loader.threadCount(threadCount); loader = loader.threadCount(threadCount);

View File

@ -60,6 +60,7 @@ class JlamaModel {
registry.getModelCachePath().toString(), registry.getModelCachePath().toString(),
owner, owner,
modelName, modelName,
true,
Optional.empty(), Optional.empty(),
authToken, authToken,
Optional.empty()); Optional.empty());
@ -67,6 +68,7 @@ class JlamaModel {
public class Loader { public class Loader {
private Path workingDirectory; private Path workingDirectory;
private DType workingQuantizationType = DType.I8;
private DType quantizationType; private DType quantizationType;
private Integer threadCount; private Integer threadCount;
private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION; private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION;
@ -75,11 +77,19 @@ class JlamaModel {
} }
public Loader quantized() { public Loader quantized() {
//For now only allow Q4 quantization at load time //For now only allow Q4 quantization at runtime
this.quantizationType = DType.Q4; this.quantizationType = DType.Q4;
return this; return this;
} }
/**
* Set the working quantization type. This is the type that the model will use for working inference memory.
*/
public Loader workingQuantizationType(DType workingQuantizationType) {
this.workingQuantizationType = workingQuantizationType;
return this;
}
public Loader workingDirectory(Path workingDirectory) { public Loader workingDirectory(Path workingDirectory) {
this.workingDirectory = workingDirectory; this.workingDirectory = workingDirectory;
return this; return this;
@ -101,10 +111,11 @@ class JlamaModel {
new File(registry.getModelCachePath().toFile(), modelName), new File(registry.getModelCachePath().toFile(), modelName),
workingDirectory == null ? null : workingDirectory.toFile(), workingDirectory == null ? null : workingDirectory.toFile(),
DType.F32, DType.F32,
DType.I8, workingQuantizationType,
Optional.ofNullable(quantizationType), Optional.ofNullable(quantizationType),
Optional.ofNullable(threadCount), Optional.ofNullable(threadCount),
Optional.empty()); Optional.empty(),
SafeTensorSupport::loadWeights);
} }
} }

View File

@ -90,7 +90,7 @@ class JlamaModelRegistry {
name = parts[1]; name = parts[1];
} }
File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty()); File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty());
File config = new File(modelDir, "config.json"); File config = new File(modelDir, "config.json");
ModelSupport.ModelType type = SafeTensorSupport.detectModel(config); ModelSupport.ModelType type = SafeTensorSupport.detectModel(config);

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ChatMessage;
@ -34,6 +35,7 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
Integer threadCount, Integer threadCount,
Boolean quantizeModelAtRuntime, Boolean quantizeModelAtRuntime,
Path workingDirectory, Path workingDirectory,
DType workingQuantizedType,
Float temperature, Float temperature,
Integer maxTokens) { Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -43,6 +45,9 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized(); loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null) if (threadCount != null)
loader = loader.threadCount(threadCount); loader = loader.threadCount(threadCount);

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext; import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.StreamingResponseHandler;
@ -31,6 +32,7 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
Integer threadCount, Integer threadCount,
Boolean quantizeModelAtRuntime, Boolean quantizeModelAtRuntime,
Path workingDirectory, Path workingDirectory,
DType workingQuantizedType,
Float temperature, Float temperature,
Integer maxTokens) { Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -40,6 +42,9 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized(); loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null) if (threadCount != null)
loader = loader.threadCount(threadCount); loader = loader.threadCount(threadCount);

View File

@ -29,7 +29,7 @@ class JlamaChatModelIT {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath()) .modelCachePath(tmpDir.toPath())
.temperature(0.0f) .temperature(0.0f)
.maxTokens(30) .maxTokens(64)
.build(); .build();
} }

View File

@ -26,7 +26,7 @@ class JlamaLanguageModelIT {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath()) .modelCachePath(tmpDir.toPath())
.temperature(0.0f) .temperature(0.0f)
.maxTokens(30) .maxTokens(64)
.build(); .build();
} }

View File

@ -27,7 +27,7 @@ class JlamaStreamingChatModelIT {
model = JlamaStreamingChatModel.builder() model = JlamaStreamingChatModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath()) .modelCachePath(tmpDir.toPath())
.maxTokens(30) .maxTokens(64)
.temperature(0.0f) .temperature(0.0f)
.build(); .build();
} }

View File

@ -26,7 +26,7 @@ class JlamaStreamingLanguageModelIT {
model = JlamaStreamingLanguageModel.builder() model = JlamaStreamingLanguageModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath()) .modelCachePath(tmpDir.toPath())
.maxTokens(30) .maxTokens(64)
.temperature(0.0f) .temperature(0.0f)
.build(); .build();
} }