Feat/config derive (#38)

This commit is contained in:
Nathaniel Simard 2022-09-18 15:56:12 -04:00 committed by GitHub
parent b9f833767f
commit 21757ee534
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 374 additions and 131 deletions

View File

@ -0,0 +1,157 @@
use crate::shared::{
attribute::AttributeItem,
field::{parse_fields, FieldTypeAnalyzer},
};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
pub(crate) fn config_attr_impl(item: &syn::DeriveInput) -> TokenStream {
let name = item.ident.clone();
let fields = parse_fields(item);
let fields = fields.into_iter().map(FieldTypeAnalyzer::new).collect();
let config = Config { name, fields }.analyze();
let constructor = config.gen_constructor_impl();
let builders = config.gen_builder_fn_impl();
quote! {
#constructor
#builders
}
}
struct Config {
name: Ident,
fields: Vec<FieldTypeAnalyzer>,
}
struct ConfigAnalyzer {
name: Ident,
fields_required: Vec<FieldTypeAnalyzer>,
fields_option: Vec<FieldTypeAnalyzer>,
fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,
}
impl ConfigAnalyzer {
fn gen_constructor_impl(&self) -> TokenStream {
let mut body = quote! {};
let mut names = Vec::new();
for field in self.fields_required.iter() {
let name = field.ident();
let ty = &field.field.ty;
body.extend(quote! {
#name: #name,
});
names.push(quote! {
#name: #ty
});
}
for field in self.fields_option.iter() {
let name = field.ident();
body.extend(quote! {
#name: None,
});
}
for (field, attribute) in self.fields_default.iter() {
let name = field.ident();
let value = &attribute.value;
body.extend(quote! {
#name: #value,
});
}
let body = quote! {
pub fn new(
#(#names),*
) -> Self {
Self { #body }
}
};
self.wrap_impl_block(body)
}
fn gen_builder_fn_impl(&self) -> TokenStream {
let mut body = quote! {};
for (field, _) in self.fields_default.iter() {
let name = field.ident();
let ty = &field.field.ty;
let fn_name = Ident::new(&format!("with_{}", name), name.span());
body.extend(quote! {
pub fn #fn_name(mut self, #name: #ty) -> Self {
self.#name = #name;
self
}
});
}
for field in self.fields_option.iter() {
let name = field.ident();
let ty = &field.field.ty;
let fn_name = Ident::new(&format!("with_{}", name), name.span());
body.extend(quote! {
pub fn #fn_name(mut self, #name: #ty) -> Self {
self.#name = #name;
self
}
});
}
self.wrap_impl_block(body)
}
fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream {
let name = &self.name;
quote! {
impl #name {
#tokens
}
}
}
}
impl Config {
fn analyze(&self) -> ConfigAnalyzer {
let mut fields_required = Vec::new();
let mut fields_option = Vec::new();
let mut fields_default = Vec::new();
for field in self.fields.iter() {
let attributes: Vec<AttributeItem> = field
.attributes()
.filter(|attr| attr.has_name("config"))
.map(|attr| attr.items())
.filter_map(|attr| attr.first().map(Clone::clone))
.collect();
if !attributes.is_empty() {
let item = attributes.first().unwrap().clone();
fields_default.push((field.clone(), item));
continue;
}
if field.is_of_type(&["Option"]) {
fields_option.push(field.clone());
continue;
}
fields_required.push(field.clone());
}
ConfigAnalyzer {
name: self.name.clone(),
fields_required,
fields_option,
fields_default,
}
}
}

View File

@ -0,0 +1,3 @@
mod base;
pub(crate) use base::*;

View File

@ -1,63 +1,20 @@
use proc_macro::TokenStream;
use quote::quote;
pub(crate) mod field;
pub(crate) mod config;
pub(crate) mod module;
pub(crate) mod shared;
mod display;
mod param;
use param::Param;
use config::config_attr_impl;
use module::module_derive_impl;
#[proc_macro_derive(Module)]
pub fn module_derive(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
module_derive_impl(&ast)
let input = syn::parse(input).unwrap();
module_derive_impl(&input)
}
fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let display_fn = display::display_fn();
let name_fn = display::name_fn(name);
let param = Param::from_ast(ast);
let num_params_fn = param.gen_num_params_fn();
let update_params_fn = param.gen_update_params_fn();
let load_optim_state = param.gen_load_optim_state_fn();
let register_optim_state = param.gen_register_optim_state_fn();
let devices_fn = param.gen_devices_fn();
let to_device_fn = param.gen_to_device_fn();
let state_fn = param.gen_state_fn();
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
type Backend=B;
#name_fn
#num_params_fn
#update_params_fn
#load_optim_state
#register_optim_state
#devices_fn
#to_device_fn
#state_fn
#load_fn
}
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type ADBackend=B;
type InnerModule=#name<B::InnerBackend>;
#inner_fn
}
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
};
gen.into()
#[proc_macro_derive(Config, attributes(config))]
pub fn config_derive(input: TokenStream) -> TokenStream {
let item = syn::parse(input).unwrap();
config_attr_impl(&item).into()
}

View File

@ -0,0 +1,52 @@
use super::param::Param;
use crate::module::display;
use proc_macro::TokenStream;
use quote::quote;
pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let display_fn = display::display_fn();
let name_fn = display::name_fn(name);
let param = Param::from_ast(ast);
let num_params_fn = param.gen_num_params_fn();
let update_params_fn = param.gen_update_params_fn();
let load_optim_state = param.gen_load_optim_state_fn();
let register_optim_state = param.gen_register_optim_state_fn();
let devices_fn = param.gen_devices_fn();
let to_device_fn = param.gen_to_device_fn();
let state_fn = param.gen_state_fn();
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
type Backend=B;
#name_fn
#num_params_fn
#update_params_fn
#load_optim_state
#register_optim_state
#devices_fn
#to_device_fn
#state_fn
#load_fn
}
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type ADBackend=B;
type InnerModule=#name<B::InnerBackend>;
#inner_fn
}
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
};
gen.into()
}

View File

@ -0,0 +1,6 @@
pub(crate) mod display;
pub(crate) mod param;
mod base;
pub(crate) use base::*;

View File

@ -1,7 +1,6 @@
use crate::field::FieldTypeAnalyzer;
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use proc_macro2::TokenStream;
use quote::quote;
use syn::Field;
pub struct Param {
fields_param: Vec<FieldTypeAnalyzer>,
@ -215,18 +214,3 @@ impl Param {
}
}
}
fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
let mut fields = Vec::new();
match &ast.data {
syn::Data::Struct(struct_data) => {
for field in struct_data.fields.iter() {
fields.push(field.clone());
}
}
syn::Data::Enum(_) => panic!("Only struct can be derived"),
syn::Data::Union(_) => panic!("Only struct cna be derived"),
};
fields
}

View File

@ -0,0 +1,58 @@
use syn::{Attribute, Ident, Meta, NestedMeta};
#[derive(Debug)]
pub struct AttributeAnalyzer {
attr: Attribute,
}
#[derive(Debug, Clone)]
pub struct AttributeItem {
pub ident: Ident,
pub value: syn::Lit,
}
impl AttributeAnalyzer {
pub fn new(attr: Attribute) -> Self {
Self { attr }
}
pub fn items(&self) -> Vec<AttributeItem> {
let config = match self.attr.parse_meta() {
Ok(val) => val,
_ => return Vec::new(),
};
let nested = match config {
Meta::List(val) => val.nested,
_ => return Vec::new(),
};
let mut output = Vec::new();
for pair in nested.into_iter() {
if let NestedMeta::Meta(Meta::NameValue(value)) = pair {
output.push(AttributeItem {
ident: value.path.get_ident().unwrap().clone(),
value: value.lit,
});
};
}
output
}
pub fn has_name(&self, name: &str) -> bool {
Self::path_syn_name(&self.attr.path) == name
}
fn path_syn_name(path: &syn::Path) -> String {
let length = path.segments.len();
let mut name = String::new();
for (i, segment) in path.segments.iter().enumerate() {
if i == length - 1 {
name += segment.ident.to_string().as_str();
} else {
let tmp = segment.ident.to_string() + "::";
name += tmp.as_str();
}
}
name
}
}

View File

@ -1,6 +1,8 @@
use super::attribute::AttributeAnalyzer;
use proc_macro2::Ident;
use syn::{Field, Type, TypePath};
#[derive(Debug, Clone)]
pub struct FieldTypeAnalyzer {
pub field: Field,
}
@ -50,7 +52,7 @@ impl FieldTypeAnalyzer {
}
}
pub fn path_name(path: &TypePath) -> String {
fn path_name(path: &TypePath) -> String {
let length = path.path.segments.len();
let mut name = String::new();
for (i, segment) in path.path.segments.iter().enumerate() {
@ -64,8 +66,30 @@ impl FieldTypeAnalyzer {
name
}
pub fn attributes(&self) -> impl Iterator<Item = AttributeAnalyzer> {
self.field
.attrs
.clone()
.into_iter()
.map(AttributeAnalyzer::new)
}
pub fn is_param(&self) -> bool {
let params_types = vec!["Param", "burn::Param"];
self.is_of_type(&params_types)
self.is_of_type(&["Param", "burn::Param"])
}
}
pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
let mut fields = Vec::new();
match &ast.data {
syn::Data::Struct(struct_data) => {
for field in struct_data.fields.iter() {
fields.push(field.clone());
}
}
syn::Data::Enum(_) => panic!("Only struct can be derived"),
syn::Data::Union(_) => panic!("Only struct cna be derived"),
};
fields
}

View File

@ -0,0 +1,2 @@
pub(crate) mod attribute;
pub(crate) mod field;

View File

@ -2,7 +2,6 @@ use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
use burn::module::{Forward, Module, Param, State};
use burn::nn;
use burn::optim::decay::WeightDecayConfig;
use burn::optim::momentum::MomentumConfig;
use burn::optim::{Optimizer, Sgd, SgdConfig};
@ -12,6 +11,7 @@ use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::train::logger::{AsyncLogger, CLILogger};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
use burn::{config, nn};
use std::sync::Arc;
#[derive(Module, Debug)]
@ -21,6 +21,17 @@ struct Model<B: Backend> {
output: Param<nn::Linear<B>>,
}
config!(
struct MlpConfig {
#[config(default = 4)]
num_layers: usize,
#[config(default = 0.2)]
dropout: f64,
#[config(default = 1024)]
dim: usize,
}
);
#[derive(Module, Debug)]
struct Mlp<B: Backend> {
linears: Param<Vec<nn::Linear<B>>>,
@ -69,42 +80,28 @@ impl<B: Backend> Forward<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
}
impl<B: Backend> Mlp<B> {
fn new(dim: usize, num_layers: usize) -> Self {
let mut linears = Vec::with_capacity(num_layers);
fn new(config: &MlpConfig) -> Self {
let mut linears = Vec::with_capacity(config.num_layers);
for _ in 0..num_layers {
let config = nn::LinearConfig {
d_input: dim,
d_output: dim,
bias: true,
};
let linear = nn::Linear::new(&config);
for _ in 0..config.num_layers {
let linear = nn::Linear::new(&nn::LinearConfig::new(config.dim, config.dim));
linears.push(linear);
}
Self {
linears: Param::new(linears),
dropout: nn::Dropout::new(&nn::DropoutConfig { prob: 0.3 }),
dropout: nn::Dropout::new(&nn::DropoutConfig::new(0.3)),
activation: nn::ReLU::new(),
}
}
}
impl<B: Backend> Model<B> {
fn new(d_input: usize, d_hidden: usize, num_layers: usize, num_classes: usize) -> Self {
let mlp = Mlp::new(d_hidden, num_layers);
let config_input = nn::LinearConfig {
d_input,
d_output: d_hidden,
bias: true,
};
let config_output = nn::LinearConfig {
d_input: d_hidden,
d_output: num_classes,
bias: true,
};
let output = nn::Linear::new(&config_output);
let input = nn::Linear::new(&config_input);
fn new(d_input: usize, num_classes: usize) -> Self {
let mlp_config = MlpConfig::new();
let mlp = Mlp::new(&mlp_config);
let output = nn::Linear::new(&nn::LinearConfig::new(mlp_config.dim, num_classes));
let input = nn::Linear::new(&nn::LinearConfig::new(d_input, mlp_config.dim));
Self {
mlp: Param::new(mlp),
@ -150,14 +147,12 @@ fn run<B: ADBackend>(device: B::Device) {
let batch_size = 128;
let num_epochs = 15;
let num_workers = 8;
let num_layers = 4;
let hidden_dim = 1024;
let seed = 42;
let state_model = State::<f32>::load("/tmp/mnist_state_model").ok();
let state_optim = State::<f32>::load("/tmp/mnist_state_optim").ok();
let mut model = Model::new(784, hidden_dim, num_layers, 10);
let mut model = Model::new(784, 10);
model.to_device(device);
if let Some(state) = state_model {
@ -165,16 +160,12 @@ fn run<B: ADBackend>(device: B::Device) {
model.load(&state.convert()).unwrap();
}
let mut optim = Sgd::new(&SgdConfig {
learning_rate: 2.5e-2,
weight_decay: Some(WeightDecayConfig { penalty: 0.01 }),
momentum: Some(MomentumConfig {
momentum: 0.9,
dampening: 0.1,
nesterov: true,
}),
});
let optim_config = SgdConfig::new()
.with_learning_rate(2.5e-2)
.with_weight_decay(Some(WeightDecayConfig::new(0.05)))
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)));
let mut optim = Sgd::new(&optim_config);
if let Some(state) = state_optim {
println!("Loading optimizer state");
optim.load(&model, &state.convert()).unwrap();

View File

@ -2,13 +2,14 @@
extern crate derive_new;
pub mod data;
pub mod macros;
pub mod module;
pub mod nn;
pub mod optim;
pub mod tensor;
pub mod train;
pub(crate) mod macros;
pub use burn_derive::Config;
#[cfg(test)]
pub type TestBackend = crate::tensor::backend::TchBackend<f32>;

View File

@ -1,8 +1,7 @@
#[macro_export]
macro_rules! config {
($item:item) => {
#[derive(new, serde::Serialize, serde::Deserialize, Clone, Debug)]
#[derive(burn::Config, serde::Serialize, serde::Deserialize, Clone, Debug)]
$item
};
}
pub(crate) use config;

View File

@ -1,4 +1,6 @@
use crate::macros::config;
use crate as burn;
use crate::config;
use crate::module::Forward;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, ElementConversion, Tensor};

View File

@ -1,6 +1,6 @@
use crate as burn;
use crate::macros::config;
use crate::config;
use crate::module::Module;
use crate::module::{Forward, Param};
use crate::tensor::backend::Backend;
@ -11,7 +11,8 @@ config!(
pub struct LayerNormConfig {
/// The size of the input features.
pub d_model: usize,
/// A value required for numerical stability, typically 1e-5.
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
pub epsilon: f64,
}
);
@ -61,10 +62,7 @@ mod tests {
#[test]
fn layer_norm_forward() {
let config = LayerNormConfig {
d_model: 10,
epsilon: 1e-5,
};
let config = LayerNormConfig::new(10);
let module = LayerNorm::<TestBackend>::new(&config);
let input = Tensor::from_data(Data::from([[
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
@ -82,10 +80,7 @@ mod tests {
#[test]
fn layer_norm_backward() {
let config = LayerNormConfig {
d_model: 2,
epsilon: 1e-5,
};
let config = LayerNormConfig::new(2);
let module = LayerNorm::<TestADBackend>::new(&config);
let tensor_1 = Tensor::<TestADBackend, 2>::from_data(Data::from([[0.0, 1.0], [3.0, 4.0]]));
let tensor_2 = Tensor::<TestADBackend, 2>::from_data(Data::from([[6.0, 7.0], [9.0, 10.0]]));

View File

@ -1,6 +1,6 @@
use crate as burn;
use crate::macros::config;
use crate::config;
use crate::module::Module;
use crate::module::{Forward, Param};
use crate::tensor::backend::Backend;
@ -15,6 +15,7 @@ config!(
/// The size of the output features.
pub d_output: usize,
/// If a bias should be applied during the linear transformation.
#[config(default = true)]
pub bias: bool,
}
);

View File

@ -1,4 +1,5 @@
use crate::macros::config;
use crate as burn;
use crate::config;
config!(
pub struct AdamConfig {

View File

@ -1,5 +1,7 @@
use crate as burn;
use super::{load_state_gradients, register_state_gradients};
use crate::macros::config;
use crate::config;
use crate::module::{ParamId, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::tensor::{ElementConversion, Gradients, Tensor};

View File

@ -1,5 +1,7 @@
use crate as burn;
use super::{load_state_gradients, register_state_gradients};
use crate::macros::config;
use crate::config;
use crate::module::{ParamId, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::tensor::{ElementConversion, Gradients, Tensor};
@ -8,10 +10,13 @@ config!(
/// Configuration to create momentum [Momentum](Momentum).
pub struct MomentumConfig {
/// Momemtum factor
#[config(default = 0.9)]
pub momentum: f64,
/// Dampening factor.
#[config(default = 0.1)]
pub dampening: f64,
/// Enables Nesterov momentum, see [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
#[config(default = false)]
pub nesterov: bool,
}
);

View File

@ -1,6 +1,8 @@
use crate as burn;
use super::decay::{WeightDecay, WeightDecayConfig};
use super::momentum::{Momentum, MomentumConfig};
use crate::macros::config;
use crate::config;
use crate::module::{ParamId, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::backend::ADBackend;
@ -10,6 +12,7 @@ config!(
/// Configuration to create the [Sgd](Sgd) optimizer.
pub struct SgdConfig {
/// Learning rate for the optimizer.
#[config(default = 0.01)]
pub learning_rate: f64,
/// [Weight decay](WeightDecayConfig) config.
pub weight_decay: Option<WeightDecayConfig>,