From c30ffcf6ac7972b41215dc99faf7b9ddfde79c93 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 11 Jul 2024 07:48:43 -0400 Subject: [PATCH] Enable optimized handling of bytes (#2003) * Enable optimized handling of bytes * Implement byte buffer de/serialization * Use serde_bytes w/ alloc (no_std compatible) --- Cargo.lock | 10 ++++++++++ Cargo.toml | 1 + crates/burn-core/src/record/serde/data.rs | 8 ++++++++ crates/burn-core/src/record/serde/de.rs | 4 ++-- crates/burn-core/src/record/serde/ser.rs | 4 ++-- crates/burn-tensor/Cargo.toml | 1 + crates/burn-tensor/src/tensor/data.rs | 1 + 7 files changed, 25 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6d452da6e..55e58354f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -727,6 +727,7 @@ dependencies = [ "rand", "rand_distr", "serde", + "serde_bytes", ] [[package]] @@ -4886,6 +4887,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "serde_bytes" +version = "0.11.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" version = "1.0.204" diff --git a/Cargo.toml b/Cargo.toml index c54e08217..ea3307837 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ rstest = "0.19.0" rusqlite = { version = "0.31.0" } rust-format = { version = "0.3.4" } sanitize-filename = "0.5.0" +serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std serde_rusqlite = "0.35.0" serial_test = "3.1.1" spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } diff --git a/crates/burn-core/src/record/serde/data.rs b/crates/burn-core/src/record/serde/data.rs index a18ade924..27a55c5b2 100644 --- a/crates/burn-core/src/record/serde/data.rs +++ b/crates/burn-core/src/record/serde/data.rs @@ -183,6 +183,14 @@ impl NestedValue { } } + /// Get the nested value as a vector of bytes. + pub fn as_bytes(self) -> Option> { + match self { + NestedValue::U8s(u) => Some(u), + _ => None, + } + } + /// Deserialize a nested value into a record type. pub fn try_into_record(self, device: &B::Device) -> Result where diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 0f72ae4f4..3c93afed1 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -229,11 +229,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { unimplemented!("deserialize_bytes is not implemented") } - fn deserialize_byte_buf(self, _visitor: V) -> Result + fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize_byte_buf is not implemented") + visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap()) } fn deserialize_option(self, visitor: V) -> Result diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index 088b82dbf..44844160a 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -107,8 +107,8 @@ impl SerializerTrait for Serializer { unimplemented!() } - fn serialize_bytes(self, _v: &[u8]) -> Result { - unimplemented!() + fn serialize_bytes(self, v: &[u8]) -> Result { + Ok(NestedValue::U8s(v.to_vec())) } fn serialize_none(self) -> Result { diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 7a8ef8b72..c23bde5e1 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -34,6 +34,7 @@ hashbrown = { workspace = true } # no_std compatible # Serialization serde = { workspace = true } +serde_bytes = { workspace = true } [dev-dependencies] rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index c69184e01..f3c8b3843 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -32,6 +32,7 @@ pub enum DataError { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct TensorData { /// The values of the tensor (as bytes). + #[serde(with = "serde_bytes")] pub bytes: Vec, /// The shape of the tensor.