Add enum module support (#1337)

This commit is contained in:
Guillaume Lagrange 2024-02-21 17:03:34 -05:00 committed by GitHub
parent 4427768570
commit bff4961426
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 621 additions and 137 deletions

View File

@ -34,6 +34,23 @@ struct ModuleWithGenericModule<B: Backend, M> {
_backend: PhantomData<B>,
}
#[derive(Module, Debug)]
enum ModuleEnum<B: Backend> {
Basic(ModuleBasic<B>),
Composed(ModuleComposed<B>),
}
#[derive(Module, Debug)]
enum ModuleEnumNested<B: Backend> {
AnotherEnum(ModuleEnum<B>),
}
#[derive(Module, Debug)]
enum ModuleEnumWithGenericModule<B: Backend, M: Module<B>> {
Basic(ModuleBasic<B>),
Generic(ModuleWithGenericModule<B, M>),
}
#[derive(Module, Debug)]
pub struct ModuleComposed<B: Backend> {
weight: Param<Tensor<B, 2>>,
@ -95,6 +112,46 @@ mod state {
module_2.basic.weight_basic.to_data()
);
}
#[test]
fn should_load_from_record_enum() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let mut module_2 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let state_1 = module_1.clone().into_record();
let ModuleEnum::Basic(module_1_basic) = module_1 else {
panic!("Invalid module type")
};
let ModuleEnum::Basic(module_2_basic) = module_2.clone() else {
panic!("Invalid module type")
};
assert_ne!(
module_1_basic.weight_basic.to_data(),
module_2_basic.weight_basic.to_data()
);
module_2 = module_2.load_record(state_1);
let ModuleEnum::Basic(module_2_basic) = module_2 else {
panic!("Invalid module type")
};
assert_eq!(
module_1_basic.weight_basic.to_data(),
module_2_basic.weight_basic.to_data()
);
}
#[test]
#[should_panic(expected = "Can't parse record from a different variant")]
fn should_panic_load_from_incorrect_enum_variant() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
let module_2 = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
let state_1 = module_1.clone().into_record();
module_2.load_record(state_1);
}
}
mod num_params {
@ -113,6 +170,16 @@ mod num_params {
let module = ModuleComposed::<TestBackend>::new(&device);
assert_eq!(4 * 20 * 20, module.num_params());
}
#[test]
fn should_calculate_num_params_enum() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));
assert_eq!(20 * 20, module.num_params());
let module = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));
assert_eq!(4 * 20 * 20, module.num_params());
}
}
#[cfg(feature = "std")]

View File

@ -1,5 +1,6 @@
use super::{
codegen::{generate_module_const, generate_module_standard},
codegen_enum::EnumModuleCodegen,
codegen_struct::StructModuleCodegen,
};
use proc_macro::TokenStream;
@ -22,7 +23,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
}
syn::Data::Enum(_data) => {
if has_backend {
panic!("Enum modules aren't supported yet.")
generate_module_standard(ast, EnumModuleCodegen::from_ast(ast))
} else {
generate_module_const(ast)
}

View File

@ -45,7 +45,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
let record = codegen.record_codegen();
let record_name = Ident::new(format!("{}Record", name).as_str(), name.span());
let record_struct = record.gen_record_type(&record_name, &generics.module);
let record_type = record.gen_record_type(&record_name, &generics.module);
let (generics_module, generics_ty_module, generics_where_module) =
generics.module.split_for_impl();
@ -86,7 +86,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#clone_fn
}
#record_struct
#record_type
};
gen

View File

@ -0,0 +1,192 @@
use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};
use crate::shared::enum_variant::{parse_variants, EnumVariant};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
pub(crate) struct EnumModuleCodegen {
pub variants: Vec<EnumVariant>,
}
impl ModuleCodegen for EnumModuleCodegen {
type RecordCodegen = EnumModuleRecordCodegen;
fn gen_num_params(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::<B>::num_params(module)
}
});
quote! {
fn num_params(&self) -> usize {
#match_body
}
}
}
fn gen_visit(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::visit(module, visitor)
}
});
quote! {
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
#match_body
}
}
}
fn gen_collect_devices(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::<B>::collect_devices(module, devices)
}
});
quote! {
fn collect_devices(
&self,
devices: burn::module::Devices<B>
) -> burn::module::Devices<B> {
#match_body
}
}
}
fn gen_to_device(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::to_device(module, device))
}
});
quote! {
fn to_device(self, device: &B::Device) -> Self {
#match_body
}
}
}
fn gen_fork(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::fork(module, device))
}
});
quote! {
fn fork(self, device: &B::Device) -> Self {
#match_body
}
}
}
fn gen_map(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::map(module, mapper))
}
});
quote! {
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
#match_body
}
}
}
fn gen_valid(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module))
}
});
quote! {
fn valid(&self) -> Self::InnerModule {
#match_body
}
}
}
fn gen_into_record(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::Record::#variant(burn::module::Module::<B>::into_record(module))
}
});
quote! {
fn into_record(self) -> Self::Record {
#match_body
}
}
}
fn gen_load_record(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
{
let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");};
Self::#variant(burn::module::Module::<B>::load_record(module, r))
}
}
});
quote! {
fn load_record(self, record: Self::Record) -> Self {
#match_body
}
}
}
fn gen_clone(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(module.clone())
}
});
quote! {
fn clone(&self) -> Self {
#match_body
}
}
}
fn record_codegen(self) -> Self::RecordCodegen {
EnumModuleRecordCodegen::new(self.variants)
}
}
impl EnumModuleCodegen {
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
variants: parse_variants(ast),
}
}
/// Generate the enum variants' match arm with the provided function
fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream
where
F: Fn(Ident) -> TokenStream,
{
let mut match_arms = quote! {};
for variant in self.variants.iter() {
let name = &variant.ident;
let arm_pattern = quote! {Self::#name(module)};
let arm_code = func(name.clone());
match_arms.extend(quote! {#arm_pattern => #arm_code,})
}
quote! {
match self {
#match_arms
}
}
}
}

View File

@ -1,7 +1,9 @@
pub(crate) mod codegen;
pub(crate) mod codegen_enum;
pub(crate) mod codegen_struct;
pub(crate) mod display;
pub(crate) mod record;
pub(crate) mod record_enum;
pub(crate) mod record_struct;
mod base;

View File

@ -0,0 +1,39 @@
use crate::shared::enum_variant::EnumVariant;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Generics;
use super::record::ModuleRecordCodegen;
#[derive(new)]
pub(crate) struct EnumModuleRecordCodegen {
variants: Vec<EnumVariant>,
}
impl ModuleRecordCodegen for EnumModuleRecordCodegen {
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
let mut variants = quote! {};
// Capture the Record enum variant types
for variant in self.variants.iter() {
let ty = &variant.ty;
let name = &variant.ident;
variants.extend(quote! {
/// The module record associative type.
#name(<#ty as burn::module::Module<B>>::Record),
});
}
let (generics, _generics_ty, generics_where) = generics.split_for_impl();
quote! {
/// The record type for the module.
#[derive(burn::record::Record)]
pub enum #record_name #generics #generics_where {
#variants
}
}
}
}

View File

@ -1,122 +1,13 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use super::{codegen::RecordItemCodegen, codegen_struct::StructRecordItemCodegen};
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use super::{
codegen::generate_record,
item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen},
};
pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream {
let record_gen = RecordDeriveCodegen::from_ast(ast);
let item_struct = record_gen.gen_record_type();
let record_impl = record_gen.gen_impl_record();
quote! {
#item_struct
#record_impl
match &ast.data {
syn::Data::Struct(_) => generate_record::<StructRecordItemCodegen>(ast),
syn::Data::Enum(_) => generate_record::<EnumRecordItemCodegen>(ast),
syn::Data::Union(_) => panic!("Union modules aren't supported yet."),
}
.into()
}
struct RecordDeriveCodegen {
name_record: Ident,
name_item: Ident,
gen: StructRecordItemCodegen,
generics: Generics,
has_backend: bool,
}
impl RecordDeriveCodegen {
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self {
let name_record = ast.ident.clone();
let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span());
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
Self {
name_record,
name_item,
gen: StructRecordItemCodegen::new(
parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
),
generics: ast.generics.clone(),
has_backend,
}
}
/// Generate the record type with the correct generics.
pub(crate) fn gen_record_type(&self) -> TokenStream {
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.generics.clone();
for param in param.params.into_iter() {
generics.params.push(param);
}
self.gen
.gen_item_type(&self.name_item, &generics, self.has_backend)
}
/// Generate the implementation for the Record trait.
pub(crate) fn gen_impl_record(&self) -> TokenStream {
let name = &self.name_record;
let item_generics = self.record_item_generics();
let (_, ty_generics_item, _) = item_generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
impl_generic
} else {
quote! { #impl_generics }
};
let name_item = &self.name_item;
let into_item_fn = self.gen.gen_into_item(name_item);
let from_item_fn = self.gen.gen_from_item();
quote! {
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
#into_item_fn
#from_item_fn
}
}
}
fn impl_generics(&self) -> Option<TokenStream> {
if self.has_backend {
return None;
}
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
let mut generics = self.generics.clone();
generics.params.push(syn::GenericParam::Type(param));
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
Some(quote! {#impl_generics})
}
fn record_item_generics(&self) -> Generics {
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.generics.clone();
for param in param.params.into_iter() {
generics.params.push(param);
}
if !self.has_backend {
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
}
generics
}
}

View File

@ -1,17 +1,140 @@
use proc_macro2::{Ident, TokenStream};
use syn::Generics;
use quote::quote;
use syn::{parse_quote, Generics};
/// Basic trait to be implemented for record generation.
pub(crate) trait RecordItemCodegen {
/// Generate the record item type (i.e a struct)
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream;
/// Generate the into_item function.
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
/// Generate the from item function.
fn gen_from_item(&self) -> TokenStream;
use crate::record::item::codegen::RecordItemCodegen;
pub(crate) fn generate_record<G: RecordItemCodegen>(ast: &syn::DeriveInput) -> TokenStream {
let record_gen: RecordCodegen<G> = RecordCodegen::from_ast(ast);
let item_type = record_gen.gen_record_type();
let record_impl = record_gen.gen_impl_record();
quote! {
#item_type
#record_impl
}
}
pub(crate) struct RecordCodegen<G: RecordItemCodegen> {
/// Record type info.
ty: RecordType,
/// Record item code gen.
gen: G,
}
impl<G: RecordItemCodegen> RecordCodegen<G> {
/// Generate the record type with the correct generics.
pub(crate) fn gen_record_type(&self) -> TokenStream {
// Add precision settings type bound
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.ty.generics.clone();
for param in param.params.into_iter() {
generics.params.push(param);
}
// Generate the record item definition
self.gen
.gen_item_type(&self.ty.item, &generics, self.ty.has_backend)
}
/// Generate the implementation for the Record trait.
pub(crate) fn gen_impl_record(&self) -> TokenStream {
// Capture the record type's generics and bounds in where clauses
let item_generics = self.record_item_generics();
let (_, ty_generics_item, _) = item_generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl();
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
impl_generic
} else {
quote! { #impl_generics }
};
let name_item = &self.ty.item;
let into_item_fn = self.gen.gen_into_item(name_item);
let from_item_fn = self.gen.gen_from_item();
// Return the generated stream of token trees (i.e., code to be generated)
let name = &self.ty.name;
quote! {
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
#into_item_fn
#from_item_fn
}
}
}
/// Add backend generic type to the implementation block.
fn impl_generics(&self) -> Option<TokenStream> {
if self.ty.has_backend {
return None;
}
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
let mut generics = self.ty.generics.clone();
generics.params.push(syn::GenericParam::Type(param));
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
Some(quote! {#impl_generics})
}
/// Get the generics attached to the record item type.
fn record_item_generics(&self) -> Generics {
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.ty.generics.clone();
for param in param.params.into_iter() {
generics.params.push(param);
}
if !self.ty.has_backend {
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
}
generics
}
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
ty: RecordType::from_ast(ast),
gen: G::from_ast(ast),
}
}
}
/// Information about a record type.
struct RecordType {
/// Record type name.
name: Ident,
/// Record item type name.
item: Ident,
/// Lifetimes and type parameters attached to the record type declaration.
generics: Generics,
/// Whether or not the record type should specify a backend generic.
has_backend: bool,
}
impl RecordType {
fn from_ast(ast: &syn::DeriveInput) -> Self {
let name = ast.ident.clone();
let item = Ident::new(format!("{}Item", name).as_str(), name.span());
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
Self {
name,
item,
generics: ast.generics.clone(),
has_backend,
}
}
}

View File

@ -0,0 +1,19 @@
use proc_macro2::{Ident, TokenStream};
use syn::Generics;
/// Basic trait to be implemented for record generation.
pub(crate) trait RecordItemCodegen {
/// Initialize the record item.
fn from_ast(ast: &syn::DeriveInput) -> Self;
/// Generate the record item type.
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream;
/// Generate the into_item function.
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
/// Generate the from item function.
fn gen_from_item(&self) -> TokenStream;
}

View File

@ -0,0 +1,110 @@
use crate::shared::enum_variant::{parse_variants, EnumVariant};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use super::codegen::RecordItemCodegen;
pub(crate) struct EnumRecordItemCodegen {
/// Enum variants.
variants: Vec<EnumVariant>,
}
impl RecordItemCodegen for EnumRecordItemCodegen {
fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
variants: parse_variants(ast),
}
}
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream {
let mut variants = quote! {};
let mut bounds = quote! {};
// Capture the Record enum variant types and names to transpose them in RecordItem
for variant in self.variants.iter() {
let ty = &variant.ty;
let name = &variant.ident;
variants.extend(quote! {
/// Variant to be serialized.
#name(<#ty as burn::record::Record<B>>::Item<S>),
});
// Item types must implement serialization/deserialization
bounds.extend(quote! {
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
});
}
let bound = bounds.to_string();
// Capture the type's generics and bounds in where clauses
let (generics, generics_where) = if !has_backend {
let mut generics = generics.clone();
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
let (generics, _, generics_where) = generics.split_for_impl();
(quote! { #generics }, quote! { #generics_where })
} else {
let (generics, _, generics_where) = generics.split_for_impl();
(quote! { #generics }, quote! { #generics_where })
};
// Return the generated stream of token trees (i.e., code to be generated)
quote! {
/// The record item type for the module.
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#[serde(bound = #bound)]
pub enum #item_name #generics #generics_where {
#variants
}
}
}
fn gen_into_item(&self, _item_name: &Ident) -> TokenStream {
let mut into_item_match_arms = quote! {};
for variant in self.variants.iter() {
let name = &variant.ident;
into_item_match_arms.extend(quote! {
Self::#name(record) => Self::Item::#name(burn::record::Record::<B>::into_item::<S>(record)),
});
}
quote! {
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
match self {
#into_item_match_arms
}
}
}
}
fn gen_from_item(&self) -> TokenStream {
let mut from_item_match_arms = quote! {};
for variant in self.variants.iter() {
let name = &variant.ident;
from_item_match_arms.extend(quote! {
Self::Item::#name(item) => Self::#name(burn::record::Record::<B>::from_item::<S>(item, device)),
});
}
quote! {
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
#from_item_match_arms
}
}
}
}
}

View File

@ -1,16 +1,24 @@
use crate::shared::field::FieldTypeAnalyzer;
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use super::codegen::RecordItemCodegen;
#[derive(new)]
pub(crate) struct StructRecordItemCodegen {
fields: Vec<FieldTypeAnalyzer>,
}
impl RecordItemCodegen for StructRecordItemCodegen {
fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
fields: parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
}
}
fn gen_item_type(
&self,
item_name: &Ident,

View File

@ -0,0 +1,3 @@
pub(crate) mod codegen;
pub(crate) mod codegen_enum;
pub(crate) mod codegen_struct;

View File

@ -1,5 +1,5 @@
pub(crate) mod codegen;
pub(crate) mod codegen_struct;
pub(crate) mod item;
mod base;
pub(crate) use base::*;

View File

@ -49,3 +49,32 @@ where
syn::Fields::Unit => (quote! {}, quote! {}),
}
}
/// An enum variant (simplified).
pub(crate) struct EnumVariant {
pub ident: syn::Ident,
pub ty: syn::Type,
}
pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> Vec<EnumVariant> {
let mut variants = Vec::new();
if let syn::Data::Enum(enum_data) = &ast.data {
for variant in enum_data.variants.iter() {
if variant.fields.len() != 1 {
// No support for unit variants or variants with multiple fields
panic!("Enums are only supported for one field type")
}
let field = variant.fields.iter().next().unwrap();
variants.push(EnumVariant {
ident: variant.ident.clone(),
ty: field.ty.clone(),
});
}
} else {
panic!("Only enum can be derived")
}
variants
}