diff --git a/raptor-search/src/main.rs b/raptor-search/src/main.rs index 0b156719b..0e66954fe 100644 --- a/raptor-search/src/main.rs +++ b/raptor-search/src/main.rs @@ -12,7 +12,7 @@ use raptor::{load_map, DocIndexMap, RankedStream, LevBuilder}; fn search(map: &DocIndexMap, lev_builder: &LevBuilder, query: &str) { let mut automatons = Vec::new(); for query in query.split_whitespace() { - let lev = lev_builder.build_automaton(query); + let lev = lev_builder.get_automaton(query); automatons.push(lev); } @@ -46,12 +46,11 @@ fn search(map: &DocIndexMap, lev_builder: &LevBuilder, query: &str) { fn main() { drop(env_logger::init()); - let (elapsed, (lev_builder, map)) = measure_time(|| { - let lev_builder = LevBuilder::new(); - let map = load_map("map.fst", "values.vecs").unwrap(); - (lev_builder, map) - }); - println!("Loaded in {}", elapsed); + let (elapsed, map) = measure_time(|| load_map("map.fst", "values.vecs").unwrap()); + println!("{} to load the map", elapsed); + + let (elapsed, lev_builder) = measure_time(|| LevBuilder::new()); + println!("{} to load the levenshtein automaton", elapsed); match env::args().nth(1) { Some(query) => { diff --git a/src/levenshtein.rs b/src/levenshtein.rs index 0cc1334a7..e3acea8a0 100644 --- a/src/levenshtein.rs +++ b/src/levenshtein.rs @@ -15,13 +15,21 @@ impl LevBuilder { } } - pub fn build_automaton(&self, query: &str) -> DFA { - if query.len() <= 4 { - self.automatons[0].build_dfa(query) + pub fn get_automaton(&self, query: &str) -> Levenshtein { + let dfa = if query.len() <= 4 { + self.automatons[0].build_prefix_dfa(query) } else if query.len() <= 8 { - self.automatons[1].build_dfa(query) + self.automatons[1].build_prefix_dfa(query) } else { - self.automatons[2].build_dfa(query) - } + self.automatons[2].build_prefix_dfa(query) + }; + + Levenshtein { dfa, query_len: query.len() } } } + +#[derive(Clone)] +pub struct Levenshtein { + pub dfa: DFA, + pub query_len: usize, +} diff --git a/src/lib.rs b/src/lib.rs index f97326611..46c109d69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,9 @@ pub struct Match { /// The index in the attribute is limited to a maximum of `2^32` /// this is because we index only the first 1000 words in an attribute. pub attribute_index: u32, + + /// Whether the word that match is an exact match or a prefix. + pub is_exact: bool, } impl Match { @@ -92,6 +95,7 @@ impl Match { distance: 0, attribute: 0, attribute_index: 0, + is_exact: false, } } @@ -101,6 +105,7 @@ impl Match { distance: u8::max_value(), attribute: u8::max_value(), attribute_index: u32::max_value(), + is_exact: true, } } } diff --git a/src/rank.rs b/src/rank.rs index 59a2b1888..5bca5ebe9 100644 --- a/src/rank.rs +++ b/src/rank.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::{mem, vec, iter}; use DocIndexMap; use fst; -use levenshtein_automata::DFA; +use levenshtein::Levenshtein; use map::{ OpWithStateBuilder, UnionWithState, StreamWithStateBuilder, @@ -102,7 +102,12 @@ fn sum_of_words_position(lhs: &Document, rhs: &Document) -> Ordering { } fn exact(lhs: &Document, rhs: &Document) -> Ordering { - unimplemented!() + let contains_exact = |matches: &[Match]| matches.iter().any(|m| m.is_exact); + let key = |doc: &Document| -> usize { + GroupBy::new(&doc.matches, match_query_index).map(contains_exact).filter(|x| *x).count() + }; + + key(lhs).cmp(&key(rhs)) } pub struct Pool { @@ -155,6 +160,7 @@ impl IntoIterator for Pool { words_proximity, sum_of_words_attribute, sum_of_words_position, + exact, ]; for (i, sort) in sorts.iter().enumerate() { @@ -176,7 +182,7 @@ impl IntoIterator for Pool { pub enum RankedStream<'m, 'v> { Fed { inner: UnionWithState<'m, 'v, DocIndex, u32>, - automatons: Vec, + automatons: Vec, pool: Pool, }, Pours { @@ -185,10 +191,10 @@ pub enum RankedStream<'m, 'v> { } impl<'m, 'v> RankedStream<'m, 'v> { - pub fn new(map: &'m DocIndexMap, values: &'v Values, automatons: Vec, limit: usize) -> Self { + pub fn new(map: &'m DocIndexMap, values: &'v Values, automatons: Vec, limit: usize) -> Self { let mut op = OpWithStateBuilder::new(values); - for automaton in automatons.iter().cloned() { + for automaton in automatons.iter().map(|l| l.dfa.clone()) { let stream = map.as_map().search(automaton).with_state(); op.push(stream); } @@ -216,7 +222,7 @@ impl<'m, 'v, 'a> fst::Streamer<'a> for RankedStream<'m, 'v> { match self { RankedStream::Fed { inner, automatons, pool } => { match inner.next() { - Some((_string, indexed_values)) => { + Some((string, indexed_values)) => { for iv in indexed_values { // TODO extend documents matches by batch of query_index @@ -224,10 +230,8 @@ impl<'m, 'v, 'a> fst::Streamer<'a> for RankedStream<'m, 'v> { // have an invalid distance *before* adding them // to the matches of the documents and, that way, avoid a sort - // let string = unsafe { str::from_utf8_unchecked(_string) }; - // println!("for {:15} ", string); - - let distance = automatons[iv.index].distance(iv.state).to_u8(); + let automaton = &automatons[iv.index]; + let distance = automaton.dfa.distance(iv.state).to_u8(); // TODO remove the Pool system ! // this is an internal Pool rule but @@ -240,6 +244,7 @@ impl<'m, 'v, 'a> fst::Streamer<'a> for RankedStream<'m, 'v> { distance: distance, attribute: di.attribute, attribute_index: di.attribute_index, + is_exact: string.len() == automaton.query_len, }; matches.entry(di.document) .and_modify(|ms: &mut Vec<_>| ms.push(match_)) @@ -249,6 +254,7 @@ impl<'m, 'v, 'a> fst::Streamer<'a> for RankedStream<'m, 'v> { } }, None => { + // TODO remove this when NLL are here ! transfert_pool = Some(mem::replace(pool, Pool::new(1, 1))); }, }