Jlama revision bump, add working Q type to builder (#1825)

## Change
Bump jlama rev to 0.5.0

This rev is currently being promoted in maven so should be there in next
6 hours.

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
This commit is contained in:
Jake Luciani 2024-09-25 02:55:25 -04:00 committed by GitHub
parent 3579664e08
commit 21eb8b962f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 49 additions and 13 deletions

View File

@ -3,7 +3,7 @@
[Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java
applications.
Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
Jlama uses huggingface models in safetensor format.
Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`.

View File

@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion>
<artifactId>langchain4j-jlama</artifactId>
<name>LangChain4j :: Integration :: Jlama</name>
<description>Jlama: Pure Java LLM Inference Engine - Requires Java 21</description>
<description>Jlama: LLM Inference Engine for Java - Requires Java 20+</description>
<parent>
<groupId>dev.langchain4j</groupId>
@ -15,7 +15,7 @@
</parent>
<properties>
<jlama.version>0.3.1</jlama.version>
<jlama.version>0.5.0</jlama.version>
<jackson.version>2.16.1</jackson.version>
<spotless.version>2.40.0</spotless.version>
<maven.compiler.release>21</maven.compiler.release>

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.*;
import com.github.tjake.jlama.util.JsonSupport;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
@ -33,6 +34,7 @@ public class JlamaChatModel implements ChatLanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -42,6 +44,9 @@ public class JlamaChatModel implements ChatLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null)
loader = loader.threadCount(threadCount);

View File

@ -3,6 +3,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.bert.BertModel;
import com.github.tjake.jlama.model.functions.Generator;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
@ -20,6 +21,7 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories;
public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
private final BertModel model;
private final Generator.PoolingType poolingType;
@Builder
public JlamaEmbeddingModel(Path modelCachePath,
@ -27,6 +29,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
String authToken,
Integer threadCount,
Boolean quantizeModelAtRuntime,
Generator.PoolingType poolingType,
Path workingDirectory) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -46,10 +49,12 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
if (workingDirectory != null)
loader = loader.workingDirectory(workingDirectory);
loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS);
loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING);
this.model = (BertModel) loader.load();
this.dimension = model.getConfig().embeddingLength;
this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType;
}
public static JlamaEmbeddingModelBuilder builder() {
@ -64,7 +69,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
List<Embedding> embeddings = new ArrayList<>();
textSegments.forEach(textSegment -> {
embeddings.add(Embedding.from(model.embed(textSegment.text())));
embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType)));
});
return Response.from(embeddings);

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory;
@ -30,6 +31,7 @@ public class JlamaLanguageModel implements LanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -39,6 +41,9 @@ public class JlamaLanguageModel implements LanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null)
loader = loader.threadCount(threadCount);

View File

@ -60,6 +60,7 @@ class JlamaModel {
registry.getModelCachePath().toString(),
owner,
modelName,
true,
Optional.empty(),
authToken,
Optional.empty());
@ -67,6 +68,7 @@ class JlamaModel {
public class Loader {
private Path workingDirectory;
private DType workingQuantizationType = DType.I8;
private DType quantizationType;
private Integer threadCount;
private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION;
@ -75,11 +77,19 @@ class JlamaModel {
}
public Loader quantized() {
//For now only allow Q4 quantization at load time
//For now only allow Q4 quantization at runtime
this.quantizationType = DType.Q4;
return this;
}
/**
* Set the working quantization type. This is the type that the model will use for working inference memory.
*/
public Loader workingQuantizationType(DType workingQuantizationType) {
this.workingQuantizationType = workingQuantizationType;
return this;
}
public Loader workingDirectory(Path workingDirectory) {
this.workingDirectory = workingDirectory;
return this;
@ -101,10 +111,11 @@ class JlamaModel {
new File(registry.getModelCachePath().toFile(), modelName),
workingDirectory == null ? null : workingDirectory.toFile(),
DType.F32,
DType.I8,
workingQuantizationType,
Optional.ofNullable(quantizationType),
Optional.ofNullable(threadCount),
Optional.empty());
Optional.empty(),
SafeTensorSupport::loadWeights);
}
}

View File

@ -90,7 +90,7 @@ class JlamaModelRegistry {
name = parts[1];
}
File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty());
File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty());
File config = new File(modelDir, "config.json");
ModelSupport.ModelType type = SafeTensorSupport.detectModel(config);

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
@ -34,6 +35,7 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -43,6 +45,9 @@ public class JlamaStreamingChatModel implements StreamingChatLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null)
loader = loader.threadCount(threadCount);

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.jlama;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.StreamingResponseHandler;
@ -31,6 +32,7 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
@ -40,6 +42,9 @@ public class JlamaStreamingLanguageModel implements StreamingLanguageModel {
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();
if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);
if (threadCount != null)
loader = loader.threadCount(threadCount);

View File

@ -29,7 +29,7 @@ class JlamaChatModelIT {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.temperature(0.0f)
.maxTokens(30)
.maxTokens(64)
.build();
}

View File

@ -26,7 +26,7 @@ class JlamaLanguageModelIT {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.temperature(0.0f)
.maxTokens(30)
.maxTokens(64)
.build();
}

View File

@ -27,7 +27,7 @@ class JlamaStreamingChatModelIT {
model = JlamaStreamingChatModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.maxTokens(30)
.maxTokens(64)
.temperature(0.0f)
.build();
}

View File

@ -26,7 +26,7 @@ class JlamaStreamingLanguageModelIT {
model = JlamaStreamingLanguageModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.maxTokens(30)
.maxTokens(64)
.temperature(0.0f)
.build();
}