From bff4961426e6fb4341caa9f1c124d47f5e104271 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 21 Feb 2024 17:03:34 -0500 Subject: [PATCH] Add enum module support (#1337) --- crates/burn-core/tests/derive_module.rs | 67 ++++++ crates/burn-derive/src/module/base.rs | 3 +- crates/burn-derive/src/module/codegen.rs | 4 +- crates/burn-derive/src/module/codegen_enum.rs | 192 ++++++++++++++++++ crates/burn-derive/src/module/mod.rs | 2 + crates/burn-derive/src/module/record_enum.rs | 39 ++++ crates/burn-derive/src/record/base.rs | 125 +----------- crates/burn-derive/src/record/codegen.rs | 151 ++++++++++++-- crates/burn-derive/src/record/item/codegen.rs | 19 ++ .../src/record/item/codegen_enum.rs | 110 ++++++++++ .../src/record/{ => item}/codegen_struct.rs | 12 +- crates/burn-derive/src/record/item/mod.rs | 3 + crates/burn-derive/src/record/mod.rs | 2 +- crates/burn-derive/src/shared/enum_variant.rs | 29 +++ 14 files changed, 621 insertions(+), 137 deletions(-) create mode 100644 crates/burn-derive/src/module/codegen_enum.rs create mode 100644 crates/burn-derive/src/module/record_enum.rs create mode 100644 crates/burn-derive/src/record/item/codegen.rs create mode 100644 crates/burn-derive/src/record/item/codegen_enum.rs rename crates/burn-derive/src/record/{ => item}/codegen_struct.rs (91%) create mode 100644 crates/burn-derive/src/record/item/mod.rs diff --git a/crates/burn-core/tests/derive_module.rs b/crates/burn-core/tests/derive_module.rs index 79d2dc9cb..831a27947 100644 --- a/crates/burn-core/tests/derive_module.rs +++ b/crates/burn-core/tests/derive_module.rs @@ -34,6 +34,23 @@ struct ModuleWithGenericModule { _backend: PhantomData, } +#[derive(Module, Debug)] +enum ModuleEnum { + Basic(ModuleBasic), + Composed(ModuleComposed), +} + +#[derive(Module, Debug)] +enum ModuleEnumNested { + AnotherEnum(ModuleEnum), +} + +#[derive(Module, Debug)] +enum ModuleEnumWithGenericModule> { + Basic(ModuleBasic), + Generic(ModuleWithGenericModule), +} + #[derive(Module, Debug)] pub struct ModuleComposed { weight: Param>, @@ -95,6 +112,46 @@ mod state { module_2.basic.weight_basic.to_data() ); } + + #[test] + fn should_load_from_record_enum() { + let device = ::Device::default(); + let module_1 = ModuleEnum::Basic(ModuleBasic::::new(&device)); + let mut module_2 = ModuleEnum::Basic(ModuleBasic::::new(&device)); + let state_1 = module_1.clone().into_record(); + + let ModuleEnum::Basic(module_1_basic) = module_1 else { + panic!("Invalid module type") + }; + let ModuleEnum::Basic(module_2_basic) = module_2.clone() else { + panic!("Invalid module type") + }; + assert_ne!( + module_1_basic.weight_basic.to_data(), + module_2_basic.weight_basic.to_data() + ); + + module_2 = module_2.load_record(state_1); + + let ModuleEnum::Basic(module_2_basic) = module_2 else { + panic!("Invalid module type") + }; + assert_eq!( + module_1_basic.weight_basic.to_data(), + module_2_basic.weight_basic.to_data() + ); + } + + #[test] + #[should_panic(expected = "Can't parse record from a different variant")] + fn should_panic_load_from_incorrect_enum_variant() { + let device = ::Device::default(); + let module_1 = ModuleEnum::Basic(ModuleBasic::::new(&device)); + let module_2 = ModuleEnum::Composed(ModuleComposed::::new(&device)); + let state_1 = module_1.clone().into_record(); + + module_2.load_record(state_1); + } } mod num_params { @@ -113,6 +170,16 @@ mod num_params { let module = ModuleComposed::::new(&device); assert_eq!(4 * 20 * 20, module.num_params()); } + + #[test] + fn should_calculate_num_params_enum() { + let device = ::Device::default(); + let module = ModuleEnum::Basic(ModuleBasic::::new(&device)); + assert_eq!(20 * 20, module.num_params()); + + let module = ModuleEnum::Composed(ModuleComposed::::new(&device)); + assert_eq!(4 * 20 * 20, module.num_params()); + } } #[cfg(feature = "std")] diff --git a/crates/burn-derive/src/module/base.rs b/crates/burn-derive/src/module/base.rs index 827ec9f9b..fa88242ac 100644 --- a/crates/burn-derive/src/module/base.rs +++ b/crates/burn-derive/src/module/base.rs @@ -1,5 +1,6 @@ use super::{ codegen::{generate_module_const, generate_module_standard}, + codegen_enum::EnumModuleCodegen, codegen_struct::StructModuleCodegen, }; use proc_macro::TokenStream; @@ -22,7 +23,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream { } syn::Data::Enum(_data) => { if has_backend { - panic!("Enum modules aren't supported yet.") + generate_module_standard(ast, EnumModuleCodegen::from_ast(ast)) } else { generate_module_const(ast) } diff --git a/crates/burn-derive/src/module/codegen.rs b/crates/burn-derive/src/module/codegen.rs index 57dadf1bf..f8ae69145 100644 --- a/crates/burn-derive/src/module/codegen.rs +++ b/crates/burn-derive/src/module/codegen.rs @@ -45,7 +45,7 @@ pub(crate) fn generate_module_standard( let record = codegen.record_codegen(); let record_name = Ident::new(format!("{}Record", name).as_str(), name.span()); - let record_struct = record.gen_record_type(&record_name, &generics.module); + let record_type = record.gen_record_type(&record_name, &generics.module); let (generics_module, generics_ty_module, generics_where_module) = generics.module.split_for_impl(); @@ -86,7 +86,7 @@ pub(crate) fn generate_module_standard( #clone_fn } - #record_struct + #record_type }; gen diff --git a/crates/burn-derive/src/module/codegen_enum.rs b/crates/burn-derive/src/module/codegen_enum.rs new file mode 100644 index 000000000..00a81edd1 --- /dev/null +++ b/crates/burn-derive/src/module/codegen_enum.rs @@ -0,0 +1,192 @@ +use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen}; +use crate::shared::enum_variant::{parse_variants, EnumVariant}; +use proc_macro2::{Ident, TokenStream}; +use quote::quote; + +pub(crate) struct EnumModuleCodegen { + pub variants: Vec, +} + +impl ModuleCodegen for EnumModuleCodegen { + type RecordCodegen = EnumModuleRecordCodegen; + + fn gen_num_params(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|_| { + quote! { + burn::module::Module::::num_params(module) + } + }); + + quote! { + fn num_params(&self) -> usize { + #match_body + } + } + } + + fn gen_visit(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|_| { + quote! { + burn::module::Module::visit(module, visitor) + } + }); + + quote! { + fn visit>(&self, visitor: &mut Visitor) { + #match_body + } + } + } + + fn gen_collect_devices(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|_| { + quote! { + burn::module::Module::::collect_devices(module, devices) + } + }); + + quote! { + fn collect_devices( + &self, + devices: burn::module::Devices + ) -> burn::module::Devices { + #match_body + } + } + } + + fn gen_to_device(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + Self::#variant(burn::module::Module::::to_device(module, device)) + } + }); + + quote! { + fn to_device(self, device: &B::Device) -> Self { + #match_body + } + } + } + + fn gen_fork(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + Self::#variant(burn::module::Module::::fork(module, device)) + } + }); + + quote! { + fn fork(self, device: &B::Device) -> Self { + #match_body + } + } + } + + fn gen_map(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + Self::#variant(burn::module::Module::::map(module, mapper)) + } + }); + + quote! { + fn map>(self, mapper: &mut Mapper) -> Self { + #match_body + } + } + } + + fn gen_valid(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + Self::InnerModule::#variant(burn::module::AutodiffModule::::valid(module)) + } + }); + + quote! { + fn valid(&self) -> Self::InnerModule { + #match_body + } + } + } + + fn gen_into_record(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + Self::Record::#variant(burn::module::Module::::into_record(module)) + } + }); + + quote! { + fn into_record(self) -> Self::Record { + #match_body + } + } + } + + fn gen_load_record(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + { + let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");}; + Self::#variant(burn::module::Module::::load_record(module, r)) + } + } + }); + + quote! { + fn load_record(self, record: Self::Record) -> Self { + #match_body + } + } + } + + fn gen_clone(&self) -> TokenStream { + let match_body = self.gen_variants_match_fn(|variant| { + quote! { + Self::#variant(module.clone()) + } + }); + + quote! { + fn clone(&self) -> Self { + #match_body + } + } + } + + fn record_codegen(self) -> Self::RecordCodegen { + EnumModuleRecordCodegen::new(self.variants) + } +} + +impl EnumModuleCodegen { + pub fn from_ast(ast: &syn::DeriveInput) -> Self { + Self { + variants: parse_variants(ast), + } + } + + /// Generate the enum variants' match arm with the provided function + fn gen_variants_match_fn(&self, func: F) -> TokenStream + where + F: Fn(Ident) -> TokenStream, + { + let mut match_arms = quote! {}; + + for variant in self.variants.iter() { + let name = &variant.ident; + let arm_pattern = quote! {Self::#name(module)}; + let arm_code = func(name.clone()); + + match_arms.extend(quote! {#arm_pattern => #arm_code,}) + } + + quote! { + match self { + #match_arms + } + } + } +} diff --git a/crates/burn-derive/src/module/mod.rs b/crates/burn-derive/src/module/mod.rs index 6d4ec6c79..95642fcd7 100644 --- a/crates/burn-derive/src/module/mod.rs +++ b/crates/burn-derive/src/module/mod.rs @@ -1,7 +1,9 @@ pub(crate) mod codegen; +pub(crate) mod codegen_enum; pub(crate) mod codegen_struct; pub(crate) mod display; pub(crate) mod record; +pub(crate) mod record_enum; pub(crate) mod record_struct; mod base; diff --git a/crates/burn-derive/src/module/record_enum.rs b/crates/burn-derive/src/module/record_enum.rs new file mode 100644 index 000000000..d4c003532 --- /dev/null +++ b/crates/burn-derive/src/module/record_enum.rs @@ -0,0 +1,39 @@ +use crate::shared::enum_variant::EnumVariant; +use proc_macro2::{Ident, TokenStream}; +use quote::quote; +use syn::Generics; + +use super::record::ModuleRecordCodegen; + +#[derive(new)] +pub(crate) struct EnumModuleRecordCodegen { + variants: Vec, +} + +impl ModuleRecordCodegen for EnumModuleRecordCodegen { + fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { + let mut variants = quote! {}; + + // Capture the Record enum variant types + for variant in self.variants.iter() { + let ty = &variant.ty; + let name = &variant.ident; + + variants.extend(quote! { + /// The module record associative type. + #name(<#ty as burn::module::Module>::Record), + }); + } + + let (generics, _generics_ty, generics_where) = generics.split_for_impl(); + + quote! { + + /// The record type for the module. + #[derive(burn::record::Record)] + pub enum #record_name #generics #generics_where { + #variants + } + } + } +} diff --git a/crates/burn-derive/src/record/base.rs b/crates/burn-derive/src/record/base.rs index 61a12dbd7..e68191bc0 100644 --- a/crates/burn-derive/src/record/base.rs +++ b/crates/burn-derive/src/record/base.rs @@ -1,122 +1,13 @@ -use proc_macro2::{Ident, TokenStream}; -use quote::quote; -use syn::{parse_quote, Generics}; - -use super::{codegen::RecordItemCodegen, codegen_struct::StructRecordItemCodegen}; -use crate::shared::field::{parse_fields, FieldTypeAnalyzer}; +use super::{ + codegen::generate_record, + item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen}, +}; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream { - let record_gen = RecordDeriveCodegen::from_ast(ast); - let item_struct = record_gen.gen_record_type(); - let record_impl = record_gen.gen_impl_record(); - - quote! { - #item_struct - #record_impl + match &ast.data { + syn::Data::Struct(_) => generate_record::(ast), + syn::Data::Enum(_) => generate_record::(ast), + syn::Data::Union(_) => panic!("Union modules aren't supported yet."), } .into() } - -struct RecordDeriveCodegen { - name_record: Ident, - name_item: Ident, - gen: StructRecordItemCodegen, - generics: Generics, - has_backend: bool, -} - -impl RecordDeriveCodegen { - pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self { - let name_record = ast.ident.clone(); - let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span()); - let has_backend = ast - .generics - .type_params() - .map(|param| param.ident == "B") - .reduce(|accum, is_backend| is_backend || accum) - .unwrap_or(false); - - Self { - name_record, - name_item, - gen: StructRecordItemCodegen::new( - parse_fields(ast) - .into_iter() - .map(FieldTypeAnalyzer::new) - .collect(), - ), - generics: ast.generics.clone(), - has_backend, - } - } - - /// Generate the record type with the correct generics. - pub(crate) fn gen_record_type(&self) -> TokenStream { - let param: syn::Generics = parse_quote! { }; - let mut generics = self.generics.clone(); - - for param in param.params.into_iter() { - generics.params.push(param); - } - - self.gen - .gen_item_type(&self.name_item, &generics, self.has_backend) - } - - /// Generate the implementation for the Record trait. - pub(crate) fn gen_impl_record(&self) -> TokenStream { - let name = &self.name_record; - let item_generics = self.record_item_generics(); - let (_, ty_generics_item, _) = item_generics.split_for_impl(); - let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); - - let impl_generics = if let Some(impl_generic) = self.impl_generics() { - impl_generic - } else { - quote! { #impl_generics } - }; - - let name_item = &self.name_item; - let into_item_fn = self.gen.gen_into_item(name_item); - let from_item_fn = self.gen.gen_from_item(); - - quote! { - impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { - type Item = #name_item #ty_generics_item; - - #into_item_fn - #from_item_fn - - } - } - } - - fn impl_generics(&self) -> Option { - if self.has_backend { - return None; - } - - let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; - let mut generics = self.generics.clone(); - generics.params.push(syn::GenericParam::Type(param)); - - let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl(); - - Some(quote! {#impl_generics}) - } - - fn record_item_generics(&self) -> Generics { - let param: syn::Generics = parse_quote! { }; - let mut generics = self.generics.clone(); - for param in param.params.into_iter() { - generics.params.push(param); - } - - if !self.has_backend { - let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; - generics.params.push(syn::GenericParam::Type(param)); - } - - generics - } -} diff --git a/crates/burn-derive/src/record/codegen.rs b/crates/burn-derive/src/record/codegen.rs index c61377e1e..7a76dcc3c 100644 --- a/crates/burn-derive/src/record/codegen.rs +++ b/crates/burn-derive/src/record/codegen.rs @@ -1,17 +1,140 @@ use proc_macro2::{Ident, TokenStream}; -use syn::Generics; +use quote::quote; +use syn::{parse_quote, Generics}; -/// Basic trait to be implemented for record generation. -pub(crate) trait RecordItemCodegen { - /// Generate the record item type (i.e a struct) - fn gen_item_type( - &self, - item_name: &Ident, - generics: &Generics, - has_backend: bool, - ) -> TokenStream; - /// Generate the into_item function. - fn gen_into_item(&self, item_name: &Ident) -> TokenStream; - /// Generate the from item function. - fn gen_from_item(&self) -> TokenStream; +use crate::record::item::codegen::RecordItemCodegen; + +pub(crate) fn generate_record(ast: &syn::DeriveInput) -> TokenStream { + let record_gen: RecordCodegen = RecordCodegen::from_ast(ast); + let item_type = record_gen.gen_record_type(); + let record_impl = record_gen.gen_impl_record(); + + quote! { + #item_type + #record_impl + } +} + +pub(crate) struct RecordCodegen { + /// Record type info. + ty: RecordType, + /// Record item code gen. + gen: G, +} + +impl RecordCodegen { + /// Generate the record type with the correct generics. + pub(crate) fn gen_record_type(&self) -> TokenStream { + // Add precision settings type bound + let param: syn::Generics = parse_quote! { }; + let mut generics = self.ty.generics.clone(); + + for param in param.params.into_iter() { + generics.params.push(param); + } + + // Generate the record item definition + self.gen + .gen_item_type(&self.ty.item, &generics, self.ty.has_backend) + } + + /// Generate the implementation for the Record trait. + pub(crate) fn gen_impl_record(&self) -> TokenStream { + // Capture the record type's generics and bounds in where clauses + let item_generics = self.record_item_generics(); + let (_, ty_generics_item, _) = item_generics.split_for_impl(); + let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl(); + + let impl_generics = if let Some(impl_generic) = self.impl_generics() { + impl_generic + } else { + quote! { #impl_generics } + }; + + let name_item = &self.ty.item; + let into_item_fn = self.gen.gen_into_item(name_item); + let from_item_fn = self.gen.gen_from_item(); + + // Return the generated stream of token trees (i.e., code to be generated) + let name = &self.ty.name; + quote! { + impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { + type Item = #name_item #ty_generics_item; + + #into_item_fn + #from_item_fn + + } + } + } + + /// Add backend generic type to the implementation block. + fn impl_generics(&self) -> Option { + if self.ty.has_backend { + return None; + } + + let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; + let mut generics = self.ty.generics.clone(); + generics.params.push(syn::GenericParam::Type(param)); + + let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl(); + + Some(quote! {#impl_generics}) + } + + /// Get the generics attached to the record item type. + fn record_item_generics(&self) -> Generics { + let param: syn::Generics = parse_quote! { }; + let mut generics = self.ty.generics.clone(); + for param in param.params.into_iter() { + generics.params.push(param); + } + + if !self.ty.has_backend { + let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; + generics.params.push(syn::GenericParam::Type(param)); + } + + generics + } + + pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self { + Self { + ty: RecordType::from_ast(ast), + gen: G::from_ast(ast), + } + } +} + +/// Information about a record type. +struct RecordType { + /// Record type name. + name: Ident, + /// Record item type name. + item: Ident, + /// Lifetimes and type parameters attached to the record type declaration. + generics: Generics, + /// Whether or not the record type should specify a backend generic. + has_backend: bool, +} + +impl RecordType { + fn from_ast(ast: &syn::DeriveInput) -> Self { + let name = ast.ident.clone(); + let item = Ident::new(format!("{}Item", name).as_str(), name.span()); + let has_backend = ast + .generics + .type_params() + .map(|param| param.ident == "B") + .reduce(|accum, is_backend| is_backend || accum) + .unwrap_or(false); + + Self { + name, + item, + generics: ast.generics.clone(), + has_backend, + } + } } diff --git a/crates/burn-derive/src/record/item/codegen.rs b/crates/burn-derive/src/record/item/codegen.rs new file mode 100644 index 000000000..ef2b5bc20 --- /dev/null +++ b/crates/burn-derive/src/record/item/codegen.rs @@ -0,0 +1,19 @@ +use proc_macro2::{Ident, TokenStream}; +use syn::Generics; + +/// Basic trait to be implemented for record generation. +pub(crate) trait RecordItemCodegen { + /// Initialize the record item. + fn from_ast(ast: &syn::DeriveInput) -> Self; + /// Generate the record item type. + fn gen_item_type( + &self, + item_name: &Ident, + generics: &Generics, + has_backend: bool, + ) -> TokenStream; + /// Generate the into_item function. + fn gen_into_item(&self, item_name: &Ident) -> TokenStream; + /// Generate the from item function. + fn gen_from_item(&self) -> TokenStream; +} diff --git a/crates/burn-derive/src/record/item/codegen_enum.rs b/crates/burn-derive/src/record/item/codegen_enum.rs new file mode 100644 index 000000000..112888a34 --- /dev/null +++ b/crates/burn-derive/src/record/item/codegen_enum.rs @@ -0,0 +1,110 @@ +use crate::shared::enum_variant::{parse_variants, EnumVariant}; +use proc_macro2::{Ident, TokenStream}; +use quote::quote; +use syn::{parse_quote, Generics}; + +use super::codegen::RecordItemCodegen; + +pub(crate) struct EnumRecordItemCodegen { + /// Enum variants. + variants: Vec, +} + +impl RecordItemCodegen for EnumRecordItemCodegen { + fn from_ast(ast: &syn::DeriveInput) -> Self { + Self { + variants: parse_variants(ast), + } + } + + fn gen_item_type( + &self, + item_name: &Ident, + generics: &Generics, + has_backend: bool, + ) -> TokenStream { + let mut variants = quote! {}; + let mut bounds = quote! {}; + + // Capture the Record enum variant types and names to transpose them in RecordItem + for variant in self.variants.iter() { + let ty = &variant.ty; + let name = &variant.ident; + + variants.extend(quote! { + /// Variant to be serialized. + #name(<#ty as burn::record::Record>::Item), + }); + + // Item types must implement serialization/deserialization + bounds.extend(quote! { + <#ty as burn::record::Record>::Item: burn::serde::Serialize + burn::serde::de::DeserializeOwned, + }); + } + let bound = bounds.to_string(); + + // Capture the type's generics and bounds in where clauses + let (generics, generics_where) = if !has_backend { + let mut generics = generics.clone(); + let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; + generics.params.push(syn::GenericParam::Type(param)); + let (generics, _, generics_where) = generics.split_for_impl(); + (quote! { #generics }, quote! { #generics_where }) + } else { + let (generics, _, generics_where) = generics.split_for_impl(); + (quote! { #generics }, quote! { #generics_where }) + }; + + // Return the generated stream of token trees (i.e., code to be generated) + quote! { + + /// The record item type for the module. + #[derive(burn::serde::Serialize, burn::serde::Deserialize)] + #[serde(crate = "burn::serde")] + #[serde(bound = #bound)] + pub enum #item_name #generics #generics_where { + #variants + } + } + } + + fn gen_into_item(&self, _item_name: &Ident) -> TokenStream { + let mut into_item_match_arms = quote! {}; + + for variant in self.variants.iter() { + let name = &variant.ident; + + into_item_match_arms.extend(quote! { + Self::#name(record) => Self::Item::#name(burn::record::Record::::into_item::(record)), + }); + } + + quote! { + fn into_item(self) -> Self::Item { + match self { + #into_item_match_arms + } + } + } + } + + fn gen_from_item(&self) -> TokenStream { + let mut from_item_match_arms = quote! {}; + + for variant in self.variants.iter() { + let name = &variant.ident; + + from_item_match_arms.extend(quote! { + Self::Item::#name(item) => Self::#name(burn::record::Record::::from_item::(item, device)), + }); + } + + quote! { + fn from_item(item: Self::Item, device: &B::Device) -> Self { + match item { + #from_item_match_arms + } + } + } + } +} diff --git a/crates/burn-derive/src/record/codegen_struct.rs b/crates/burn-derive/src/record/item/codegen_struct.rs similarity index 91% rename from crates/burn-derive/src/record/codegen_struct.rs rename to crates/burn-derive/src/record/item/codegen_struct.rs index 04e81af03..de1ffd23a 100644 --- a/crates/burn-derive/src/record/codegen_struct.rs +++ b/crates/burn-derive/src/record/item/codegen_struct.rs @@ -1,16 +1,24 @@ -use crate::shared::field::FieldTypeAnalyzer; +use crate::shared::field::{parse_fields, FieldTypeAnalyzer}; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{parse_quote, Generics}; use super::codegen::RecordItemCodegen; -#[derive(new)] pub(crate) struct StructRecordItemCodegen { fields: Vec, } impl RecordItemCodegen for StructRecordItemCodegen { + fn from_ast(ast: &syn::DeriveInput) -> Self { + Self { + fields: parse_fields(ast) + .into_iter() + .map(FieldTypeAnalyzer::new) + .collect(), + } + } + fn gen_item_type( &self, item_name: &Ident, diff --git a/crates/burn-derive/src/record/item/mod.rs b/crates/burn-derive/src/record/item/mod.rs new file mode 100644 index 000000000..6b2b0964a --- /dev/null +++ b/crates/burn-derive/src/record/item/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod codegen; +pub(crate) mod codegen_enum; +pub(crate) mod codegen_struct; diff --git a/crates/burn-derive/src/record/mod.rs b/crates/burn-derive/src/record/mod.rs index 0df0b4f32..31321f714 100644 --- a/crates/burn-derive/src/record/mod.rs +++ b/crates/burn-derive/src/record/mod.rs @@ -1,5 +1,5 @@ pub(crate) mod codegen; -pub(crate) mod codegen_struct; +pub(crate) mod item; mod base; pub(crate) use base::*; diff --git a/crates/burn-derive/src/shared/enum_variant.rs b/crates/burn-derive/src/shared/enum_variant.rs index ceed9fe48..5c059bf0f 100644 --- a/crates/burn-derive/src/shared/enum_variant.rs +++ b/crates/burn-derive/src/shared/enum_variant.rs @@ -49,3 +49,32 @@ where syn::Fields::Unit => (quote! {}, quote! {}), } } + +/// An enum variant (simplified). +pub(crate) struct EnumVariant { + pub ident: syn::Ident, + pub ty: syn::Type, +} +pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> Vec { + let mut variants = Vec::new(); + + if let syn::Data::Enum(enum_data) = &ast.data { + for variant in enum_data.variants.iter() { + if variant.fields.len() != 1 { + // No support for unit variants or variants with multiple fields + panic!("Enums are only supported for one field type") + } + + let field = variant.fields.iter().next().unwrap(); + + variants.push(EnumVariant { + ident: variant.ident.clone(), + ty: field.ty.clone(), + }); + } + } else { + panic!("Only enum can be derived") + } + + variants +}