add scoring onnx (#1769)
This commit is contained in:
parent
1c0671617d
commit
40b0d01349
|
@ -4,8 +4,8 @@ sidebar_position: 1
|
|||
|
||||
# In-process (ONNX)
|
||||
|
||||
LangChain4j provides local Scoring (Reranking) models, powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html),
|
||||
running in the same Java process.
|
||||
LangChain4j provides local scoring (reranking) models,
|
||||
powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html), running in the same Java process.
|
||||
|
||||
Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used,
|
||||
as long as they are in the ONNX format.
|
||||
|
@ -16,7 +16,7 @@ Many models already converted to ONNX format are available [here](https://huggin
|
|||
|
||||
### Usage
|
||||
|
||||
By default, Scoring (Reranking) use the CPU.
|
||||
By default, scoring (reranking) model uses the CPU.
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
|
@ -33,7 +33,8 @@ Response<Double> response = scoringModel.score("query", "passage");
|
|||
Double score = response.content();
|
||||
```
|
||||
|
||||
If you want to use the GPU, onnxruntime_gpu version can be found [here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html).
|
||||
If you want to use the GPU, `onnxruntime_gpu` version can be found
|
||||
[here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html).
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
|
|
|
@ -9,7 +9,7 @@ import dev.langchain4j.model.scoring.ScoringModel;
|
|||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class AbstractInProcessScoringModel implements ScoringModel {
|
||||
abstract class AbstractInProcessScoringModel implements ScoringModel {
|
||||
|
||||
public AbstractInProcessScoringModel() {
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import java.util.*;
|
|||
|
||||
import static ai.onnxruntime.OnnxTensor.createTensor;
|
||||
|
||||
public class OnnxScoringBertCrossEncoder {
|
||||
class OnnxScoringBertCrossEncoder {
|
||||
|
||||
private final OrtEnvironment environment;
|
||||
private final OrtSession session;
|
||||
|
|
|
@ -60,13 +60,12 @@ class OnnxScoringModelIT {
|
|||
List<Double> scores = response.content();
|
||||
assertThat(scores).hasSize(2);
|
||||
|
||||
// python output results: [ 8.845855712890625, -11.245561599731445 ]
|
||||
assertThat(scores.get(0)).isCloseTo(8.84, withPercentage(1));
|
||||
assertThat(scores.get(1)).isCloseTo(-11.24, withPercentage(1));
|
||||
// python output results on https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2: [ 8.663132667541504, -11.245542526245117 ]
|
||||
assertThat(scores.get(0)).isCloseTo(8.663132667541504, withPercentage(0.1));
|
||||
assertThat(scores.get(1)).isCloseTo(-11.245542526245117, withPercentage(0.1));
|
||||
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0);
|
||||
|
||||
assertThat(response.finishReason()).isNull();
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue