mirror of https://github.com/tracel-ai/burn.git
Remove copy restriction for const generic modules (#2222)
This commit is contained in:
parent
cc214d366c
commit
59d41bd4b2
|
@ -173,8 +173,7 @@ where
|
|||
|
||||
impl<const N: usize, T, B> Module<B> for [T; N]
|
||||
where
|
||||
T: Module<B> + Debug + Send + Clone + Copy,
|
||||
T::Record: Debug,
|
||||
T: Module<B> + Debug + Send + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
type Record = [T::Record; N];
|
||||
|
@ -245,16 +244,14 @@ impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
|
|||
|
||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
|
||||
T::InnerModule: Copy + Debug,
|
||||
<T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
|
||||
<T as Module<B>>::Record: Debug,
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
T::InnerModule: Debug,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type InnerModule = [T::InnerModule; N];
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.map(|module| module.valid())
|
||||
self.clone().map(|module| module.valid())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,11 @@ impl<B: Backend> ModuleBasic<B> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct ModuleWithConstGeneric<B: Backend, const N: usize> {
|
||||
modules: [ModuleBasic<B>; N],
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct ModuleWithGenericModule<B: Backend, M> {
|
||||
module: M,
|
||||
|
@ -151,6 +156,44 @@ mod state {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load_from_record_const_generic() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleWithConstGeneric {
|
||||
modules: [
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
],
|
||||
};
|
||||
let mut module_2 = ModuleWithConstGeneric {
|
||||
modules: [
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
ModuleBasic::<TestBackend>::new(&device),
|
||||
],
|
||||
};
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
assert_ne!(
|
||||
module_1.modules[0].weight_basic.to_data(),
|
||||
module_2.modules[0].weight_basic.to_data(),
|
||||
);
|
||||
assert_ne!(
|
||||
module_1.modules[1].weight_basic.to_data(),
|
||||
module_2.modules[1].weight_basic.to_data(),
|
||||
);
|
||||
|
||||
module_2 = module_2.load_record(state_1);
|
||||
|
||||
assert_eq!(
|
||||
module_1.modules[0].weight_basic.to_data(),
|
||||
module_2.modules[0].weight_basic.to_data(),
|
||||
);
|
||||
assert_eq!(
|
||||
module_1.modules[1].weight_basic.to_data(),
|
||||
module_2.modules[1].weight_basic.to_data(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Can't parse record from a different variant")]
|
||||
fn should_panic_load_from_incorrect_enum_variant() {
|
||||
|
|
Loading…
Reference in New Issue