uses the latest version of heed to get rid of unsafe code

This commit is contained in:
Tamo 2024-05-16 16:11:08 +02:00
parent 897d25780e
commit 273c6e8c5c
3 changed files with 19 additions and 33 deletions

14
Cargo.lock generated
View File

@ -378,7 +378,9 @@ dependencies = [
[[package]] [[package]]
name = "arroy" name = "arroy"
version = "0.3.0" version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73897699bf04bac935c0b120990d2a511e91e563e0f9769f9c8bb983d98dfbc9"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"byteorder", "byteorder",
@ -2260,7 +2262,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]] [[package]]
name = "heed" name = "heed"
version = "0.20.0" version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f7acb9683d7c7068aa46d47557bfa4e35a277964b350d9504a87b03610163fd"
dependencies = [ dependencies = [
"bitflags 2.5.0", "bitflags 2.5.0",
"byteorder", "byteorder",
@ -2277,10 +2281,14 @@ dependencies = [
[[package]] [[package]]
name = "heed-traits" name = "heed-traits"
version = "0.20.0" version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb3130048d404c57ce5a1ac61a903696e8fcde7e8c2991e9fcfc1f27c3ef74ff"
[[package]] [[package]]
name = "heed-types" name = "heed-types"
version = "0.20.0" version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cb0d6ba3700c9a57e83c013693e3eddb68a6d9b6781cacafc62a0d992e8ddb3"
dependencies = [ dependencies = [
"bincode", "bincode",
"byteorder", "byteorder",
@ -3181,6 +3189,8 @@ checksum = "f9d642685b028806386b2b6e75685faadd3eb65a85fff7df711ce18446a422da"
[[package]] [[package]]
name = "lmdb-master-sys" name = "lmdb-master-sys"
version = "0.2.0" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc9048db3a58c0732d7236abc4909058f9d2708cfb6d7d047eb895fddec6419a"
dependencies = [ dependencies = [
"cc", "cc",
"doxygen-rs", "doxygen-rs",

View File

@ -1,5 +1,4 @@
use std::io::{ErrorKind, Write}; use std::io::{ErrorKind, Write};
use std::pin::Pin;
use actix_web::http::header::CONTENT_TYPE; use actix_web::http::header::CONTENT_TYPE;
use actix_web::web::Data; use actix_web::web::Data;
@ -627,31 +626,19 @@ fn some_documents<'a, 't: 'a>(
pub struct DocumentsStreamer { pub struct DocumentsStreamer {
attributes_to_retrieve: Option<Vec<String>>, attributes_to_retrieve: Option<Vec<String>>,
documents: RoaringBitmap, documents: RoaringBitmap,
// safety: The `rtxn` contains a reference to the index thus: rtxn: RoTxn<'static>,
// - The `rtxn` MUST BE dropped before the index. index: Index,
// - The index MUST BE `Pin`ned in RAM and never moved.
rtxn: Option<RoTxn<'static>>,
index: Pin<Box<Index>>,
pub total_documents: u64, pub total_documents: u64,
} }
impl Drop for DocumentsStreamer {
fn drop(&mut self) {
// safety: we drop the rtxn before the index
self.rtxn = None;
}
}
impl Serialize for DocumentsStreamer { impl Serialize for DocumentsStreamer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: serde::Serializer, S: serde::Serializer,
{ {
let rtxn = self.rtxn.as_ref().unwrap();
let mut seq = serializer.serialize_seq(Some(self.documents.len() as usize)).unwrap(); let mut seq = serializer.serialize_seq(Some(self.documents.len() as usize)).unwrap();
let documents = some_documents(&self.index, rtxn, self.documents.iter()).unwrap(); let documents = some_documents(&self.index, &self.rtxn, self.documents.iter()).unwrap();
for document in documents { for document in documents {
let document = document.unwrap(); let document = document.unwrap();
let document = match self.attributes_to_retrieve { let document = match self.attributes_to_retrieve {
@ -675,9 +662,7 @@ fn retrieve_documents(
filter: Option<Value>, filter: Option<Value>,
attributes_to_retrieve: Option<Vec<String>>, attributes_to_retrieve: Option<Vec<String>>,
) -> Result<DocumentsStreamer, ResponseError> { ) -> Result<DocumentsStreamer, ResponseError> {
// safety: The index MUST NOT move while we hold the `rtxn` on it let rtxn = index.static_read_txn()?;
let index = Box::pin(index);
let rtxn = index.read_txn()?;
let filter = &filter; let filter = &filter;
let filter = if let Some(filter) = filter { let filter = if let Some(filter) = filter {
@ -702,10 +687,7 @@ fn retrieve_documents(
total_documents: candidates.len(), total_documents: candidates.len(),
attributes_to_retrieve, attributes_to_retrieve,
documents: candidates.into_iter().skip(offset).take(limit).collect(), documents: candidates.into_iter().skip(offset).take(limit).collect(),
// safety: It is safe to make the lifetime in the Rtxn static because it points to the index right below. rtxn,
// The index is `Pin`ned on the RAM and won't move even if the structure is moved.
// The `rtxn` is held in an `Option`, so we're able to drop it before dropping the index.
rtxn: Some(unsafe { std::mem::transmute(rtxn) }),
index, index,
}) })
} }

View File

@ -30,12 +30,7 @@ grenad = { version = "0.4.6", default-features = false, features = [
"rayon", "rayon",
"tempfile", "tempfile",
] } ] }
# heed = { version = "0.20.0", default-features = false, features = [ heed = { version = "0.20.1", default-features = false, features = [
# "serde-json",
# "serde-bincode",
# "read-txn-no-tls",
# ] }
heed = { path = "../../heed/heed", default-features = false, features = [
"serde-json", "serde-json",
"serde-bincode", "serde-bincode",
"read-txn-no-tls", "read-txn-no-tls",
@ -87,8 +82,7 @@ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls",
] } ] }
tiktoken-rs = "0.5.8" tiktoken-rs = "0.5.8"
liquid = "0.26.4" liquid = "0.26.4"
# arroy = "0.2.0" arroy = "0.3.1"
arroy = { path = "../../arroy" }
rand = "0.8.5" rand = "0.8.5"
tracing = "0.1.40" tracing = "0.1.40"
ureq = { version = "2.9.7", features = ["json"] } ureq = { version = "2.9.7", features = ["json"] }