diff --git a/docs/docs/integrations/scoring-reranking-models/in-process.md b/docs/docs/integrations/scoring-reranking-models/in-process.md index 81226f46f..8f7c7e4b4 100644 --- a/docs/docs/integrations/scoring-reranking-models/in-process.md +++ b/docs/docs/integrations/scoring-reranking-models/in-process.md @@ -4,8 +4,8 @@ sidebar_position: 1 # In-process (ONNX) -LangChain4j provides local Scoring (Reranking) models, powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html), -running in the same Java process. +LangChain4j provides local scoring (reranking) models, +powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html), running in the same Java process. Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used, as long as they are in the ONNX format. @@ -16,7 +16,7 @@ Many models already converted to ONNX format are available [here](https://huggin ### Usage -By default, Scoring (Reranking) use the CPU. +By default, scoring (reranking) model uses the CPU. ```xml dev.langchain4j @@ -33,7 +33,8 @@ Response response = scoringModel.score("query", "passage"); Double score = response.content(); ``` -If you want to use the GPU, onnxruntime_gpu version can be found [here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html). +If you want to use the GPU, `onnxruntime_gpu` version can be found +[here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html). ```xml dev.langchain4j @@ -65,4 +66,4 @@ OnnxScoringModel scoringModel = new OnnxScoringModel(pathToModel, options, pathT Response response = scoringModel.score("query", "passage"); Double score = response.content(); -``` \ No newline at end of file +``` diff --git a/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/AbstractInProcessScoringModel.java b/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/AbstractInProcessScoringModel.java index e524a736e..2dec182ff 100644 --- a/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/AbstractInProcessScoringModel.java +++ b/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/AbstractInProcessScoringModel.java @@ -9,7 +9,7 @@ import dev.langchain4j.model.scoring.ScoringModel; import java.util.List; import java.util.stream.Collectors; -public abstract class AbstractInProcessScoringModel implements ScoringModel { +abstract class AbstractInProcessScoringModel implements ScoringModel { public AbstractInProcessScoringModel() { } diff --git a/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/OnnxScoringBertCrossEncoder.java b/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/OnnxScoringBertCrossEncoder.java index 17192dc9c..b1a74dac5 100644 --- a/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/OnnxScoringBertCrossEncoder.java +++ b/langchain4j-onnx-scoring/src/main/java/dev/langchain4j/model/scoring/onnx/OnnxScoringBertCrossEncoder.java @@ -14,7 +14,7 @@ import java.util.*; import static ai.onnxruntime.OnnxTensor.createTensor; -public class OnnxScoringBertCrossEncoder { +class OnnxScoringBertCrossEncoder { private final OrtEnvironment environment; private final OrtSession session; diff --git a/langchain4j-onnx-scoring/src/main/test/java/dev/langchain4j/model/scoring/onnx/OnnxScoringModelIT.java b/langchain4j-onnx-scoring/src/main/test/java/dev/langchain4j/model/scoring/onnx/OnnxScoringModelIT.java index daac82074..44398d574 100644 --- a/langchain4j-onnx-scoring/src/main/test/java/dev/langchain4j/model/scoring/onnx/OnnxScoringModelIT.java +++ b/langchain4j-onnx-scoring/src/main/test/java/dev/langchain4j/model/scoring/onnx/OnnxScoringModelIT.java @@ -42,7 +42,7 @@ class OnnxScoringModelIT { Files.copy(tokenizerUrl.openStream(), tokenizerPath, REPLACE_EXISTING); // To check the modelMaxLength parameter, refer to the model configuration file at https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer_config.json - model = new OnnxScoringModel(modelPath.toString(), new OrtSession.SessionOptions(), tokenizerPath.toString(), 512, false); + model = new OnnxScoringModel(modelPath.toString(), new OrtSession.SessionOptions(), tokenizerPath.toString(), 512, false); } @Test @@ -60,13 +60,12 @@ class OnnxScoringModelIT { List scores = response.content(); assertThat(scores).hasSize(2); - // python output results: [ 8.845855712890625, -11.245561599731445 ] - assertThat(scores.get(0)).isCloseTo(8.84, withPercentage(1)); - assertThat(scores.get(1)).isCloseTo(-11.24, withPercentage(1)); + // python output results on https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2: [ 8.663132667541504, -11.245542526245117 ] + assertThat(scores.get(0)).isCloseTo(8.663132667541504, withPercentage(0.1)); + assertThat(scores.get(1)).isCloseTo(-11.245542526245117, withPercentage(0.1)); assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0); assertThat(response.finishReason()).isNull(); - } } \ No newline at end of file