mirror of https://github.com/tracel-ai/burn.git
Add enum module support (#1337)
This commit is contained in:
parent
4427768570
commit
bff4961426
|
@ -34,6 +34,23 @@ struct ModuleWithGenericModule<B: Backend, M> {
|
|||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
enum ModuleEnum<B: Backend> {
|
||||
Basic(ModuleBasic<B>),
|
||||
Composed(ModuleComposed<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
enum ModuleEnumNested<B: Backend> {
|
||||
AnotherEnum(ModuleEnum<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
enum ModuleEnumWithGenericModule<B: Backend, M: Module<B>> {
|
||||
Basic(ModuleBasic<B>),
|
||||
Generic(ModuleWithGenericModule<B, M>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModuleComposed<B: Backend> {
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
|
@ -95,6 +112,46 @@ mod state {
|
|||
module_2.basic.weight_basic.to_data()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load_from_record_enum() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
let mut module_2 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
let ModuleEnum::Basic(module_1_basic) = module_1 else {
|
||||
panic!("Invalid module type")
|
||||
};
|
||||
let ModuleEnum::Basic(module_2_basic) = module_2.clone() else {
|
||||
panic!("Invalid module type")
|
||||
};
|
||||
assert_ne!(
|
||||
module_1_basic.weight_basic.to_data(),
|
||||
module_2_basic.weight_basic.to_data()
|
||||
);
|
||||
|
||||
module_2 = module_2.load_record(state_1);
|
||||
|
||||
let ModuleEnum::Basic(module_2_basic) = module_2 else {
|
||||
panic!("Invalid module type")
|
||||
};
|
||||
assert_eq!(
|
||||
module_1_basic.weight_basic.to_data(),
|
||||
module_2_basic.weight_basic.to_data()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Can't parse record from a different variant")]
|
||||
fn should_panic_load_from_incorrect_enum_variant() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
let module_2 = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
|
||||
let state_1 = module_1.clone().into_record();
|
||||
|
||||
module_2.load_record(state_1);
|
||||
}
|
||||
}
|
||||
|
||||
mod num_params {
|
||||
|
@ -113,6 +170,16 @@ mod num_params {
|
|||
let module = ModuleComposed::<TestBackend>::new(&device);
|
||||
assert_eq!(4 * 20 * 20, module.num_params());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_calculate_num_params_enum() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
|
||||
assert_eq!(20 * 20, module.num_params());
|
||||
|
||||
let module = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
|
||||
assert_eq!(4 * 20 * 20, module.num_params());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::{
|
||||
codegen::{generate_module_const, generate_module_standard},
|
||||
codegen_enum::EnumModuleCodegen,
|
||||
codegen_struct::StructModuleCodegen,
|
||||
};
|
||||
use proc_macro::TokenStream;
|
||||
|
@ -22,7 +23,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
}
|
||||
syn::Data::Enum(_data) => {
|
||||
if has_backend {
|
||||
panic!("Enum modules aren't supported yet.")
|
||||
generate_module_standard(ast, EnumModuleCodegen::from_ast(ast))
|
||||
} else {
|
||||
generate_module_const(ast)
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
|
||||
let record = codegen.record_codegen();
|
||||
let record_name = Ident::new(format!("{}Record", name).as_str(), name.span());
|
||||
let record_struct = record.gen_record_type(&record_name, &generics.module);
|
||||
let record_type = record.gen_record_type(&record_name, &generics.module);
|
||||
|
||||
let (generics_module, generics_ty_module, generics_where_module) =
|
||||
generics.module.split_for_impl();
|
||||
|
@ -86,7 +86,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#clone_fn
|
||||
}
|
||||
|
||||
#record_struct
|
||||
#record_type
|
||||
};
|
||||
|
||||
gen
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};
|
||||
use crate::shared::enum_variant::{parse_variants, EnumVariant};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
|
||||
pub(crate) struct EnumModuleCodegen {
|
||||
pub variants: Vec<EnumVariant>,
|
||||
}
|
||||
|
||||
impl ModuleCodegen for EnumModuleCodegen {
|
||||
type RecordCodegen = EnumModuleRecordCodegen;
|
||||
|
||||
fn gen_num_params(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|_| {
|
||||
quote! {
|
||||
burn::module::Module::<B>::num_params(module)
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn num_params(&self) -> usize {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_visit(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|_| {
|
||||
quote! {
|
||||
burn::module::Module::visit(module, visitor)
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_collect_devices(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|_| {
|
||||
quote! {
|
||||
burn::module::Module::<B>::collect_devices(module, devices)
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
devices: burn::module::Devices<B>
|
||||
) -> burn::module::Devices<B> {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_to_device(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(burn::module::Module::<B>::to_device(module, device))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_fork(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(burn::module::Module::<B>::fork(module, device))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_map(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(burn::module::Module::<B>::map(module, mapper))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_valid(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_into_record(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::Record::#variant(burn::module::Module::<B>::into_record(module))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn into_record(self) -> Self::Record {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_load_record(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
{
|
||||
let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");};
|
||||
Self::#variant(burn::module::Module::<B>::load_record(module, r))
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_clone(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(module.clone())
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn clone(&self) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_codegen(self) -> Self::RecordCodegen {
|
||||
EnumModuleRecordCodegen::new(self.variants)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnumModuleCodegen {
|
||||
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
Self {
|
||||
variants: parse_variants(ast),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate the enum variants' match arm with the provided function
|
||||
fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let mut match_arms = quote! {};
|
||||
|
||||
for variant in self.variants.iter() {
|
||||
let name = &variant.ident;
|
||||
let arm_pattern = quote! {Self::#name(module)};
|
||||
let arm_code = func(name.clone());
|
||||
|
||||
match_arms.extend(quote! {#arm_pattern => #arm_code,})
|
||||
}
|
||||
|
||||
quote! {
|
||||
match self {
|
||||
#match_arms
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
pub(crate) mod codegen;
|
||||
pub(crate) mod codegen_enum;
|
||||
pub(crate) mod codegen_struct;
|
||||
pub(crate) mod display;
|
||||
pub(crate) mod record;
|
||||
pub(crate) mod record_enum;
|
||||
pub(crate) mod record_struct;
|
||||
|
||||
mod base;
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
use crate::shared::enum_variant::EnumVariant;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::Generics;
|
||||
|
||||
use super::record::ModuleRecordCodegen;
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct EnumModuleRecordCodegen {
|
||||
variants: Vec<EnumVariant>,
|
||||
}
|
||||
|
||||
impl ModuleRecordCodegen for EnumModuleRecordCodegen {
|
||||
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
|
||||
let mut variants = quote! {};
|
||||
|
||||
// Capture the Record enum variant types
|
||||
for variant in self.variants.iter() {
|
||||
let ty = &variant.ty;
|
||||
let name = &variant.ident;
|
||||
|
||||
variants.extend(quote! {
|
||||
/// The module record associative type.
|
||||
#name(<#ty as burn::module::Module<B>>::Record),
|
||||
});
|
||||
}
|
||||
|
||||
let (generics, _generics_ty, generics_where) = generics.split_for_impl();
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record type for the module.
|
||||
#[derive(burn::record::Record)]
|
||||
pub enum #record_name #generics #generics_where {
|
||||
#variants
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,122 +1,13 @@
|
|||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Generics};
|
||||
|
||||
use super::{codegen::RecordItemCodegen, codegen_struct::StructRecordItemCodegen};
|
||||
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
|
||||
use super::{
|
||||
codegen::generate_record,
|
||||
item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen},
|
||||
};
|
||||
|
||||
pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream {
|
||||
let record_gen = RecordDeriveCodegen::from_ast(ast);
|
||||
let item_struct = record_gen.gen_record_type();
|
||||
let record_impl = record_gen.gen_impl_record();
|
||||
|
||||
quote! {
|
||||
#item_struct
|
||||
#record_impl
|
||||
match &ast.data {
|
||||
syn::Data::Struct(_) => generate_record::<StructRecordItemCodegen>(ast),
|
||||
syn::Data::Enum(_) => generate_record::<EnumRecordItemCodegen>(ast),
|
||||
syn::Data::Union(_) => panic!("Union modules aren't supported yet."),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
struct RecordDeriveCodegen {
|
||||
name_record: Ident,
|
||||
name_item: Ident,
|
||||
gen: StructRecordItemCodegen,
|
||||
generics: Generics,
|
||||
has_backend: bool,
|
||||
}
|
||||
|
||||
impl RecordDeriveCodegen {
|
||||
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
let name_record = ast.ident.clone();
|
||||
let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span());
|
||||
let has_backend = ast
|
||||
.generics
|
||||
.type_params()
|
||||
.map(|param| param.ident == "B")
|
||||
.reduce(|accum, is_backend| is_backend || accum)
|
||||
.unwrap_or(false);
|
||||
|
||||
Self {
|
||||
name_record,
|
||||
name_item,
|
||||
gen: StructRecordItemCodegen::new(
|
||||
parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.collect(),
|
||||
),
|
||||
generics: ast.generics.clone(),
|
||||
has_backend,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate the record type with the correct generics.
|
||||
pub(crate) fn gen_record_type(&self) -> TokenStream {
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.generics.clone();
|
||||
|
||||
for param in param.params.into_iter() {
|
||||
generics.params.push(param);
|
||||
}
|
||||
|
||||
self.gen
|
||||
.gen_item_type(&self.name_item, &generics, self.has_backend)
|
||||
}
|
||||
|
||||
/// Generate the implementation for the Record trait.
|
||||
pub(crate) fn gen_impl_record(&self) -> TokenStream {
|
||||
let name = &self.name_record;
|
||||
let item_generics = self.record_item_generics();
|
||||
let (_, ty_generics_item, _) = item_generics.split_for_impl();
|
||||
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
|
||||
|
||||
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
|
||||
impl_generic
|
||||
} else {
|
||||
quote! { #impl_generics }
|
||||
};
|
||||
|
||||
let name_item = &self.name_item;
|
||||
let into_item_fn = self.gen.gen_into_item(name_item);
|
||||
let from_item_fn = self.gen.gen_from_item();
|
||||
|
||||
quote! {
|
||||
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
|
||||
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
|
||||
|
||||
#into_item_fn
|
||||
#from_item_fn
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_generics(&self) -> Option<TokenStream> {
|
||||
if self.has_backend {
|
||||
return None;
|
||||
}
|
||||
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
let mut generics = self.generics.clone();
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
|
||||
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
|
||||
|
||||
Some(quote! {#impl_generics})
|
||||
}
|
||||
|
||||
fn record_item_generics(&self) -> Generics {
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.generics.clone();
|
||||
for param in param.params.into_iter() {
|
||||
generics.params.push(param);
|
||||
}
|
||||
|
||||
if !self.has_backend {
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,140 @@
|
|||
use proc_macro2::{Ident, TokenStream};
|
||||
use syn::Generics;
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Generics};
|
||||
|
||||
/// Basic trait to be implemented for record generation.
|
||||
pub(crate) trait RecordItemCodegen {
|
||||
/// Generate the record item type (i.e a struct)
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream;
|
||||
/// Generate the into_item function.
|
||||
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
|
||||
/// Generate the from item function.
|
||||
fn gen_from_item(&self) -> TokenStream;
|
||||
use crate::record::item::codegen::RecordItemCodegen;
|
||||
|
||||
pub(crate) fn generate_record<G: RecordItemCodegen>(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let record_gen: RecordCodegen<G> = RecordCodegen::from_ast(ast);
|
||||
let item_type = record_gen.gen_record_type();
|
||||
let record_impl = record_gen.gen_impl_record();
|
||||
|
||||
quote! {
|
||||
#item_type
|
||||
#record_impl
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RecordCodegen<G: RecordItemCodegen> {
|
||||
/// Record type info.
|
||||
ty: RecordType,
|
||||
/// Record item code gen.
|
||||
gen: G,
|
||||
}
|
||||
|
||||
impl<G: RecordItemCodegen> RecordCodegen<G> {
|
||||
/// Generate the record type with the correct generics.
|
||||
pub(crate) fn gen_record_type(&self) -> TokenStream {
|
||||
// Add precision settings type bound
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.ty.generics.clone();
|
||||
|
||||
for param in param.params.into_iter() {
|
||||
generics.params.push(param);
|
||||
}
|
||||
|
||||
// Generate the record item definition
|
||||
self.gen
|
||||
.gen_item_type(&self.ty.item, &generics, self.ty.has_backend)
|
||||
}
|
||||
|
||||
/// Generate the implementation for the Record trait.
|
||||
pub(crate) fn gen_impl_record(&self) -> TokenStream {
|
||||
// Capture the record type's generics and bounds in where clauses
|
||||
let item_generics = self.record_item_generics();
|
||||
let (_, ty_generics_item, _) = item_generics.split_for_impl();
|
||||
let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl();
|
||||
|
||||
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
|
||||
impl_generic
|
||||
} else {
|
||||
quote! { #impl_generics }
|
||||
};
|
||||
|
||||
let name_item = &self.ty.item;
|
||||
let into_item_fn = self.gen.gen_into_item(name_item);
|
||||
let from_item_fn = self.gen.gen_from_item();
|
||||
|
||||
// Return the generated stream of token trees (i.e., code to be generated)
|
||||
let name = &self.ty.name;
|
||||
quote! {
|
||||
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
|
||||
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
|
||||
|
||||
#into_item_fn
|
||||
#from_item_fn
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add backend generic type to the implementation block.
|
||||
fn impl_generics(&self) -> Option<TokenStream> {
|
||||
if self.ty.has_backend {
|
||||
return None;
|
||||
}
|
||||
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
let mut generics = self.ty.generics.clone();
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
|
||||
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
|
||||
|
||||
Some(quote! {#impl_generics})
|
||||
}
|
||||
|
||||
/// Get the generics attached to the record item type.
|
||||
fn record_item_generics(&self) -> Generics {
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.ty.generics.clone();
|
||||
for param in param.params.into_iter() {
|
||||
generics.params.push(param);
|
||||
}
|
||||
|
||||
if !self.ty.has_backend {
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
|
||||
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
Self {
|
||||
ty: RecordType::from_ast(ast),
|
||||
gen: G::from_ast(ast),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a record type.
|
||||
struct RecordType {
|
||||
/// Record type name.
|
||||
name: Ident,
|
||||
/// Record item type name.
|
||||
item: Ident,
|
||||
/// Lifetimes and type parameters attached to the record type declaration.
|
||||
generics: Generics,
|
||||
/// Whether or not the record type should specify a backend generic.
|
||||
has_backend: bool,
|
||||
}
|
||||
|
||||
impl RecordType {
|
||||
fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
let name = ast.ident.clone();
|
||||
let item = Ident::new(format!("{}Item", name).as_str(), name.span());
|
||||
let has_backend = ast
|
||||
.generics
|
||||
.type_params()
|
||||
.map(|param| param.ident == "B")
|
||||
.reduce(|accum, is_backend| is_backend || accum)
|
||||
.unwrap_or(false);
|
||||
|
||||
Self {
|
||||
name,
|
||||
item,
|
||||
generics: ast.generics.clone(),
|
||||
has_backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
use proc_macro2::{Ident, TokenStream};
|
||||
use syn::Generics;
|
||||
|
||||
/// Basic trait to be implemented for record generation.
|
||||
pub(crate) trait RecordItemCodegen {
|
||||
/// Initialize the record item.
|
||||
fn from_ast(ast: &syn::DeriveInput) -> Self;
|
||||
/// Generate the record item type.
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream;
|
||||
/// Generate the into_item function.
|
||||
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
|
||||
/// Generate the from item function.
|
||||
fn gen_from_item(&self) -> TokenStream;
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
use crate::shared::enum_variant::{parse_variants, EnumVariant};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Generics};
|
||||
|
||||
use super::codegen::RecordItemCodegen;
|
||||
|
||||
pub(crate) struct EnumRecordItemCodegen {
|
||||
/// Enum variants.
|
||||
variants: Vec<EnumVariant>,
|
||||
}
|
||||
|
||||
impl RecordItemCodegen for EnumRecordItemCodegen {
|
||||
fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
Self {
|
||||
variants: parse_variants(ast),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream {
|
||||
let mut variants = quote! {};
|
||||
let mut bounds = quote! {};
|
||||
|
||||
// Capture the Record enum variant types and names to transpose them in RecordItem
|
||||
for variant in self.variants.iter() {
|
||||
let ty = &variant.ty;
|
||||
let name = &variant.ident;
|
||||
|
||||
variants.extend(quote! {
|
||||
/// Variant to be serialized.
|
||||
#name(<#ty as burn::record::Record<B>>::Item<S>),
|
||||
});
|
||||
|
||||
// Item types must implement serialization/deserialization
|
||||
bounds.extend(quote! {
|
||||
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
});
|
||||
}
|
||||
let bound = bounds.to_string();
|
||||
|
||||
// Capture the type's generics and bounds in where clauses
|
||||
let (generics, generics_where) = if !has_backend {
|
||||
let mut generics = generics.clone();
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
let (generics, _, generics_where) = generics.split_for_impl();
|
||||
(quote! { #generics }, quote! { #generics_where })
|
||||
} else {
|
||||
let (generics, _, generics_where) = generics.split_for_impl();
|
||||
(quote! { #generics }, quote! { #generics_where })
|
||||
};
|
||||
|
||||
// Return the generated stream of token trees (i.e., code to be generated)
|
||||
quote! {
|
||||
|
||||
/// The record item type for the module.
|
||||
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#[serde(bound = #bound)]
|
||||
pub enum #item_name #generics #generics_where {
|
||||
#variants
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_into_item(&self, _item_name: &Ident) -> TokenStream {
|
||||
let mut into_item_match_arms = quote! {};
|
||||
|
||||
for variant in self.variants.iter() {
|
||||
let name = &variant.ident;
|
||||
|
||||
into_item_match_arms.extend(quote! {
|
||||
Self::#name(record) => Self::Item::#name(burn::record::Record::<B>::into_item::<S>(record)),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
|
||||
match self {
|
||||
#into_item_match_arms
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_from_item(&self) -> TokenStream {
|
||||
let mut from_item_match_arms = quote! {};
|
||||
|
||||
for variant in self.variants.iter() {
|
||||
let name = &variant.ident;
|
||||
|
||||
from_item_match_arms.extend(quote! {
|
||||
Self::Item::#name(item) => Self::#name(burn::record::Record::<B>::from_item::<S>(item, device)),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
#from_item_match_arms
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,16 +1,24 @@
|
|||
use crate::shared::field::FieldTypeAnalyzer;
|
||||
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Generics};
|
||||
|
||||
use super::codegen::RecordItemCodegen;
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct StructRecordItemCodegen {
|
||||
fields: Vec<FieldTypeAnalyzer>,
|
||||
}
|
||||
|
||||
impl RecordItemCodegen for StructRecordItemCodegen {
|
||||
fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
Self {
|
||||
fields: parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
|
@ -0,0 +1,3 @@
|
|||
pub(crate) mod codegen;
|
||||
pub(crate) mod codegen_enum;
|
||||
pub(crate) mod codegen_struct;
|
|
@ -1,5 +1,5 @@
|
|||
pub(crate) mod codegen;
|
||||
pub(crate) mod codegen_struct;
|
||||
pub(crate) mod item;
|
||||
|
||||
mod base;
|
||||
pub(crate) use base::*;
|
||||
|
|
|
@ -49,3 +49,32 @@ where
|
|||
syn::Fields::Unit => (quote! {}, quote! {}),
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum variant (simplified).
|
||||
pub(crate) struct EnumVariant {
|
||||
pub ident: syn::Ident,
|
||||
pub ty: syn::Type,
|
||||
}
|
||||
pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> Vec<EnumVariant> {
|
||||
let mut variants = Vec::new();
|
||||
|
||||
if let syn::Data::Enum(enum_data) = &ast.data {
|
||||
for variant in enum_data.variants.iter() {
|
||||
if variant.fields.len() != 1 {
|
||||
// No support for unit variants or variants with multiple fields
|
||||
panic!("Enums are only supported for one field type")
|
||||
}
|
||||
|
||||
let field = variant.fields.iter().next().unwrap();
|
||||
|
||||
variants.push(EnumVariant {
|
||||
ident: variant.ident.clone(),
|
||||
ty: field.ty.clone(),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
panic!("Only enum can be derived")
|
||||
}
|
||||
|
||||
variants
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue