Enable optimized handling of bytes (#2003)

* Enable optimized handling of bytes

* Implement byte buffer de/serialization

* Use serde_bytes w/ alloc (no_std compatible)
This commit is contained in:
Guillaume Lagrange 2024-07-11 07:48:43 -04:00 committed by GitHub
parent 69be99b802
commit c30ffcf6ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 25 additions and 4 deletions

10
Cargo.lock generated
View File

@ -727,6 +727,7 @@ dependencies = [
"rand", "rand",
"rand_distr", "rand_distr",
"serde", "serde",
"serde_bytes",
] ]
[[package]] [[package]]
@ -4886,6 +4887,15 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "serde_bytes"
version = "0.11.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.204" version = "1.0.204"

View File

@ -67,6 +67,7 @@ rstest = "0.19.0"
rusqlite = { version = "0.31.0" } rusqlite = { version = "0.31.0" }
rust-format = { version = "0.3.4" } rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0" sanitize-filename = "0.5.0"
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
serde_rusqlite = "0.35.0" serde_rusqlite = "0.35.0"
serial_test = "3.1.1" serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }

View File

@ -183,6 +183,14 @@ impl NestedValue {
} }
} }
/// Get the nested value as a vector of bytes.
pub fn as_bytes(self) -> Option<Vec<u8>> {
match self {
NestedValue::U8s(u) => Some(u),
_ => None,
}
}
/// Deserialize a nested value into a record type. /// Deserialize a nested value into a record type.
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error> pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
where where

View File

@ -229,11 +229,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
unimplemented!("deserialize_bytes is not implemented") unimplemented!("deserialize_bytes is not implemented")
} }
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
unimplemented!("deserialize_byte_buf is not implemented") visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap())
} }
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>

View File

@ -107,8 +107,8 @@ impl SerializerTrait for Serializer {
unimplemented!() unimplemented!()
} }
fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> { fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
unimplemented!() Ok(NestedValue::U8s(v.to_vec()))
} }
fn serialize_none(self) -> Result<Self::Ok, Self::Error> { fn serialize_none(self) -> Result<Self::Ok, Self::Error> {

View File

@ -34,6 +34,7 @@ hashbrown = { workspace = true } # no_std compatible
# Serialization # Serialization
serde = { workspace = true } serde = { workspace = true }
serde_bytes = { workspace = true }
[dev-dependencies] [dev-dependencies]
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std

View File

@ -32,6 +32,7 @@ pub enum DataError {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TensorData { pub struct TensorData {
/// The values of the tensor (as bytes). /// The values of the tensor (as bytes).
#[serde(with = "serde_bytes")]
pub bytes: Vec<u8>, pub bytes: Vec<u8>,
/// The shape of the tensor. /// The shape of the tensor.