Remove stuff, add distribution shift (WIP)
This commit is contained in:
parent
e56f160032
commit
65e49b7092
|
@ -56,7 +56,7 @@ dependencies = [
|
|||
"flate2",
|
||||
"futures-core",
|
||||
"h2",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
|
@ -90,7 +90,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "d66ff4d247d2b160861fa2866457e85706833527840e4133f8f49aa423a38799"
|
||||
dependencies = [
|
||||
"bytestring",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"regex",
|
||||
"serde",
|
||||
"tracing",
|
||||
|
@ -189,7 +189,7 @@ dependencies = [
|
|||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"itoa",
|
||||
"language-tags",
|
||||
"log",
|
||||
|
@ -1407,7 +1407,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"big_s",
|
||||
"flate2",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"log",
|
||||
"maplit",
|
||||
"meili-snap",
|
||||
|
@ -1702,21 +1702,6 @@ version = "1.0.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
|
||||
dependencies = [
|
||||
"foreign-types-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types-shared"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.2.0"
|
||||
|
@ -2047,7 +2032,7 @@ dependencies = [
|
|||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"indexmap 1.9.3",
|
||||
"slab",
|
||||
"tokio",
|
||||
|
@ -2171,13 +2156,12 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
|||
[[package]]
|
||||
name = "hf-hub"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
|
||||
source = "git+https://github.com/dureuill/hf-hub.git?branch=rust_tls#88d4f11cb9fa079f2912bacb96f5080b16825ce8"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"http 1.0.0",
|
||||
"indicatif",
|
||||
"log",
|
||||
"native-tls",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -2205,6 +2189,17 @@ dependencies = [
|
|||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "0.4.5"
|
||||
|
@ -2212,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
|
@ -2245,7 +2240,7 @@ dependencies = [
|
|||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
|
@ -2265,7 +2260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"rustls 0.21.6",
|
||||
"tokio",
|
||||
|
@ -2868,21 +2863,6 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant-distance"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a"
|
||||
dependencies = [
|
||||
"num_cpus",
|
||||
"ordered-float 3.7.0",
|
||||
"parking_lot",
|
||||
"rand",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde-big-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "io-lifetimes"
|
||||
version = "1.0.11"
|
||||
|
@ -3531,7 +3511,7 @@ dependencies = [
|
|||
"futures",
|
||||
"futures-util",
|
||||
"hex",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"index-scheduler",
|
||||
"indexmap 2.0.0",
|
||||
"insta",
|
||||
|
@ -3718,7 +3698,6 @@ dependencies = [
|
|||
"hf-hub",
|
||||
"indexmap 2.0.0",
|
||||
"insta",
|
||||
"instant-distance",
|
||||
"itertools 0.11.0",
|
||||
"json-depth-checker",
|
||||
"levenshtein_automata",
|
||||
|
@ -3730,7 +3709,6 @@ dependencies = [
|
|||
"meili-snap",
|
||||
"memmap2 0.7.1",
|
||||
"mimalloc",
|
||||
"nolife",
|
||||
"obkv",
|
||||
"once_cell",
|
||||
"ordered-float 3.7.0",
|
||||
|
@ -3829,35 +3807,11 @@ dependencies = [
|
|||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"log",
|
||||
"openssl",
|
||||
"openssl-probe",
|
||||
"openssl-sys",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nelson"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a"
|
||||
|
||||
[[package]]
|
||||
name = "nolife"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc"
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
|
@ -3994,50 +3948,6 @@ version = "11.1.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.59"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a257ad03cd8fb16ad4172fedf8094451e1af1c4b70097636ef2eac9a5f0cc33"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"cfg-if",
|
||||
"foreign-types",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"openssl-macros",
|
||||
"openssl-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.95"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "option-ext"
|
||||
version = "0.2.0"
|
||||
|
@ -4655,7 +4565,7 @@ dependencies = [
|
|||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
|
@ -4802,7 +4712,7 @@ checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb"
|
|||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-webpki 0.101.3",
|
||||
"rustls-webpki",
|
||||
"sct",
|
||||
]
|
||||
|
||||
|
@ -4815,16 +4725,6 @@ dependencies = [
|
|||
"base64 0.21.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.101.3"
|
||||
|
@ -4866,15 +4766,6 @@ dependencies = [
|
|||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88"
|
||||
dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
|
@ -4891,29 +4782,6 @@ dependencies = [
|
|||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "segment"
|
||||
version = "0.2.2"
|
||||
|
@ -4949,15 +4817,6 @@ dependencies = [
|
|||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde-big-array"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde-cs"
|
||||
version = "0.2.4"
|
||||
|
@ -5151,6 +5010,17 @@ dependencies = [
|
|||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socks"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
|
@ -5713,21 +5583,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
|||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.7.1"
|
||||
version = "2.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9"
|
||||
checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
|
||||
dependencies = [
|
||||
"base64 0.21.5",
|
||||
"flate2",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"rustls 0.21.6",
|
||||
"rustls-webpki 0.100.2",
|
||||
"rustls-webpki",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socks",
|
||||
"url",
|
||||
"webpki-roots 0.23.1",
|
||||
"webpki-roots 0.25.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5958,15 +5828,6 @@ dependencies = [
|
|||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338"
|
||||
dependencies = [
|
||||
"rustls-webpki 0.100.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.25.3"
|
||||
|
|
|
@ -14,18 +14,14 @@ use meilisearch_types::error::deserr_codes::*;
|
|||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::{
|
||||
dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues,
|
||||
VectorQuery,
|
||||
};
|
||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
|
||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||
use meilisearch_types::{milli, Document};
|
||||
use milli::tokenizer::TokenizerBuilder;
|
||||
use milli::{
|
||||
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
|
||||
SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET,
|
||||
SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use regex::Regex;
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
|
@ -550,13 +546,8 @@ pub fn perform_search(
|
|||
insert_geo_distance(sort, &mut document);
|
||||
}
|
||||
|
||||
let semantic_score = /*match query.vector.as_ref() {
|
||||
Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? {
|
||||
Some(vectors) => compute_semantic_score(vector, vectors)?,
|
||||
None => None,
|
||||
},
|
||||
None => None,
|
||||
};*/ None;
|
||||
/// FIXME: remove this or set to value from the score details
|
||||
let semantic_score = None;
|
||||
|
||||
let ranking_score =
|
||||
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
||||
|
@ -689,18 +680,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
|
|||
}
|
||||
}
|
||||
|
||||
fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result<Option<f32>> {
|
||||
let vectors = serde_json::from_value(vectors)
|
||||
.map(VectorOrArrayOfVectors::into_array_of_vectors)
|
||||
.map_err(InternalError::SerdeJson)?;
|
||||
Ok(vectors
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
|
||||
.max()
|
||||
.map(OrderedFloat::into_inner))
|
||||
}
|
||||
|
||||
fn compute_formatted_options(
|
||||
attr_to_highlight: &HashSet<String>,
|
||||
attr_to_crop: &[String],
|
||||
|
@ -828,22 +807,6 @@ fn make_document(
|
|||
Ok(document)
|
||||
}
|
||||
|
||||
/// Extract the JSON value under the field name specified
|
||||
/// but doesn't support nested objects.
|
||||
fn extract_field(
|
||||
field_name: &str,
|
||||
field_ids_map: &FieldsIdsMap,
|
||||
obkv: obkv::KvReaderU16,
|
||||
) -> Result<Option<serde_json::Value>, MeilisearchHttpError> {
|
||||
match field_ids_map.id(field_name) {
|
||||
Some(fid) => match obkv.get(fid) {
|
||||
Some(value) => Ok(serde_json::from_slice(value).map(Some)?),
|
||||
None => Ok(None),
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_fields<'a>(
|
||||
document: &Document,
|
||||
field_ids_map: &FieldsIdsMap,
|
||||
|
|
|
@ -36,7 +36,6 @@ heed = { version = "0.20.0-alpha.9", default-features = false, features = [
|
|||
"read-txn-no-tls",
|
||||
] }
|
||||
indexmap = { version = "2.0.0", features = ["serde"] }
|
||||
instant-distance = { version = "0.6.1", features = ["with-serde"] }
|
||||
json-depth-checker = { path = "../json-depth-checker" }
|
||||
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
|
||||
memmap2 = "0.7.1"
|
||||
|
@ -79,10 +78,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.
|
|||
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
|
||||
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" }
|
||||
hf-hub = "0.3.2"
|
||||
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
|
||||
"online",
|
||||
] }
|
||||
tokio = { version = "1.34.0", features = ["rt"] }
|
||||
futures = "0.3.29"
|
||||
nolife = { version = "0.3.1" }
|
||||
reqwest = { version = "0.11.16", features = [
|
||||
"rustls-tls",
|
||||
"json",
|
||||
|
@ -102,7 +102,15 @@ meili-snap = { path = "../meili-snap" }
|
|||
rand = { version = "0.8.5", features = ["small_rng"] }
|
||||
|
||||
[features]
|
||||
all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"]
|
||||
all-tokenizations = [
|
||||
"charabia/chinese",
|
||||
"charabia/hebrew",
|
||||
"charabia/japanese",
|
||||
"charabia/thai",
|
||||
"charabia/korean",
|
||||
"charabia/greek",
|
||||
"charabia/khmer",
|
||||
]
|
||||
|
||||
# Use POSIX semaphores instead of SysV semaphores in LMDB
|
||||
# For more information on this feature, see heed's Cargo.toml
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
use std::ops;
|
||||
|
||||
use instant_distance::Point;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::normalize_vector;
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct NDotProductPoint(Vec<f32>);
|
||||
|
||||
impl NDotProductPoint {
|
||||
pub fn new(point: Vec<f32>) -> Self {
|
||||
NDotProductPoint(normalize_vector(point))
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Vec<f32> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::Deref for NDotProductPoint {
|
||||
type Target = [f32];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point for NDotProductPoint {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
|
||||
debug_assert!(!dist.is_nan());
|
||||
dist
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the dot product similarity score that will between 0.0 and 1.0
|
||||
/// if both vectors are normalized. The higher the more similar the vectors are.
|
||||
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
}
|
|
@ -10,7 +10,6 @@ use roaring::RoaringBitmap;
|
|||
use rstar::RTree;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::distance::NDotProductPoint;
|
||||
use crate::documents::PrimaryKey;
|
||||
use crate::error::{InternalError, UserError};
|
||||
use crate::fields_ids_map::FieldsIdsMap;
|
||||
|
@ -30,9 +29,6 @@ use crate::{
|
|||
BEU32, BEU64,
|
||||
};
|
||||
|
||||
/// The HNSW data-structure that we serialize, fill and search in.
|
||||
pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
|
||||
|
||||
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
||||
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ pub mod documents;
|
|||
|
||||
mod asc_desc;
|
||||
mod criterion;
|
||||
pub mod distance;
|
||||
mod error;
|
||||
mod external_documents_ids;
|
||||
pub mod facet;
|
||||
|
@ -33,7 +32,6 @@ use std::convert::{TryFrom, TryInto};
|
|||
use std::hash::BuildHasherDefault;
|
||||
|
||||
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
|
||||
pub use distance::dot_product_similarity;
|
||||
pub use filter_parser::{Condition, FilterCondition, Span, Token};
|
||||
use fxhash::{FxHasher32, FxHasher64};
|
||||
pub use grenad::CompressionType;
|
||||
|
|
|
@ -50,6 +50,7 @@ use self::vector_sort::VectorSort;
|
|||
use crate::error::FieldIdMapMissingEntry;
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::search::new::distinct::apply_distinct_rule;
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{
|
||||
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
|
||||
};
|
||||
|
@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>(
|
|||
geo_strategy: geo_sort::Strategy,
|
||||
limit_plus_offset: usize,
|
||||
target: &[f32],
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
|
||||
// query graph search
|
||||
|
||||
|
@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>(
|
|||
target.to_vec(),
|
||||
vector_candidates,
|
||||
limit_plus_offset,
|
||||
distribution_shift,
|
||||
)?;
|
||||
ranking_rules.push(Box::new(vector_sort));
|
||||
vector = true;
|
||||
|
@ -515,8 +518,14 @@ pub fn execute_vector_search(
|
|||
|
||||
/// FIXME: input universe = universe & documents_with_vectors
|
||||
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
|
||||
let ranking_rules =
|
||||
get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?;
|
||||
let ranking_rules = get_ranking_rules_for_vector(
|
||||
ctx,
|
||||
sort_criteria,
|
||||
geo_strategy,
|
||||
from + length,
|
||||
vector,
|
||||
None,
|
||||
)?;
|
||||
|
||||
let mut placeholder_search_logger = logger::DefaultSearchLogger;
|
||||
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
|
||||
|
|
|
@ -5,6 +5,7 @@ use roaring::RoaringBitmap;
|
|||
|
||||
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
|
||||
use crate::score_details::{self, ScoreDetails};
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{DocumentId, Result, SearchContext, SearchLogger};
|
||||
|
||||
pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
||||
|
@ -13,6 +14,7 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
|||
vector_candidates: RoaringBitmap,
|
||||
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
|
||||
limit: usize,
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
||||
|
@ -21,6 +23,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
target: Vec<f32>,
|
||||
vector_candidates: RoaringBitmap,
|
||||
limit: usize,
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
query: None,
|
||||
|
@ -28,6 +31,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
vector_candidates,
|
||||
cached_sorted_docids: Default::default(),
|
||||
limit,
|
||||
distribution_shift,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -52,7 +56,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
for reader in readers.iter() {
|
||||
let nns_by_vector = reader.nns_by_vector(
|
||||
ctx.txn,
|
||||
&target,
|
||||
target,
|
||||
self.limit,
|
||||
None,
|
||||
Some(&self.vector_candidates),
|
||||
|
@ -66,6 +70,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
|||
}
|
||||
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
|
||||
self.cached_sorted_docids = results.into_iter();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
|
|||
}));
|
||||
}
|
||||
|
||||
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
|
||||
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
|
||||
if self.vector_candidates.contains(docid) {
|
||||
let score = 1.0 - distance;
|
||||
let score = self
|
||||
.distribution_shift
|
||||
.map(|distribution| distribution.shift(score))
|
||||
.unwrap_or(score);
|
||||
return Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: RoaringBitmap::from_iter([docid]),
|
||||
score: ScoreDetails::Vector(score_details::Vector {
|
||||
target_vector: self.target.clone(),
|
||||
value_similarity: Some((vector, 1.0 - distance)),
|
||||
value_similarity: Some((vector, score)),
|
||||
}),
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -415,7 +415,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||
|
||||
let mut deleted_index = None;
|
||||
for (index, writer) in writers.iter().enumerate() {
|
||||
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
|
||||
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
|
||||
// uses invariant: vectors are packed in the first writers.
|
||||
break;
|
||||
};
|
||||
|
@ -429,7 +429,7 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||
if let Some(deleted_index) = deleted_index {
|
||||
let mut last_index_with_a_vector = None;
|
||||
for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
|
||||
let Some(candidate) = writer.item_vector(&wtxn, docid)? else {
|
||||
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
|
||||
break;
|
||||
};
|
||||
last_index_with_a_vector = Some((index, candidate));
|
||||
|
|
|
@ -140,3 +140,47 @@ impl Embedder {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DistributionShift {
|
||||
pub current_mean: f32,
|
||||
pub current_sigma: f32,
|
||||
}
|
||||
|
||||
impl DistributionShift {
|
||||
/// `None` if sigma <= 0.
|
||||
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
|
||||
if sigma <= 0.0 {
|
||||
None
|
||||
} else {
|
||||
Some(Self { current_mean: mean, current_sigma: sigma })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shift(&self, score: f32) -> f32 {
|
||||
// <https://math.stackexchange.com/a/2894689>
|
||||
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
||||
// The parameters we're given is the mean and sigma of the native result distribution.
|
||||
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
|
||||
|
||||
let target_mean = 0.5;
|
||||
let target_sigma = 0.4;
|
||||
|
||||
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
||||
let factor = target_sigma / self.current_sigma;
|
||||
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
||||
let offset = target_mean - (factor * self.current_mean);
|
||||
|
||||
let mut score = factor * score + offset;
|
||||
|
||||
// clamp the final score in the ]0, 1] interval.
|
||||
if score <= 0.0 {
|
||||
score = f32::EPSILON;
|
||||
}
|
||||
if score > 1.0 {
|
||||
score = 1.0;
|
||||
}
|
||||
|
||||
score
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue