mirror of https://github.com/tracel-ai/burn.git
Enable optimized handling of bytes (#2003)
* Enable optimized handling of bytes * Implement byte buffer de/serialization * Use serde_bytes w/ alloc (no_std compatible)
This commit is contained in:
parent
69be99b802
commit
c30ffcf6ac
|
@ -727,6 +727,7 @@ dependencies = [
|
||||||
"rand",
|
"rand",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_bytes",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4886,6 +4887,15 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_bytes"
|
||||||
|
version = "0.11.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.204"
|
version = "1.0.204"
|
||||||
|
|
|
@ -67,6 +67,7 @@ rstest = "0.19.0"
|
||||||
rusqlite = { version = "0.31.0" }
|
rusqlite = { version = "0.31.0" }
|
||||||
rust-format = { version = "0.3.4" }
|
rust-format = { version = "0.3.4" }
|
||||||
sanitize-filename = "0.5.0"
|
sanitize-filename = "0.5.0"
|
||||||
|
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
|
||||||
serde_rusqlite = "0.35.0"
|
serde_rusqlite = "0.35.0"
|
||||||
serial_test = "3.1.1"
|
serial_test = "3.1.1"
|
||||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||||
|
|
|
@ -183,6 +183,14 @@ impl NestedValue {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the nested value as a vector of bytes.
|
||||||
|
pub fn as_bytes(self) -> Option<Vec<u8>> {
|
||||||
|
match self {
|
||||||
|
NestedValue::U8s(u) => Some(u),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Deserialize a nested value into a record type.
|
/// Deserialize a nested value into a record type.
|
||||||
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
|
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
|
||||||
where
|
where
|
||||||
|
|
|
@ -229,11 +229,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
|
||||||
unimplemented!("deserialize_bytes is not implemented")
|
unimplemented!("deserialize_bytes is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||||
where
|
where
|
||||||
V: Visitor<'de>,
|
V: Visitor<'de>,
|
||||||
{
|
{
|
||||||
unimplemented!("deserialize_byte_buf is not implemented")
|
visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||||
|
|
|
@ -107,8 +107,8 @@ impl SerializerTrait for Serializer {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
||||||
unimplemented!()
|
Ok(NestedValue::U8s(v.to_vec()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
|
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
|
||||||
|
|
|
@ -34,6 +34,7 @@ hashbrown = { workspace = true } # no_std compatible
|
||||||
|
|
||||||
# Serialization
|
# Serialization
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
|
serde_bytes = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
|
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
|
||||||
|
|
|
@ -32,6 +32,7 @@ pub enum DataError {
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||||
pub struct TensorData {
|
pub struct TensorData {
|
||||||
/// The values of the tensor (as bytes).
|
/// The values of the tensor (as bytes).
|
||||||
|
#[serde(with = "serde_bytes")]
|
||||||
pub bytes: Vec<u8>,
|
pub bytes: Vec<u8>,
|
||||||
|
|
||||||
/// The shape of the tensor.
|
/// The shape of the tensor.
|
||||||
|
|
Loading…
Reference in New Issue