add scoring onnx (#1769)

This commit is contained in:
LangChain4j 2024-09-20 12:21:44 +02:00
parent 1c0671617d
commit 40b0d01349
4 changed files with 12 additions and 12 deletions

View File

@ -4,8 +4,8 @@ sidebar_position: 1
# In-process (ONNX)
LangChain4j provides local Scoring (Reranking) models, powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html),
running in the same Java process.
LangChain4j provides local scoring (reranking) models,
powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html), running in the same Java process.
Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used,
as long as they are in the ONNX format.
@ -16,7 +16,7 @@ Many models already converted to ONNX format are available [here](https://huggin
### Usage
By default, Scoring (Reranking) use the CPU.
By default, scoring (reranking) model uses the CPU.
```xml
<dependency>
<groupId>dev.langchain4j</groupId>
@ -33,7 +33,8 @@ Response<Double> response = scoringModel.score("query", "passage");
Double score = response.content();
```
If you want to use the GPU, onnxruntime_gpu version can be found [here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html).
If you want to use the GPU, `onnxruntime_gpu` version can be found
[here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html).
```xml
<dependency>
<groupId>dev.langchain4j</groupId>
@ -65,4 +66,4 @@ OnnxScoringModel scoringModel = new OnnxScoringModel(pathToModel, options, pathT
Response<Double> response = scoringModel.score("query", "passage");
Double score = response.content();
```
```

View File

@ -9,7 +9,7 @@ import dev.langchain4j.model.scoring.ScoringModel;
import java.util.List;
import java.util.stream.Collectors;
public abstract class AbstractInProcessScoringModel implements ScoringModel {
abstract class AbstractInProcessScoringModel implements ScoringModel {
public AbstractInProcessScoringModel() {
}

View File

@ -14,7 +14,7 @@ import java.util.*;
import static ai.onnxruntime.OnnxTensor.createTensor;
public class OnnxScoringBertCrossEncoder {
class OnnxScoringBertCrossEncoder {
private final OrtEnvironment environment;
private final OrtSession session;

View File

@ -42,7 +42,7 @@ class OnnxScoringModelIT {
Files.copy(tokenizerUrl.openStream(), tokenizerPath, REPLACE_EXISTING);
// To check the modelMaxLength parameter, refer to the model configuration file at https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer_config.json
model = new OnnxScoringModel(modelPath.toString(), new OrtSession.SessionOptions(), tokenizerPath.toString(), 512, false);
model = new OnnxScoringModel(modelPath.toString(), new OrtSession.SessionOptions(), tokenizerPath.toString(), 512, false);
}
@Test
@ -60,13 +60,12 @@ class OnnxScoringModelIT {
List<Double> scores = response.content();
assertThat(scores).hasSize(2);
// python output results: [ 8.845855712890625, -11.245561599731445 ]
assertThat(scores.get(0)).isCloseTo(8.84, withPercentage(1));
assertThat(scores.get(1)).isCloseTo(-11.24, withPercentage(1));
// python output results on https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2: [ 8.663132667541504, -11.245542526245117 ]
assertThat(scores.get(0)).isCloseTo(8.663132667541504, withPercentage(0.1));
assertThat(scores.get(1)).isCloseTo(-11.245542526245117, withPercentage(0.1));
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
assertThat(response.finishReason()).isNull();
}
}