mirror of https://github.com/tracel-ai/burn.git
Feat/config derive (#38)
This commit is contained in:
parent
b9f833767f
commit
21757ee534
|
@ -0,0 +1,157 @@
|
|||
use crate::shared::{
|
||||
attribute::AttributeItem,
|
||||
field::{parse_fields, FieldTypeAnalyzer},
|
||||
};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
|
||||
pub(crate) fn config_attr_impl(item: &syn::DeriveInput) -> TokenStream {
|
||||
let name = item.ident.clone();
|
||||
let fields = parse_fields(item);
|
||||
let fields = fields.into_iter().map(FieldTypeAnalyzer::new).collect();
|
||||
let config = Config { name, fields }.analyze();
|
||||
|
||||
let constructor = config.gen_constructor_impl();
|
||||
let builders = config.gen_builder_fn_impl();
|
||||
|
||||
quote! {
|
||||
#constructor
|
||||
#builders
|
||||
}
|
||||
}
|
||||
|
||||
struct Config {
|
||||
name: Ident,
|
||||
fields: Vec<FieldTypeAnalyzer>,
|
||||
}
|
||||
|
||||
struct ConfigAnalyzer {
|
||||
name: Ident,
|
||||
fields_required: Vec<FieldTypeAnalyzer>,
|
||||
fields_option: Vec<FieldTypeAnalyzer>,
|
||||
fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,
|
||||
}
|
||||
|
||||
impl ConfigAnalyzer {
|
||||
fn gen_constructor_impl(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
let mut names = Vec::new();
|
||||
|
||||
for field in self.fields_required.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
|
||||
body.extend(quote! {
|
||||
#name: #name,
|
||||
});
|
||||
names.push(quote! {
|
||||
#name: #ty
|
||||
});
|
||||
}
|
||||
|
||||
for field in self.fields_option.iter() {
|
||||
let name = field.ident();
|
||||
|
||||
body.extend(quote! {
|
||||
#name: None,
|
||||
});
|
||||
}
|
||||
|
||||
for (field, attribute) in self.fields_default.iter() {
|
||||
let name = field.ident();
|
||||
let value = &attribute.value;
|
||||
|
||||
body.extend(quote! {
|
||||
#name: #value,
|
||||
});
|
||||
}
|
||||
|
||||
let body = quote! {
|
||||
pub fn new(
|
||||
#(#names),*
|
||||
) -> Self {
|
||||
Self { #body }
|
||||
}
|
||||
};
|
||||
self.wrap_impl_block(body)
|
||||
}
|
||||
|
||||
fn gen_builder_fn_impl(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
|
||||
for (field, _) in self.fields_default.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let fn_name = Ident::new(&format!("with_{}", name), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
self.#name = #name;
|
||||
self
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for field in self.fields_option.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let fn_name = Ident::new(&format!("with_{}", name), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
self.#name = #name;
|
||||
self
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
self.wrap_impl_block(body)
|
||||
}
|
||||
|
||||
fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream {
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl #name {
|
||||
#tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn analyze(&self) -> ConfigAnalyzer {
|
||||
let mut fields_required = Vec::new();
|
||||
let mut fields_option = Vec::new();
|
||||
let mut fields_default = Vec::new();
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let attributes: Vec<AttributeItem> = field
|
||||
.attributes()
|
||||
.filter(|attr| attr.has_name("config"))
|
||||
.map(|attr| attr.items())
|
||||
.filter_map(|attr| attr.first().map(Clone::clone))
|
||||
.collect();
|
||||
|
||||
if !attributes.is_empty() {
|
||||
let item = attributes.first().unwrap().clone();
|
||||
fields_default.push((field.clone(), item));
|
||||
continue;
|
||||
}
|
||||
|
||||
if field.is_of_type(&["Option"]) {
|
||||
fields_option.push(field.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
fields_required.push(field.clone());
|
||||
}
|
||||
|
||||
ConfigAnalyzer {
|
||||
name: self.name.clone(),
|
||||
fields_required,
|
||||
fields_option,
|
||||
fields_default,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod base;
|
||||
|
||||
pub(crate) use base::*;
|
|
@ -1,63 +1,20 @@
|
|||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
pub(crate) mod field;
|
||||
pub(crate) mod config;
|
||||
pub(crate) mod module;
|
||||
pub(crate) mod shared;
|
||||
|
||||
mod display;
|
||||
mod param;
|
||||
|
||||
use param::Param;
|
||||
use config::config_attr_impl;
|
||||
use module::module_derive_impl;
|
||||
|
||||
#[proc_macro_derive(Module)]
|
||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||
let ast = syn::parse(input).unwrap();
|
||||
module_derive_impl(&ast)
|
||||
let input = syn::parse(input).unwrap();
|
||||
module_derive_impl(&input)
|
||||
}
|
||||
|
||||
fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
|
||||
let display_fn = display::display_fn();
|
||||
let name_fn = display::name_fn(name);
|
||||
|
||||
let param = Param::from_ast(ast);
|
||||
let num_params_fn = param.gen_num_params_fn();
|
||||
let update_params_fn = param.gen_update_params_fn();
|
||||
let load_optim_state = param.gen_load_optim_state_fn();
|
||||
let register_optim_state = param.gen_register_optim_state_fn();
|
||||
let devices_fn = param.gen_devices_fn();
|
||||
let to_device_fn = param.gen_to_device_fn();
|
||||
let state_fn = param.gen_state_fn();
|
||||
let load_fn = param.gen_load_fn();
|
||||
let inner_fn = param.gen_inner_fn();
|
||||
|
||||
let gen = quote! {
|
||||
impl #generics burn::module::Module for #name #generics_ty #generics_where {
|
||||
type Backend=B;
|
||||
#name_fn
|
||||
#num_params_fn
|
||||
#update_params_fn
|
||||
#load_optim_state
|
||||
#register_optim_state
|
||||
#devices_fn
|
||||
#to_device_fn
|
||||
|
||||
#state_fn
|
||||
#load_fn
|
||||
}
|
||||
|
||||
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
|
||||
type ADBackend=B;
|
||||
type InnerModule=#name<B::InnerBackend>;
|
||||
|
||||
#inner_fn
|
||||
}
|
||||
|
||||
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
|
||||
#display_fn
|
||||
}
|
||||
};
|
||||
|
||||
gen.into()
|
||||
#[proc_macro_derive(Config, attributes(config))]
|
||||
pub fn config_derive(input: TokenStream) -> TokenStream {
|
||||
let item = syn::parse(input).unwrap();
|
||||
config_attr_impl(&item).into()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
use super::param::Param;
|
||||
use crate::module::display;
|
||||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
|
||||
let display_fn = display::display_fn();
|
||||
let name_fn = display::name_fn(name);
|
||||
|
||||
let param = Param::from_ast(ast);
|
||||
let num_params_fn = param.gen_num_params_fn();
|
||||
let update_params_fn = param.gen_update_params_fn();
|
||||
let load_optim_state = param.gen_load_optim_state_fn();
|
||||
let register_optim_state = param.gen_register_optim_state_fn();
|
||||
let devices_fn = param.gen_devices_fn();
|
||||
let to_device_fn = param.gen_to_device_fn();
|
||||
let state_fn = param.gen_state_fn();
|
||||
let load_fn = param.gen_load_fn();
|
||||
let inner_fn = param.gen_inner_fn();
|
||||
|
||||
let gen = quote! {
|
||||
impl #generics burn::module::Module for #name #generics_ty #generics_where {
|
||||
type Backend=B;
|
||||
#name_fn
|
||||
#num_params_fn
|
||||
#update_params_fn
|
||||
#load_optim_state
|
||||
#register_optim_state
|
||||
#devices_fn
|
||||
#to_device_fn
|
||||
|
||||
#state_fn
|
||||
#load_fn
|
||||
}
|
||||
|
||||
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
|
||||
type ADBackend=B;
|
||||
type InnerModule=#name<B::InnerBackend>;
|
||||
|
||||
#inner_fn
|
||||
}
|
||||
|
||||
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
|
||||
#display_fn
|
||||
}
|
||||
};
|
||||
|
||||
gen.into()
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
pub(crate) mod display;
|
||||
pub(crate) mod param;
|
||||
|
||||
mod base;
|
||||
|
||||
pub(crate) use base::*;
|
|
@ -1,7 +1,6 @@
|
|||
use crate::field::FieldTypeAnalyzer;
|
||||
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::Field;
|
||||
|
||||
pub struct Param {
|
||||
fields_param: Vec<FieldTypeAnalyzer>,
|
||||
|
@ -215,18 +214,3 @@ impl Param {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
|
||||
let mut fields = Vec::new();
|
||||
|
||||
match &ast.data {
|
||||
syn::Data::Struct(struct_data) => {
|
||||
for field in struct_data.fields.iter() {
|
||||
fields.push(field.clone());
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(_) => panic!("Only struct can be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct cna be derived"),
|
||||
};
|
||||
fields
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
use syn::{Attribute, Ident, Meta, NestedMeta};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AttributeAnalyzer {
|
||||
attr: Attribute,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttributeItem {
|
||||
pub ident: Ident,
|
||||
pub value: syn::Lit,
|
||||
}
|
||||
|
||||
impl AttributeAnalyzer {
|
||||
pub fn new(attr: Attribute) -> Self {
|
||||
Self { attr }
|
||||
}
|
||||
|
||||
pub fn items(&self) -> Vec<AttributeItem> {
|
||||
let config = match self.attr.parse_meta() {
|
||||
Ok(val) => val,
|
||||
_ => return Vec::new(),
|
||||
};
|
||||
let nested = match config {
|
||||
Meta::List(val) => val.nested,
|
||||
_ => return Vec::new(),
|
||||
};
|
||||
|
||||
let mut output = Vec::new();
|
||||
for pair in nested.into_iter() {
|
||||
if let NestedMeta::Meta(Meta::NameValue(value)) = pair {
|
||||
output.push(AttributeItem {
|
||||
ident: value.path.get_ident().unwrap().clone(),
|
||||
value: value.lit,
|
||||
});
|
||||
};
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
pub fn has_name(&self, name: &str) -> bool {
|
||||
Self::path_syn_name(&self.attr.path) == name
|
||||
}
|
||||
|
||||
fn path_syn_name(path: &syn::Path) -> String {
|
||||
let length = path.segments.len();
|
||||
let mut name = String::new();
|
||||
for (i, segment) in path.segments.iter().enumerate() {
|
||||
if i == length - 1 {
|
||||
name += segment.ident.to_string().as_str();
|
||||
} else {
|
||||
let tmp = segment.ident.to_string() + "::";
|
||||
name += tmp.as_str();
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
use super::attribute::AttributeAnalyzer;
|
||||
use proc_macro2::Ident;
|
||||
use syn::{Field, Type, TypePath};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FieldTypeAnalyzer {
|
||||
pub field: Field,
|
||||
}
|
||||
|
@ -50,7 +52,7 @@ impl FieldTypeAnalyzer {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn path_name(path: &TypePath) -> String {
|
||||
fn path_name(path: &TypePath) -> String {
|
||||
let length = path.path.segments.len();
|
||||
let mut name = String::new();
|
||||
for (i, segment) in path.path.segments.iter().enumerate() {
|
||||
|
@ -64,8 +66,30 @@ impl FieldTypeAnalyzer {
|
|||
name
|
||||
}
|
||||
|
||||
pub fn attributes(&self) -> impl Iterator<Item = AttributeAnalyzer> {
|
||||
self.field
|
||||
.attrs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(AttributeAnalyzer::new)
|
||||
}
|
||||
|
||||
pub fn is_param(&self) -> bool {
|
||||
let params_types = vec!["Param", "burn::Param"];
|
||||
self.is_of_type(¶ms_types)
|
||||
self.is_of_type(&["Param", "burn::Param"])
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
|
||||
let mut fields = Vec::new();
|
||||
|
||||
match &ast.data {
|
||||
syn::Data::Struct(struct_data) => {
|
||||
for field in struct_data.fields.iter() {
|
||||
fields.push(field.clone());
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(_) => panic!("Only struct can be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct cna be derived"),
|
||||
};
|
||||
fields
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
pub(crate) mod attribute;
|
||||
pub(crate) mod field;
|
|
@ -2,7 +2,6 @@ use burn::data::dataloader::batcher::Batcher;
|
|||
use burn::data::dataloader::DataLoaderBuilder;
|
||||
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
|
||||
use burn::module::{Forward, Module, Param, State};
|
||||
use burn::nn;
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::momentum::MomentumConfig;
|
||||
use burn::optim::{Optimizer, Sgd, SgdConfig};
|
||||
|
@ -12,6 +11,7 @@ use burn::tensor::{Data, ElementConversion, Shape, Tensor};
|
|||
use burn::train::logger::{AsyncLogger, CLILogger};
|
||||
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
|
||||
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
|
||||
use burn::{config, nn};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
@ -21,6 +21,17 @@ struct Model<B: Backend> {
|
|||
output: Param<nn::Linear<B>>,
|
||||
}
|
||||
|
||||
config!(
|
||||
struct MlpConfig {
|
||||
#[config(default = 4)]
|
||||
num_layers: usize,
|
||||
#[config(default = 0.2)]
|
||||
dropout: f64,
|
||||
#[config(default = 1024)]
|
||||
dim: usize,
|
||||
}
|
||||
);
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct Mlp<B: Backend> {
|
||||
linears: Param<Vec<nn::Linear<B>>>,
|
||||
|
@ -69,42 +80,28 @@ impl<B: Backend> Forward<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Mlp<B> {
|
||||
fn new(dim: usize, num_layers: usize) -> Self {
|
||||
let mut linears = Vec::with_capacity(num_layers);
|
||||
fn new(config: &MlpConfig) -> Self {
|
||||
let mut linears = Vec::with_capacity(config.num_layers);
|
||||
|
||||
for _ in 0..num_layers {
|
||||
let config = nn::LinearConfig {
|
||||
d_input: dim,
|
||||
d_output: dim,
|
||||
bias: true,
|
||||
};
|
||||
let linear = nn::Linear::new(&config);
|
||||
for _ in 0..config.num_layers {
|
||||
let linear = nn::Linear::new(&nn::LinearConfig::new(config.dim, config.dim));
|
||||
linears.push(linear);
|
||||
}
|
||||
|
||||
Self {
|
||||
linears: Param::new(linears),
|
||||
dropout: nn::Dropout::new(&nn::DropoutConfig { prob: 0.3 }),
|
||||
dropout: nn::Dropout::new(&nn::DropoutConfig::new(0.3)),
|
||||
activation: nn::ReLU::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
fn new(d_input: usize, d_hidden: usize, num_layers: usize, num_classes: usize) -> Self {
|
||||
let mlp = Mlp::new(d_hidden, num_layers);
|
||||
let config_input = nn::LinearConfig {
|
||||
d_input,
|
||||
d_output: d_hidden,
|
||||
bias: true,
|
||||
};
|
||||
let config_output = nn::LinearConfig {
|
||||
d_input: d_hidden,
|
||||
d_output: num_classes,
|
||||
bias: true,
|
||||
};
|
||||
let output = nn::Linear::new(&config_output);
|
||||
let input = nn::Linear::new(&config_input);
|
||||
fn new(d_input: usize, num_classes: usize) -> Self {
|
||||
let mlp_config = MlpConfig::new();
|
||||
let mlp = Mlp::new(&mlp_config);
|
||||
let output = nn::Linear::new(&nn::LinearConfig::new(mlp_config.dim, num_classes));
|
||||
let input = nn::Linear::new(&nn::LinearConfig::new(d_input, mlp_config.dim));
|
||||
|
||||
Self {
|
||||
mlp: Param::new(mlp),
|
||||
|
@ -150,14 +147,12 @@ fn run<B: ADBackend>(device: B::Device) {
|
|||
let batch_size = 128;
|
||||
let num_epochs = 15;
|
||||
let num_workers = 8;
|
||||
let num_layers = 4;
|
||||
let hidden_dim = 1024;
|
||||
let seed = 42;
|
||||
|
||||
let state_model = State::<f32>::load("/tmp/mnist_state_model").ok();
|
||||
let state_optim = State::<f32>::load("/tmp/mnist_state_optim").ok();
|
||||
|
||||
let mut model = Model::new(784, hidden_dim, num_layers, 10);
|
||||
let mut model = Model::new(784, 10);
|
||||
model.to_device(device);
|
||||
|
||||
if let Some(state) = state_model {
|
||||
|
@ -165,16 +160,12 @@ fn run<B: ADBackend>(device: B::Device) {
|
|||
model.load(&state.convert()).unwrap();
|
||||
}
|
||||
|
||||
let mut optim = Sgd::new(&SgdConfig {
|
||||
learning_rate: 2.5e-2,
|
||||
weight_decay: Some(WeightDecayConfig { penalty: 0.01 }),
|
||||
momentum: Some(MomentumConfig {
|
||||
momentum: 0.9,
|
||||
dampening: 0.1,
|
||||
nesterov: true,
|
||||
}),
|
||||
});
|
||||
let optim_config = SgdConfig::new()
|
||||
.with_learning_rate(2.5e-2)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.05)))
|
||||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)));
|
||||
|
||||
let mut optim = Sgd::new(&optim_config);
|
||||
if let Some(state) = state_optim {
|
||||
println!("Loading optimizer state");
|
||||
optim.load(&model, &state.convert()).unwrap();
|
||||
|
|
|
@ -2,13 +2,14 @@
|
|||
extern crate derive_new;
|
||||
|
||||
pub mod data;
|
||||
pub mod macros;
|
||||
pub mod module;
|
||||
pub mod nn;
|
||||
pub mod optim;
|
||||
pub mod tensor;
|
||||
pub mod train;
|
||||
|
||||
pub(crate) mod macros;
|
||||
pub use burn_derive::Config;
|
||||
|
||||
#[cfg(test)]
|
||||
pub type TestBackend = crate::tensor::backend::TchBackend<f32>;
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
#[macro_export]
|
||||
macro_rules! config {
|
||||
($item:item) => {
|
||||
#[derive(new, serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
#[derive(burn::Config, serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
$item
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use config;
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate::macros::config;
|
||||
use crate as burn;
|
||||
|
||||
use crate::config;
|
||||
use crate::module::Forward;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, ElementConversion, Tensor};
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::macros::config;
|
||||
use crate::config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Forward, Param};
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -11,7 +11,8 @@ config!(
|
|||
pub struct LayerNormConfig {
|
||||
/// The size of the input features.
|
||||
pub d_model: usize,
|
||||
/// A value required for numerical stability, typically 1e-5.
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
pub epsilon: f64,
|
||||
}
|
||||
);
|
||||
|
@ -61,10 +62,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn layer_norm_forward() {
|
||||
let config = LayerNormConfig {
|
||||
d_model: 10,
|
||||
epsilon: 1e-5,
|
||||
};
|
||||
let config = LayerNormConfig::new(10);
|
||||
let module = LayerNorm::<TestBackend>::new(&config);
|
||||
let input = Tensor::from_data(Data::from([[
|
||||
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
|
||||
|
@ -82,10 +80,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn layer_norm_backward() {
|
||||
let config = LayerNormConfig {
|
||||
d_model: 2,
|
||||
epsilon: 1e-5,
|
||||
};
|
||||
let config = LayerNormConfig::new(2);
|
||||
let module = LayerNorm::<TestADBackend>::new(&config);
|
||||
let tensor_1 = Tensor::<TestADBackend, 2>::from_data(Data::from([[0.0, 1.0], [3.0, 4.0]]));
|
||||
let tensor_2 = Tensor::<TestADBackend, 2>::from_data(Data::from([[6.0, 7.0], [9.0, 10.0]]));
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::macros::config;
|
||||
use crate::config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Forward, Param};
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -15,6 +15,7 @@ config!(
|
|||
/// The size of the output features.
|
||||
pub d_output: usize,
|
||||
/// If a bias should be applied during the linear transformation.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
}
|
||||
);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::macros::config;
|
||||
use crate as burn;
|
||||
use crate::config;
|
||||
|
||||
config!(
|
||||
pub struct AdamConfig {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use super::{load_state_gradients, register_state_gradients};
|
||||
use crate::macros::config;
|
||||
use crate::config;
|
||||
use crate::module::{ParamId, StateNamed};
|
||||
use crate::tensor::backend::ADBackend;
|
||||
use crate::tensor::{ElementConversion, Gradients, Tensor};
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use super::{load_state_gradients, register_state_gradients};
|
||||
use crate::macros::config;
|
||||
use crate::config;
|
||||
use crate::module::{ParamId, StateNamed};
|
||||
use crate::tensor::backend::ADBackend;
|
||||
use crate::tensor::{ElementConversion, Gradients, Tensor};
|
||||
|
@ -8,10 +10,13 @@ config!(
|
|||
/// Configuration to create momentum [Momentum](Momentum).
|
||||
pub struct MomentumConfig {
|
||||
/// Momemtum factor
|
||||
#[config(default = 0.9)]
|
||||
pub momentum: f64,
|
||||
/// Dampening factor.
|
||||
#[config(default = 0.1)]
|
||||
pub dampening: f64,
|
||||
/// Enables Nesterov momentum, see [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
|
||||
#[config(default = false)]
|
||||
pub nesterov: bool,
|
||||
}
|
||||
);
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate as burn;
|
||||
|
||||
use super::decay::{WeightDecay, WeightDecayConfig};
|
||||
use super::momentum::{Momentum, MomentumConfig};
|
||||
use crate::macros::config;
|
||||
use crate::config;
|
||||
use crate::module::{ParamId, StateNamed};
|
||||
use crate::optim::Optimizer;
|
||||
use crate::tensor::backend::ADBackend;
|
||||
|
@ -10,6 +12,7 @@ config!(
|
|||
/// Configuration to create the [Sgd](Sgd) optimizer.
|
||||
pub struct SgdConfig {
|
||||
/// Learning rate for the optimizer.
|
||||
#[config(default = 0.01)]
|
||||
pub learning_rate: f64,
|
||||
/// [Weight decay](WeightDecayConfig) config.
|
||||
pub weight_decay: Option<WeightDecayConfig>,
|
||||
|
|
Loading…
Reference in New Issue