mirror of https://github.com/tracel-ai/burn.git
Enable optimized handling of bytes (#2003)
* Enable optimized handling of bytes * Implement byte buffer de/serialization * Use serde_bytes w/ alloc (no_std compatible)
This commit is contained in:
parent
69be99b802
commit
c30ffcf6ac
|
@ -727,6 +727,7 @@ dependencies = [
|
|||
"rand",
|
||||
"rand_distr",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4886,6 +4887,15 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_bytes"
|
||||
version = "0.11.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.204"
|
||||
|
|
|
@ -67,6 +67,7 @@ rstest = "0.19.0"
|
|||
rusqlite = { version = "0.31.0" }
|
||||
rust-format = { version = "0.3.4" }
|
||||
sanitize-filename = "0.5.0"
|
||||
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
|
||||
serde_rusqlite = "0.35.0"
|
||||
serial_test = "3.1.1"
|
||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||
|
|
|
@ -183,6 +183,14 @@ impl NestedValue {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the nested value as a vector of bytes.
|
||||
pub fn as_bytes(self) -> Option<Vec<u8>> {
|
||||
match self {
|
||||
NestedValue::U8s(u) => Some(u),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize a nested value into a record type.
|
||||
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
|
||||
where
|
||||
|
|
|
@ -229,11 +229,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
|
|||
unimplemented!("deserialize_bytes is not implemented")
|
||||
}
|
||||
|
||||
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
unimplemented!("deserialize_byte_buf is not implemented")
|
||||
visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap())
|
||||
}
|
||||
|
||||
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
|
|
|
@ -107,8 +107,8 @@ impl SerializerTrait for Serializer {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
||||
unimplemented!()
|
||||
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
|
||||
Ok(NestedValue::U8s(v.to_vec()))
|
||||
}
|
||||
|
||||
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
|
||||
|
|
|
@ -34,6 +34,7 @@ hashbrown = { workspace = true } # no_std compatible
|
|||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_bytes = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
|
||||
|
|
|
@ -32,6 +32,7 @@ pub enum DataError {
|
|||
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TensorData {
|
||||
/// The values of the tensor (as bytes).
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub bytes: Vec<u8>,
|
||||
|
||||
/// The shape of the tensor.
|
||||
|
|
Loading…
Reference in New Issue