Feat/module no grad (#274)

This commit is contained in:
Nathaniel Simard 2023-04-07 09:01:27 -04:00 committed by GitHub
parent d8f64ce1dd
commit f04fe101d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 366 additions and 389 deletions

View File

@ -745,16 +745,30 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
fn detach<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
// When we detach a tensor, we remove it from the graph, but we still want to keep the
// `require_grad` setting.
let is_require_grad = Self::is_require_grad(&tensor);
let tensor = ADTensor::new(tensor.primitive);
match tensor.node.requirement {
Requirement::Grad => tensor.require_grad(),
_ => tensor,
match is_require_grad {
true => tensor.require_grad(),
false => tensor,
}
}
fn require_grad<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
tensor.require_grad()
fn set_require_grad<const D: usize>(
tensor: ADTensor<B, D>,
require_grad: bool,
) -> ADTensor<B, D> {
if require_grad {
return tensor.require_grad();
}
ADTensor::new(tensor.primitive)
}
fn is_require_grad<const D: usize>(tensor: &ADTensor<B, D>) -> bool {
matches!(tensor.node.requirement, Requirement::Grad)
}
fn mean<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, 1> {

View File

@ -1,4 +1,4 @@
use alloc::{format, string::String, vec::Vec};
use alloc::vec::Vec;
use super::ParamId;
use crate::{
@ -8,6 +8,58 @@ use crate::{
pub use burn_derive::Module;
use burn_tensor::Tensor;
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
struct Mapper<'a, B: Backend> {
capture: &'a $ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor, self.capture)
}
}
let mut mapper = Mapper {
capture: $capture,
backend: core::marker::PhantomData::default(),
};
$module.map(&mut mapper)
}};
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
}
let mut state = $init();
let mut visitor = Visitor {
state: &mut state,
backend: core::marker::PhantomData::default(),
};
$module.visit(&mut visitor);
state
}};
}
/// Trait for all neural network modules.
///
/// Modules should be created using the [derive](burn_derive::Module) attribute.
@ -42,13 +94,80 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
type Record: Record;
/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<B::Device>;
fn devices(&self) -> Vec<B::Device> {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
let device = tensor.device();
if !state.contains(&device) {
state.push(device);
}
},
state = Vec<B::Device>,
init = Vec::new
)
}
/// Fork the module and all of its sub-modules to the given device.
///
/// # Notes
///
/// This is similar to [to_device](Module::to_device), but it ensures the module will
/// have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();
if is_require_grad {
tensor = tensor.require_grad();
}
tensor
},
capture = { device: B::Device }
)
}
/// Move the module and all of its sub-modules to the given device.
fn to_device(self, device: &B::Device) -> Self;
/// Detach the module from the graph.
fn detach(self) -> Self;
///
/// # Warnings
///
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
/// backward only one time even if you have the same module on multiple devices. If you want to
/// call backward multiple times, look into using [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
capture = { device: B::Device }
)
}
/// Each tensor in the module tree will not require grad.
///
/// # Warnings
///
/// This should not be used for inference, use [valid](ADModule::valid) when using
/// AD modules. This is mostly useful when performing partial finetuning, which is updating only
/// a small fraction of the parameters instead of finetuning all of them.
fn no_grad(self) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
)
}
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize;
fn num_params(&self) -> usize {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
/// Map each tensor in the module with a [mapper](ModuleMapper).
@ -72,21 +191,5 @@ pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
type InnerModule: Module<B::InnerBackend>;
/// Get the same module, but on the inner backend without auto-differentiation.
fn inner(self) -> Self::InnerModule;
fn from_inner(module: Self::InnerModule) -> Self;
fn valid(&self) -> Self::InnerModule;
}
#[derive(new, Debug)]
pub struct LoadingError {
message: String,
}
impl core::fmt::Display for LoadingError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Loading error: {}", self.message).as_str())
}
}
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
#[cfg(feature = "std")]
impl std::error::Error for LoadingError {}

View File

@ -1,10 +1,8 @@
use alloc::format;
use serde::{Deserialize, Serialize};
use super::ParamId;
use alloc::format;
/// Define a trainable parameter.
#[derive(new, Debug, Clone, Serialize, Deserialize)]
/// Define a parameter.
#[derive(new, Debug, Clone)]
pub struct Param<T> {
pub(crate) id: ParamId,
pub(crate) value: T,

View File

@ -5,22 +5,6 @@ macro_rules! constant {
(module) => {
type Record = ();
fn devices(&self) -> alloc::vec::Vec<<B as burn_tensor::backend::Backend>::Device> {
alloc::vec::Vec::new()
}
fn to_device(self, _device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self
}
fn detach(self) -> Self {
self
}
fn num_params(&self) -> usize {
0
}
fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
@ -39,12 +23,8 @@ macro_rules! constant {
(ad_module, $type:ty) => {
type InnerModule = $type;
fn inner(self) -> Self::InnerModule {
self
}
fn from_inner(module: Self::InnerModule) -> Self {
module
fn valid(&self) -> Self::InnerModule {
self.clone()
}
};

View File

@ -1,10 +1,7 @@
use alloc::string::{String, ToString};
use burn_common::id::IdGenerator;
use serde::{Deserialize, Serialize};
#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct ParamId {
value: String,
}
@ -35,6 +32,9 @@ impl ParamId {
value: IdGenerator::generate(),
}
}
pub fn into_string(self) -> String {
self.value
}
}
impl core::fmt::Display for ParamId {

View File

@ -10,29 +10,6 @@ where
{
type Record = Option<T::Record>;
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
if let Some(module) = self {
return Module::<B>::devices(module);
}
Vec::new()
}
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.map(|module| module.to_device(device))
}
fn detach(self) -> Self {
self.map(|module| module.detach())
}
fn num_params(&self) -> usize {
match &self {
Some(module) => module.num_params(),
None => 0,
}
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
if let Some(module) = self {
module.visit(visitor)
@ -60,12 +37,8 @@ where
{
type InnerModule = Option<T::InnerModule>;
fn inner(self) -> Self::InnerModule {
self.map(|module| module.inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
fn valid(&self) -> Self::InnerModule {
self.as_ref().map(|module| module.valid())
}
}
@ -76,22 +49,6 @@ where
{
type Record = Vec<T::Record>;
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
let mut devices = Vec::new();
for module in self.iter() {
devices.append(&mut module.devices());
}
devices
}
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.into_iter().map(|val| val.to_device(device)).collect()
}
fn detach(self) -> Self {
self.into_iter().map(|module| module.detach()).collect()
}
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
@ -130,15 +87,8 @@ where
{
type InnerModule = Vec<T::InnerModule>;
fn inner(self) -> Self::InnerModule {
self.into_iter().map(|module| module.inner()).collect()
}
fn from_inner(module: Self::InnerModule) -> Self {
module
.into_iter()
.map(|module| T::from_inner(module))
.collect()
fn valid(&self) -> Self::InnerModule {
self.iter().map(|module| module.valid()).collect()
}
}
@ -158,14 +108,6 @@ where
devices
}
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.map(|val| val.to_device(device))
}
fn detach(self) -> Self {
self.map(|module| module.detach())
}
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
@ -209,11 +151,7 @@ where
{
type InnerModule = [T::InnerModule; N];
fn inner(self) -> Self::InnerModule {
self.map(|module| module.inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
fn valid(&self) -> Self::InnerModule {
self.map(|module| module.valid())
}
}

View File

@ -1,4 +1,4 @@
use alloc::{sync::Arc, vec, vec::Vec};
use alloc::sync::Arc;
use super::ParamId;
use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor, Param};
@ -40,62 +40,22 @@ use threading::*;
/// The state value is the average of all updates on all threads.
#[derive(Clone, Debug)]
pub struct RunningState<V> {
id: ParamId,
values: Arc<Mutex<HashMap<ThreadId, V>>>,
value: Arc<RwLock<V>>,
}
impl<B: Backend, const D: usize> From<RunningState<Tensor<B, D>>>
for Param<RunningState<Tensor<B, D>>>
{
fn from(value: RunningState<Tensor<B, D>>) -> Self {
Param {
id: ParamId::new(),
value,
}
}
}
impl<const D: usize, B: Backend> Module<B> for Param<RunningState<Tensor<B, D>>> {
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;
fn num_params(&self) -> usize {
let tensor = self.value.value.read().unwrap();
tensor.shape().num_elements()
}
fn devices(&self) -> Vec<B::Device> {
let tensor = self.value.value.read().unwrap();
vec![tensor.device()]
}
fn to_device(self, device: &B::Device) -> Self {
self.value.sync();
let mut tensor = self.value.value.write().unwrap();
tensor.inplace(|tensor| tensor.to_device(device));
core::mem::drop(tensor);
self
}
fn detach(self) -> Self {
self.sync();
let mut tensor = self.value.value.write().unwrap();
tensor.inplace(|tensor| tensor.detach());
core::mem::drop(tensor);
self
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
let tensor = self.value.value.read().unwrap();
let tensor = self.value.read().unwrap();
visitor.visit(&self.id, &tensor)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut tensor = self.value.value.write().unwrap();
let mut tensor = self.value.write().unwrap();
let tensor_out = mapper.map(&self.id, tensor.clone());
*tensor = tensor_out;
@ -106,13 +66,13 @@ impl<const D: usize, B: Backend> Module<B> for Param<RunningState<Tensor<B, D>>>
fn into_record(self) -> Self::Record {
self.sync();
let tensor = self.value.value.read().unwrap();
let tensor = self.value.read().unwrap();
Param::new(self.id, tensor.clone())
}
fn load_record(mut self, record: Self::Record) -> Self {
let mut tensor = self.value.value.write().unwrap();
let mut tensor = self.value.write().unwrap();
*tensor = record.value.to_device(&tensor.device());
self.id = record.id;
@ -126,6 +86,16 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
/// Create a new running state.
pub fn new(value: Tensor<B, D>) -> Self {
Self {
id: ParamId::new(),
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(RwLock::new(value)),
}
}
/// Create a new running state.
pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {
Self {
id,
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(RwLock::new(value)),
}
@ -200,27 +170,13 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
}
}
impl<const D: usize, B: ADBackend> ADModule<B> for Param<RunningState<Tensor<B, D>>> {
type InnerModule = Param<RunningState<Tensor<B::InnerBackend, D>>>;
impl<const D: usize, B: ADBackend> ADModule<B> for RunningState<Tensor<B, D>> {
type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;
fn inner(self) -> Self::InnerModule {
fn valid(&self) -> Self::InnerModule {
self.sync();
let value = self.value.value();
let value = self.value();
Param {
id: self.id,
value: RunningState::new(value.inner()),
}
}
fn from_inner(module: Self::InnerModule) -> Self {
module.sync();
let value = module.value.value();
Param {
id: module.id,
value: RunningState::new(Tensor::from_inner(value)),
}
RunningState::with_id(self.id.clone(), value.inner())
}
}

View File

@ -1,5 +1,3 @@
use alloc::{vec, vec::Vec};
use super::{Param, ParamId};
use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor};
use crate::tensor::{
@ -9,45 +7,20 @@ use crate::tensor::{
impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
fn from(value: Tensor<B, D>) -> Self {
Param {
id: ParamId::new(),
value: value.require_grad(),
}
Param::new(ParamId::new(), value.require_grad())
}
}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;
fn num_params(&self) -> usize {
self.value.shape().num_elements()
}
fn devices(&self) -> Vec<B::Device> {
vec![self.value.device()]
}
fn to_device(self, device: &B::Device) -> Self {
Self {
id: self.id,
value: self.value.to_device(device).require_grad(),
}
}
fn detach(self) -> Self {
Self {
id: self.id,
value: self.value.detach().require_grad(),
}
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit(&self.id, &self.value)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let value = mapper.map(&self.id, self.value).require_grad();
Self { id: self.id, value }
let value = mapper.map(&self.id, self.value);
Self::new(self.id, value)
}
fn into_record(self) -> Self::Record {
@ -55,24 +28,63 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
}
fn load_record(self, record: Self::Record) -> Self {
record.to_device(&self.device())
let mut tensor = record.value.detach();
let device = self.device();
// Make sure we load the record into the same module device.
if tensor.device() != device {
tensor = tensor.to_device(&device).detach();
}
// Make sure we load the record with the same autodiff setting.
if self.is_require_grad() {
tensor = tensor.require_grad();
}
Self::new(record.id, tensor)
}
}
impl<const D: usize, B: ADBackend> ADModule<B> for Param<Tensor<B, D>> {
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
fn inner(self) -> Self::InnerModule {
Param {
id: self.id,
value: self.value.inner(),
}
}
fn from_inner(module: Self::InnerModule) -> Self {
Param {
id: module.id,
value: Tensor::from_inner(module.value).require_grad(),
}
fn valid(&self) -> Self::InnerModule {
Param::new(
self.id.clone(),
self.value.clone().inner().set_require_grad(false),
)
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use crate::{
record::{NoStdInferenceRecordSettings, Record},
TestADBackend,
};
use super::*;
#[test]
fn test_load_record_setting() {
let tensor = Tensor::<TestADBackend, 2>::ones([3, 3]);
let bytes = Param::from(tensor.clone())
.into_record()
.record::<NoStdInferenceRecordSettings>(())
.unwrap();
let no_grad_is_require_grad = Param::from(tensor.clone())
.no_grad()
.load_record(Param::load::<NoStdInferenceRecordSettings>(bytes.clone()).unwrap())
.value
.is_require_grad();
let with_default_is_require_grad = Param::from(tensor)
.load_record(Param::load::<NoStdInferenceRecordSettings>(bytes).unwrap())
.value
.is_require_grad();
assert!(!no_grad_is_require_grad);
assert!(with_default_is_require_grad);
}
}

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::nn::cache::TensorCache;
@ -9,7 +7,6 @@ use crate::{
nn,
tensor::{activation, backend::Backend, Bool, Tensor},
};
use libm::sqrtf;
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer.
@ -257,6 +254,7 @@ pub struct MHAAutoregressiveCache<B: Backend> {
mod tests {
use super::*;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
use alloc::vec::Vec;
use burn::tensor::{Distribution, Shape};
use burn_tensor::Int;

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;

View File

@ -1,6 +1,3 @@
use alloc::vec::Vec;
use burn_tensor::Int;
use crate as burn;
use super::Initializer;
@ -9,6 +6,7 @@ use crate::module::Module;
use crate::module::Param;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::Int;
/// Configuration to create an [Embedding](Embedding) layer.
#[derive(Config)]

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::{
@ -28,8 +26,8 @@ pub struct BatchNormConfig {
pub struct BatchNorm<B: Backend, const D: usize> {
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
running_mean: Param<RunningState<Tensor<B, 1>>>,
running_var: Param<RunningState<Tensor<B, 1>>>,
running_mean: RunningState<Tensor<B, 1>>,
running_var: RunningState<Tensor<B, 1>>,
momentum: f64,
epsilon: f64,
}
@ -46,8 +44,8 @@ impl BatchNormConfig {
BatchNorm {
gamma: Param::from(gamma),
beta: Param::from(beta),
running_mean: Param::from(RunningState::new(running_mean)),
running_var: Param::from(RunningState::new(running_var)),
running_mean: RunningState::new(running_mean),
running_var: RunningState::new(running_var),
momentum: self.momentum,
epsilon: self.epsilon,
}
@ -76,8 +74,8 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
let channels = input.dims()[1];
let mean = self.running_mean.val().value();
let var = self.running_var.val().value();
let mean = self.running_mean.value();
let var = self.running_var.value();
let mut shape = [1; DI];
shape[1] = channels;
@ -192,7 +190,7 @@ mod tests_1d {
let module = BatchNormConfig::new(3).init::<TestADBackend, 1>();
module.forward(input_tensor());
let module = module.inner();
let module = module.valid();
let output = module.forward(input_tensor());
output.to_data().assert_approx_eq(
@ -247,7 +245,7 @@ mod tests_2d {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
module.forward(input_tensor());
let module = module.inner();
let module = module.valid();
let output = module.forward(input_tensor());
output.to_data().assert_approx_eq(
@ -301,11 +299,9 @@ mod tests_2d {
let _output = module.forward(input_tensor());
let module_valid = module.inner();
let module_valid = module.valid();
let running_mean = module_valid.running_mean.value();
let module_train = BatchNorm::<TestADBackend, 2>::from_inner(module_valid);
let running_mean_after = module_train.running_mean.value();
let running_mean_after = module.running_mean.value();
running_mean_after
.into_data()

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;

View File

@ -1,5 +1,3 @@
use alloc::vec::Vec;
use crate as burn;
use crate::{

View File

@ -97,9 +97,7 @@ mod tests {
fn test_convert_grads() {
let layer_1 = layer();
let mut layer_2 = layer_1.clone();
layer_2 = layer_2
.to_device(&<TestADBackend as Backend>::Device::default())
.detach();
layer_2 = layer_2.fork(&<TestADBackend as Backend>::Device::default());
let loss_1 = layer_1.forward(random_tensor());
let loss_2 = layer_2.forward(random_tensor());
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);

View File

@ -83,7 +83,9 @@ where
if let Some(grad) = grad {
let device = grad.device();
let is_require_grad = tensor.is_require_grad();
let (key, record) = self.records.remove_entry(id).unzip();
let (tensor, state) = self.optimizer.step(
tensor.inner(),
grad,
@ -97,7 +99,11 @@ where
);
}
return Tensor::from_inner(tensor);
let mut tensor = Tensor::from_inner(tensor);
if is_require_grad {
tensor = tensor.require_grad();
}
return tensor;
}
tensor

View File

@ -1,6 +1,8 @@
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use serde::Deserialize;
use serde::Serialize;
use super::{Record, RecordSettings};
use crate::module::{Param, ParamId};
@ -87,21 +89,22 @@ impl<E: Element> Record for DataSerialize<E> {
}
}
/// (De)serialize parameters into a clean format.
#[derive(new, Debug, Clone, Serialize, Deserialize)]
pub struct ParamSerde<T> {
id: String,
param: T,
}
impl<T: Record> Record for Param<T> {
type Item<S: RecordSettings> = Param<T::Item<S>>;
type Item<S: RecordSettings> = ParamSerde<T::Item<S>>;
fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
Param {
id: self.id,
value: self.value.into_item(),
}
ParamSerde::new(self.id.into_string(), self.value.into_item())
}
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
Param {
id: item.id,
value: T::from_item(item.value),
}
Param::new(ParamId::from(item.id), T::from_item(item.param))
}
}

View File

@ -4,6 +4,8 @@ use burn::tensor::{Distribution, Shape, Tensor};
use burn_core as burn;
pub type TestBackend = burn_ndarray::NdArrayBackend<f32>;
#[cfg(feature = "std")]
pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
#[derive(Module, Debug)]
pub struct ModuleBasic<B: Backend> {
@ -93,3 +95,53 @@ mod num_params {
assert_eq!(2 * 20 * 20, module.num_params());
}
}
#[cfg(feature = "std")]
mod require_grad {
use burn_tensor::backend::ADBackend;
use super::*;
#[test]
fn should_have_grad_by_default() {
let module = ModuleBasic::<TestADBackend>::new();
let mut grads = calculate_grads(&module);
let grad_x = module.weight_basic.grad_remove(&mut grads);
assert!(grad_x.is_some());
}
#[test]
fn should_have_no_grad_after_no_grad() {
let module = ModuleBasic::<TestADBackend>::new().no_grad();
let mut grads = calculate_grads(&module);
let grad_x = module.weight_basic.grad_remove(&mut grads);
assert!(grad_x.is_none());
}
#[test]
fn should_have_grad_when_from_record() {
let module = ModuleBasic::<TestADBackend>::new();
let record = ModuleBasicRecord {
weight_basic: module.weight_basic.clone(), // Even when param is no_grad,
};
let module = module.load_record(record);
let mut grads = calculate_grads(&module);
let grad_x = module.weight_basic.grad_remove(&mut grads);
assert!(grad_x.is_some());
}
fn calculate_grads(
module: &ModuleBasic<TestADBackend>,
) -> <TestADBackend as ADBackend>::Gradients {
let x = Tensor::ones([20, 20]).require_grad();
let y = module.weight_basic.val().matmul(x);
y.backward()
}
}

View File

@ -25,13 +25,9 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let num_params_fn = generator.gen_num_params_fn();
let visit = generator.gen_visit_fn();
let map_mut = generator.gen_map_fn();
let devices_fn = generator.gen_devices_fn();
let to_device_fn = generator.gen_to_device_fn();
let inner_fn = generator.gen_inner_fn();
let from_inner_fn = generator.gen_from_inner_fn();
let valid_fn = generator.gen_valid_fn();
let into_record_fn = generator.gen_into_record_fn();
let load_record_fn = generator.gen_load_record_fn();
let detach_fn = generator.gen_detach_fn();
let clone_fn = generator.gen_clone_fn();
let generics_names_except_backend = generics_names_except_backend(&ast.generics);
@ -45,14 +41,10 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
impl #generics burn::module::Module<B> for #name #generics_ty #generics_where {
type Record = #record_name #generics_ty;
#devices_fn
#to_device_fn
#load_record_fn
#into_record_fn
#num_params_fn
#detach_fn
#visit
#map_mut
@ -61,8 +53,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
impl #generics burn::module::ADModule<B> for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type InnerModule=#name<B::InnerBackend, #generics_names_except_backend>;
#inner_fn
#from_inner_fn
#valid_fn
}
impl #generics core::fmt::Display for #name #generics_ty #generics_where {

View File

@ -96,68 +96,15 @@ impl FnGenerator {
}
}
pub fn gen_devices_fn(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
devices.append(&mut burn::module::Module::<B>::devices(&self.#name));
}
});
quote! {
fn devices(&self) -> Vec<B::Device> {
let mut devices = Vec::new();
#body
devices
}
}
}
pub fn gen_to_device_fn(&self) -> TokenStream {
pub fn gen_valid_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::to_device(self.#name, device);
let #name = burn::module::ADModule::<B>::valid(&self.#name);
}
});
quote! {
fn to_device(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_detach_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::detach(self.#name);
}
});
quote! {
fn detach(self) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_inner_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::ADModule::<B>::inner(self.#name);
}
});
quote! {
fn inner(self) -> Self::InnerModule {
fn valid(&self) -> Self::InnerModule {
#body
Self::InnerModule {
@ -167,24 +114,6 @@ impl FnGenerator {
}
}
pub fn gen_from_inner_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::ADModule::<B>::from_inner(module.#name);
}
});
quote! {
fn from_inner(module: Self::InnerModule) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_clone_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {

View File

@ -16,7 +16,7 @@ use burn_common::stub::Mutex;
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NdArrayDevice {
Cpu,
}

View File

@ -1,7 +1,5 @@
// Orginally copied from the burn/examples/mnist package
use alloc::vec::Vec;
use burn::{
config::Config,
module::Module,

View File

@ -1,7 +1,5 @@
// Orginally copied from the burn/examples/mnist package
use alloc::vec::Vec;
use crate::{
conv::{ConvBlock, ConvBlockConfig},
mlp::{Mlp, MlpConfig},

View File

@ -2,7 +2,7 @@ use super::element::TchElement;
use super::TchTensor;
use burn_tensor::backend::Backend;
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
/// The device struct when using the `tch` backend.
///
/// Note that you need to provide the device index when using Cuda.

View File

@ -305,7 +305,20 @@ where
/// Mark the tensor to keep gradients during the backward pass.
/// This function does nothing when autodiff is not enabled.
pub fn require_grad(self) -> Self {
Self::new(B::require_grad(self.primitive))
self.set_require_grad(true)
}
/// Returns true if the tensor requires gradients during the backward pass.
pub fn is_require_grad(&self) -> bool {
B::is_require_grad(&self.primitive)
}
/// Mark the tensor as tracked or untracked depending on the require grad argument.
/// When tracked, the gradients will be available after the backward pass.
///
/// This function does nothing when autodiff is not enabled.
pub fn set_require_grad(self, require_grad: bool) -> Self {
Self::new(B::set_require_grad(self.primitive, require_grad))
}
/// Applies the relu function to the tensor.

View File

@ -63,7 +63,7 @@ pub trait Backend:
+ 'static
{
/// Device type.
type Device: Clone + Default + core::fmt::Debug + Send + Sync;
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
/// Pointer to another backend that have a full precision float element type
type FullPrecisionBackend: Backend<FloatElem = Self::FullPrecisionElem, Device = Self::Device>;

View File

@ -193,10 +193,17 @@ pub trait TensorOps<B: Backend> {
// Should only be overriden by autodiff backends.
tensor
}
fn require_grad<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
fn set_require_grad<const D: usize>(
tensor: B::TensorPrimitive<D>,
_require_grad: bool,
) -> B::TensorPrimitive<D> {
// Should only be overriden by autodiff backends.
tensor
}
fn is_require_grad<const D: usize>(_tensor: &B::TensorPrimitive<D>) -> bool {
// Should only be overriden by autodiff backends.
false
}
fn sum<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
fn sum_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D>;
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;

View File

@ -192,7 +192,6 @@ where
}
None => None,
};
let model = model.detach();
Learner {
model,

View File

@ -24,14 +24,14 @@ pub struct TrainEpoch<TI> {
}
impl<I> ValidEpoch<I> {
pub fn run<B, M, TO, VO>(&self, model: M, callback: &mut Box<dyn LearnerCallback<TO, VO>>) -> M
pub fn run<B, M, TO, VO>(&self, model: &M, callback: &mut Box<dyn LearnerCallback<TO, VO>>)
where
B: ADBackend,
M: ADModule<B>,
M::InnerModule: ValidStep<I, VO>,
{
log::info!("Executing validation step for epoch {}", self.epoch);
let model = model.inner();
let model = model.valid();
let mut iterator = self.dataloader.iter();
let mut iteration = 0;
@ -50,8 +50,6 @@ impl<I> ValidEpoch<I> {
));
}
callback.on_valid_end_epoch(self.epoch);
ADModule::from_inner(model)
}
}
@ -77,6 +75,7 @@ impl<TI> TrainEpoch<TI> {
while let Some(item) = iterator.next() {
iteration += 1;
log::info!("Iteration {}", iteration);
let progress = iterator.progress();
let item = model.step(item);
@ -154,7 +153,6 @@ impl<TI> TrainEpoch<TI> {
let grads = item.grads.to_device(&device_main, &model);
log::info!("Updated device");
accumulator.accumulate(&model, grads);
accumulation_current += 1;

View File

@ -47,7 +47,7 @@ where
spawn(move || loop {
match receiver_input.recv() {
Ok(item) => {
let step = item.model.to_device(&device).detach();
let step = item.model.fork(&device);
let output = step.step(item.item);
sender_output.send(output).unwrap();

View File

@ -49,7 +49,7 @@ where
log::info!("Fitting {}", self.model.to_string());
// The reference model is always on the first device provided.
if let Some(device) = self.devices.get(0) {
self.model = self.model.to_device(device).detach();
self.model = self.model.fork(device);
}
let starting_epoch = match self.checkpoint {
@ -83,7 +83,7 @@ where
}
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
model = epoch_valid.run(model, &mut self.callback);
epoch_valid.run(&model, &mut self.callback);
Self::checkpoint(
&model,

View File

@ -2,8 +2,6 @@
// Orginally copied from the burn/examples/mnist package
use alloc::vec::Vec;
use burn::{
module::Module,
nn::{self, conv::Conv2dPaddingConfig, BatchNorm},
@ -15,6 +13,7 @@ pub struct Model<B: Backend> {
conv1: ConvBlock<B>,
conv2: ConvBlock<B>,
conv3: ConvBlock<B>,
dropout: nn::Dropout,
fc1: nn::Linear<B>,
fc2: nn::Linear<B>,
activation: nn::GELU,
@ -27,7 +26,6 @@ impl<B: Backend> Model<B> {
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26]
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24]
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22]
let hidden_size = 24 * 22 * 22;
let fc1 = nn::LinearConfig::new(hidden_size, 32)
.with_bias(false)
@ -36,12 +34,15 @@ impl<B: Backend> Model<B> {
.with_bias(false)
.init();
let dropout = nn::DropoutConfig::new(0.5).init();
Self {
conv1,
conv2,
conv3,
fc1,
fc2,
dropout,
activation: nn::GELU::new(),
}
}
@ -57,6 +58,7 @@ impl<B: Backend> Model<B> {
let [batch_size, channels, heigth, width] = x.dims();
let x = x.reshape([batch_size, channels * heigth * width]);
let x = self.dropout.forward(x);
let x = self.fc1.forward(x);
let x = self.activation.forward(x);

View File

@ -36,7 +36,7 @@ impl<B: Backend> Model<B> {
.with_bias(false)
.init();
let dropout = nn::DropoutConfig::new(0.3).init();
let dropout = nn::DropoutConfig::new(0.5).init();
Self {
conv1,
@ -60,9 +60,9 @@ impl<B: Backend> Model<B> {
let [batch_size, channels, heigth, width] = x.dims();
let x = x.reshape([batch_size, channels * heigth * width]);
let x = self.dropout.forward(x);
let x = self.fc1.forward(x);
let x = self.activation.forward(x);
let x = self.dropout.forward(x);
self.fc2.forward(x)
}

View File

@ -21,7 +21,7 @@ static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist";
#[derive(Config)]
pub struct MnistTrainingConfig {
#[config(default = 4)]
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]

View File

@ -42,7 +42,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
let record = Record::load::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
.expect("Trained model weights");
let model = model.load_record(record);
let model = model.to_device(&device);
let model = model.fork(&device);
println!("Running inference ...");
let item = batcher.batch(samples.clone());