mirror of https://github.com/tracel-ai/burn.git
Remove copy restriction for const generic modules (#2222)
This commit is contained in:
parent
cc214d366c
commit
59d41bd4b2
|
@ -173,8 +173,7 @@ where
|
||||||
|
|
||||||
impl<const N: usize, T, B> Module<B> for [T; N]
|
impl<const N: usize, T, B> Module<B> for [T; N]
|
||||||
where
|
where
|
||||||
T: Module<B> + Debug + Send + Clone + Copy,
|
T: Module<B> + Debug + Send + Clone,
|
||||||
T::Record: Debug,
|
|
||||||
B: Backend,
|
B: Backend,
|
||||||
{
|
{
|
||||||
type Record = [T::Record; N];
|
type Record = [T::Record; N];
|
||||||
|
@ -245,16 +244,14 @@ impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
|
||||||
|
|
||||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||||
where
|
where
|
||||||
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
|
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||||
T::InnerModule: Copy + Debug,
|
T::InnerModule: Debug,
|
||||||
<T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
|
|
||||||
<T as Module<B>>::Record: Debug,
|
|
||||||
B: AutodiffBackend,
|
B: AutodiffBackend,
|
||||||
{
|
{
|
||||||
type InnerModule = [T::InnerModule; N];
|
type InnerModule = [T::InnerModule; N];
|
||||||
|
|
||||||
fn valid(&self) -> Self::InnerModule {
|
fn valid(&self) -> Self::InnerModule {
|
||||||
self.map(|module| module.valid())
|
self.clone().map(|module| module.valid())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,11 @@ impl<B: Backend> ModuleBasic<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
struct ModuleWithConstGeneric<B: Backend, const N: usize> {
|
||||||
|
modules: [ModuleBasic<B>; N],
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
struct ModuleWithGenericModule<B: Backend, M> {
|
struct ModuleWithGenericModule<B: Backend, M> {
|
||||||
module: M,
|
module: M,
|
||||||
|
@ -151,6 +156,44 @@ mod state {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_load_from_record_const_generic() {
|
||||||
|
let device = <TestBackend as Backend>::Device::default();
|
||||||
|
let module_1 = ModuleWithConstGeneric {
|
||||||
|
modules: [
|
||||||
|
ModuleBasic::<TestBackend>::new(&device),
|
||||||
|
ModuleBasic::<TestBackend>::new(&device),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let mut module_2 = ModuleWithConstGeneric {
|
||||||
|
modules: [
|
||||||
|
ModuleBasic::<TestBackend>::new(&device),
|
||||||
|
ModuleBasic::<TestBackend>::new(&device),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let state_1 = module_1.clone().into_record();
|
||||||
|
|
||||||
|
assert_ne!(
|
||||||
|
module_1.modules[0].weight_basic.to_data(),
|
||||||
|
module_2.modules[0].weight_basic.to_data(),
|
||||||
|
);
|
||||||
|
assert_ne!(
|
||||||
|
module_1.modules[1].weight_basic.to_data(),
|
||||||
|
module_2.modules[1].weight_basic.to_data(),
|
||||||
|
);
|
||||||
|
|
||||||
|
module_2 = module_2.load_record(state_1);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
module_1.modules[0].weight_basic.to_data(),
|
||||||
|
module_2.modules[0].weight_basic.to_data(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
module_1.modules[1].weight_basic.to_data(),
|
||||||
|
module_2.modules[1].weight_basic.to_data(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Can't parse record from a different variant")]
|
#[should_panic(expected = "Can't parse record from a different variant")]
|
||||||
fn should_panic_load_from_incorrect_enum_variant() {
|
fn should_panic_load_from_incorrect_enum_variant() {
|
||||||
|
|
Loading…
Reference in New Issue