mirror of https://github.com/tracel-ai/burn.git
Fix: constant record loading (#1902)
This commit is contained in:
parent
263add23a0
commit
e758fd43db
|
@ -21,6 +21,12 @@ where
|
|||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let is_constant = self.num_params() == 0;
|
||||
|
||||
if is_constant {
|
||||
return self;
|
||||
}
|
||||
|
||||
self.zip(record)
|
||||
.map(|(module, record)| module.load_record(record))
|
||||
}
|
||||
|
@ -89,6 +95,14 @@ where
|
|||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
assert_eq!(
|
||||
self.len(),
|
||||
record.len(),
|
||||
r#"[Load Record Error] The vec record does not the same length as the module.
|
||||
Make sure you module initialization is compatible with the record being loaded.
|
||||
"#,
|
||||
);
|
||||
|
||||
self.into_iter()
|
||||
.zip(record)
|
||||
.map(|(module, record)| module.load_record(record))
|
||||
|
@ -267,3 +281,28 @@ impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
|
|||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn dont_override_constant_module_when_loading_record() {
|
||||
let module = Some(42);
|
||||
|
||||
let record = Module::<TestBackend>::into_record(module);
|
||||
let loaded = Module::<TestBackend>::load_record(module, record);
|
||||
|
||||
assert_eq!(loaded, module);
|
||||
}
|
||||
#[test]
|
||||
fn dont_override_constant_module_when_loading_none_record() {
|
||||
let module = Some(42);
|
||||
|
||||
let record = None;
|
||||
let loaded = Module::<TestBackend>::load_record(module, record);
|
||||
|
||||
assert_eq!(loaded, module);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue