mirror of https://github.com/tracel-ai/burn.git
Feat/module no grad (#274)
This commit is contained in:
parent
d8f64ce1dd
commit
f04fe101d8
|
@ -745,16 +745,30 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
fn detach<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||
// When we detach a tensor, we remove it from the graph, but we still want to keep the
|
||||
// `require_grad` setting.
|
||||
let is_require_grad = Self::is_require_grad(&tensor);
|
||||
let tensor = ADTensor::new(tensor.primitive);
|
||||
|
||||
match tensor.node.requirement {
|
||||
Requirement::Grad => tensor.require_grad(),
|
||||
_ => tensor,
|
||||
match is_require_grad {
|
||||
true => tensor.require_grad(),
|
||||
false => tensor,
|
||||
}
|
||||
}
|
||||
|
||||
fn require_grad<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||
tensor.require_grad()
|
||||
fn set_require_grad<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
require_grad: bool,
|
||||
) -> ADTensor<B, D> {
|
||||
if require_grad {
|
||||
return tensor.require_grad();
|
||||
}
|
||||
|
||||
ADTensor::new(tensor.primitive)
|
||||
}
|
||||
|
||||
fn is_require_grad<const D: usize>(tensor: &ADTensor<B, D>) -> bool {
|
||||
matches!(tensor.node.requirement, Requirement::Grad)
|
||||
}
|
||||
|
||||
fn mean<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, 1> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use alloc::{format, string::String, vec::Vec};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::ParamId;
|
||||
use crate::{
|
||||
|
@ -8,6 +8,58 @@ use crate::{
|
|||
pub use burn_derive::Module;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
|
||||
// We may consider making it public in the future.
|
||||
macro_rules! module {
|
||||
(map=$module:ident, ops=$item:expr) => {{
|
||||
struct Mapper;
|
||||
impl<B: Backend> ModuleMapper<B> for Mapper {
|
||||
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let func = $item;
|
||||
func(tensor)
|
||||
}
|
||||
}
|
||||
let mut mapper = Mapper;
|
||||
$module.map(&mut mapper)
|
||||
}};
|
||||
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
|
||||
struct Mapper<'a, B: Backend> {
|
||||
capture: &'a $ty,
|
||||
backend: core::marker::PhantomData<B>,
|
||||
}
|
||||
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
|
||||
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let func = $item;
|
||||
func(tensor, self.capture)
|
||||
}
|
||||
}
|
||||
let mut mapper = Mapper {
|
||||
capture: $capture,
|
||||
backend: core::marker::PhantomData::default(),
|
||||
};
|
||||
$module.map(&mut mapper)
|
||||
}};
|
||||
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
|
||||
struct Visitor<'a, B: Backend> {
|
||||
state: &'a mut $state_ty,
|
||||
backend: core::marker::PhantomData<B>,
|
||||
}
|
||||
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
|
||||
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
let func = $item;
|
||||
func(tensor, &mut self.state)
|
||||
}
|
||||
}
|
||||
let mut state = $init();
|
||||
let mut visitor = Visitor {
|
||||
state: &mut state,
|
||||
backend: core::marker::PhantomData::default(),
|
||||
};
|
||||
$module.visit(&mut visitor);
|
||||
state
|
||||
}};
|
||||
}
|
||||
|
||||
/// Trait for all neural network modules.
|
||||
///
|
||||
/// Modules should be created using the [derive](burn_derive::Module) attribute.
|
||||
|
@ -42,13 +94,80 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
type Record: Record;
|
||||
|
||||
/// Get the device list of the module and all of its sub-modules.
|
||||
fn devices(&self) -> Vec<B::Device>;
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
module!(
|
||||
visit = self,
|
||||
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
|
||||
let device = tensor.device();
|
||||
if !state.contains(&device) {
|
||||
state.push(device);
|
||||
}
|
||||
},
|
||||
state = Vec<B::Device>,
|
||||
init = Vec::new
|
||||
)
|
||||
}
|
||||
/// Fork the module and all of its sub-modules to the given device.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is similar to [to_device](Module::to_device), but it ensures the module will
|
||||
/// have its own autodiff graph.
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
module!(
|
||||
map = self,
|
||||
ops = |tensor: Tensor<B, D>, device: &B::Device| {
|
||||
let is_require_grad = tensor.is_require_grad();
|
||||
let mut tensor = tensor.to_device(device).detach();
|
||||
|
||||
if is_require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
|
||||
tensor
|
||||
},
|
||||
capture = { device: B::Device }
|
||||
)
|
||||
}
|
||||
/// Move the module and all of its sub-modules to the given device.
|
||||
fn to_device(self, device: &B::Device) -> Self;
|
||||
/// Detach the module from the graph.
|
||||
fn detach(self) -> Self;
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
|
||||
/// backward only one time even if you have the same module on multiple devices. If you want to
|
||||
/// call backward multiple times, look into using [fork](Module::fork) instead.
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
module!(
|
||||
map = self,
|
||||
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
|
||||
capture = { device: B::Device }
|
||||
)
|
||||
}
|
||||
/// Each tensor in the module tree will not require grad.
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// This should not be used for inference, use [valid](ADModule::valid) when using
|
||||
/// AD modules. This is mostly useful when performing partial finetuning, which is updating only
|
||||
/// a small fraction of the parameters instead of finetuning all of them.
|
||||
fn no_grad(self) -> Self {
|
||||
module!(
|
||||
map = self,
|
||||
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
|
||||
)
|
||||
}
|
||||
|
||||
/// Get the number of parameters the module has, including all of its sub-modules.
|
||||
fn num_params(&self) -> usize;
|
||||
fn num_params(&self) -> usize {
|
||||
module!(
|
||||
visit = self,
|
||||
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
|
||||
*state += tensor.shape().num_elements();
|
||||
},
|
||||
state = usize,
|
||||
init = || 0
|
||||
)
|
||||
}
|
||||
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
|
||||
/// Map each tensor in the module with a [mapper](ModuleMapper).
|
||||
|
@ -72,21 +191,5 @@ pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
|
|||
type InnerModule: Module<B::InnerBackend>;
|
||||
|
||||
/// Get the same module, but on the inner backend without auto-differentiation.
|
||||
fn inner(self) -> Self::InnerModule;
|
||||
fn from_inner(module: Self::InnerModule) -> Self;
|
||||
fn valid(&self) -> Self::InnerModule;
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct LoadingError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for LoadingError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("Loading error: {}", self.message).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for LoadingError {}
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use alloc::format;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::ParamId;
|
||||
use alloc::format;
|
||||
|
||||
/// Define a trainable parameter.
|
||||
#[derive(new, Debug, Clone, Serialize, Deserialize)]
|
||||
/// Define a parameter.
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct Param<T> {
|
||||
pub(crate) id: ParamId,
|
||||
pub(crate) value: T,
|
||||
|
|
|
@ -5,22 +5,6 @@ macro_rules! constant {
|
|||
(module) => {
|
||||
type Record = ();
|
||||
|
||||
fn devices(&self) -> alloc::vec::Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
alloc::vec::Vec::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
@ -39,12 +23,8 @@ macro_rules! constant {
|
|||
(ad_module, $type:ty) => {
|
||||
type InnerModule = $type;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone()
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
use alloc::string::{String, ToString};
|
||||
|
||||
use burn_common::id::IdGenerator;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
|
||||
pub struct ParamId {
|
||||
value: String,
|
||||
}
|
||||
|
@ -35,6 +32,9 @@ impl ParamId {
|
|||
value: IdGenerator::generate(),
|
||||
}
|
||||
}
|
||||
pub fn into_string(self) -> String {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for ParamId {
|
||||
|
|
|
@ -10,29 +10,6 @@ where
|
|||
{
|
||||
type Record = Option<T::Record>;
|
||||
|
||||
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
if let Some(module) = self {
|
||||
return Module::<B>::devices(module);
|
||||
}
|
||||
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self.map(|module| module.to_device(device))
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.map(|module| module.detach())
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
match &self {
|
||||
Some(module) => module.num_params(),
|
||||
None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
if let Some(module) = self {
|
||||
module.visit(visitor)
|
||||
|
@ -60,12 +37,8 @@ where
|
|||
{
|
||||
type InnerModule = Option<T::InnerModule>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self.map(|module| module.inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.map(|module| T::from_inner(module))
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.as_ref().map(|module| module.valid())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,22 +49,6 @@ where
|
|||
{
|
||||
type Record = Vec<T::Record>;
|
||||
|
||||
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
let mut devices = Vec::new();
|
||||
for module in self.iter() {
|
||||
devices.append(&mut module.devices());
|
||||
}
|
||||
devices
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self.into_iter().map(|val| val.to_device(device)).collect()
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.into_iter().map(|module| module.detach()).collect()
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.iter() {
|
||||
|
@ -130,15 +87,8 @@ where
|
|||
{
|
||||
type InnerModule = Vec<T::InnerModule>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self.into_iter().map(|module| module.inner()).collect()
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
.into_iter()
|
||||
.map(|module| T::from_inner(module))
|
||||
.collect()
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.iter().map(|module| module.valid()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -158,14 +108,6 @@ where
|
|||
devices
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self.map(|val| val.to_device(device))
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.map(|module| module.detach())
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.iter() {
|
||||
|
@ -209,11 +151,7 @@ where
|
|||
{
|
||||
type InnerModule = [T::InnerModule; N];
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self.map(|module| module.inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.map(|module| T::from_inner(module))
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.map(|module| module.valid())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use alloc::{sync::Arc, vec, vec::Vec};
|
||||
use alloc::sync::Arc;
|
||||
|
||||
use super::ParamId;
|
||||
use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor, Param};
|
||||
|
@ -40,62 +40,22 @@ use threading::*;
|
|||
/// The state value is the average of all updates on all threads.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RunningState<V> {
|
||||
id: ParamId,
|
||||
values: Arc<Mutex<HashMap<ThreadId, V>>>,
|
||||
value: Arc<RwLock<V>>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> From<RunningState<Tensor<B, D>>>
|
||||
for Param<RunningState<Tensor<B, D>>>
|
||||
{
|
||||
fn from(value: RunningState<Tensor<B, D>>) -> Self {
|
||||
Param {
|
||||
id: ParamId::new(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<RunningState<Tensor<B, D>>> {
|
||||
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let tensor = self.value.value.read().unwrap();
|
||||
tensor.shape().num_elements()
|
||||
}
|
||||
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
let tensor = self.value.value.read().unwrap();
|
||||
vec![tensor.device()]
|
||||
}
|
||||
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
self.value.sync();
|
||||
|
||||
let mut tensor = self.value.value.write().unwrap();
|
||||
tensor.inplace(|tensor| tensor.to_device(device));
|
||||
core::mem::drop(tensor);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.sync();
|
||||
|
||||
let mut tensor = self.value.value.write().unwrap();
|
||||
tensor.inplace(|tensor| tensor.detach());
|
||||
core::mem::drop(tensor);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
let tensor = self.value.value.read().unwrap();
|
||||
let tensor = self.value.read().unwrap();
|
||||
|
||||
visitor.visit(&self.id, &tensor)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let mut tensor = self.value.value.write().unwrap();
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
let tensor_out = mapper.map(&self.id, tensor.clone());
|
||||
|
||||
*tensor = tensor_out;
|
||||
|
@ -106,13 +66,13 @@ impl<const D: usize, B: Backend> Module<B> for Param<RunningState<Tensor<B, D>>>
|
|||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.sync();
|
||||
let tensor = self.value.value.read().unwrap();
|
||||
let tensor = self.value.read().unwrap();
|
||||
|
||||
Param::new(self.id, tensor.clone())
|
||||
}
|
||||
|
||||
fn load_record(mut self, record: Self::Record) -> Self {
|
||||
let mut tensor = self.value.value.write().unwrap();
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
*tensor = record.value.to_device(&tensor.device());
|
||||
self.id = record.id;
|
||||
|
||||
|
@ -126,6 +86,16 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
/// Create a new running state.
|
||||
pub fn new(value: Tensor<B, D>) -> Self {
|
||||
Self {
|
||||
id: ParamId::new(),
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(RwLock::new(value)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new running state.
|
||||
pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(RwLock::new(value)),
|
||||
}
|
||||
|
@ -200,27 +170,13 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule<B> for Param<RunningState<Tensor<B, D>>> {
|
||||
type InnerModule = Param<RunningState<Tensor<B::InnerBackend, D>>>;
|
||||
impl<const D: usize, B: ADBackend> ADModule<B> for RunningState<Tensor<B, D>> {
|
||||
type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.sync();
|
||||
let value = self.value.value();
|
||||
let value = self.value();
|
||||
|
||||
Param {
|
||||
id: self.id,
|
||||
value: RunningState::new(value.inner()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.sync();
|
||||
|
||||
let value = module.value.value();
|
||||
|
||||
Param {
|
||||
id: module.id,
|
||||
value: RunningState::new(Tensor::from_inner(value)),
|
||||
}
|
||||
RunningState::with_id(self.id.clone(), value.inner())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::{vec, vec::Vec};
|
||||
|
||||
use super::{Param, ParamId};
|
||||
use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor};
|
||||
use crate::tensor::{
|
||||
|
@ -9,45 +7,20 @@ use crate::tensor::{
|
|||
|
||||
impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
|
||||
fn from(value: Tensor<B, D>) -> Self {
|
||||
Param {
|
||||
id: ParamId::new(),
|
||||
value: value.require_grad(),
|
||||
}
|
||||
Param::new(ParamId::new(), value.require_grad())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
self.value.shape().num_elements()
|
||||
}
|
||||
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
vec![self.value.device()]
|
||||
}
|
||||
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
value: self.value.to_device(device).require_grad(),
|
||||
}
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
value: self.value.detach().require_grad(),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit(&self.id, &self.value)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let value = mapper.map(&self.id, self.value).require_grad();
|
||||
Self { id: self.id, value }
|
||||
let value = mapper.map(&self.id, self.value);
|
||||
Self::new(self.id, value)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
|
@ -55,24 +28,63 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
|||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
record.to_device(&self.device())
|
||||
let mut tensor = record.value.detach();
|
||||
let device = self.device();
|
||||
|
||||
// Make sure we load the record into the same module device.
|
||||
if tensor.device() != device {
|
||||
tensor = tensor.to_device(&device).detach();
|
||||
}
|
||||
|
||||
// Make sure we load the record with the same autodiff setting.
|
||||
if self.is_require_grad() {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
|
||||
Self::new(record.id, tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule<B> for Param<Tensor<B, D>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self.value.inner(),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
Param {
|
||||
id: module.id,
|
||||
value: Tensor::from_inner(module.value).require_grad(),
|
||||
}
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
Param::new(
|
||||
self.id.clone(),
|
||||
self.value.clone().inner().set_require_grad(false),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use crate::{
|
||||
record::{NoStdInferenceRecordSettings, Record},
|
||||
TestADBackend,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_load_record_setting() {
|
||||
let tensor = Tensor::<TestADBackend, 2>::ones([3, 3]);
|
||||
let bytes = Param::from(tensor.clone())
|
||||
.into_record()
|
||||
.record::<NoStdInferenceRecordSettings>(())
|
||||
.unwrap();
|
||||
|
||||
let no_grad_is_require_grad = Param::from(tensor.clone())
|
||||
.no_grad()
|
||||
.load_record(Param::load::<NoStdInferenceRecordSettings>(bytes.clone()).unwrap())
|
||||
.value
|
||||
.is_require_grad();
|
||||
|
||||
let with_default_is_require_grad = Param::from(tensor)
|
||||
.load_record(Param::load::<NoStdInferenceRecordSettings>(bytes).unwrap())
|
||||
.value
|
||||
.is_require_grad();
|
||||
|
||||
assert!(!no_grad_is_require_grad);
|
||||
assert!(with_default_is_require_grad);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::nn::cache::TensorCache;
|
||||
|
@ -9,7 +7,6 @@ use crate::{
|
|||
nn,
|
||||
tensor::{activation, backend::Backend, Bool, Tensor},
|
||||
};
|
||||
|
||||
use libm::sqrtf;
|
||||
|
||||
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer.
|
||||
|
@ -257,6 +254,7 @@ pub struct MHAAutoregressiveCache<B: Backend> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
|
||||
use alloc::vec::Vec;
|
||||
use burn::tensor::{Distribution, Shape};
|
||||
use burn_tensor::Int;
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
use burn_tensor::Int;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use super::Initializer;
|
||||
|
@ -9,6 +6,7 @@ use crate::module::Module;
|
|||
use crate::module::Param;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::Int;
|
||||
|
||||
/// Configuration to create an [Embedding](Embedding) layer.
|
||||
#[derive(Config)]
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::{
|
||||
|
@ -28,8 +26,8 @@ pub struct BatchNormConfig {
|
|||
pub struct BatchNorm<B: Backend, const D: usize> {
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
running_mean: Param<RunningState<Tensor<B, 1>>>,
|
||||
running_var: Param<RunningState<Tensor<B, 1>>>,
|
||||
running_mean: RunningState<Tensor<B, 1>>,
|
||||
running_var: RunningState<Tensor<B, 1>>,
|
||||
momentum: f64,
|
||||
epsilon: f64,
|
||||
}
|
||||
|
@ -46,8 +44,8 @@ impl BatchNormConfig {
|
|||
BatchNorm {
|
||||
gamma: Param::from(gamma),
|
||||
beta: Param::from(beta),
|
||||
running_mean: Param::from(RunningState::new(running_mean)),
|
||||
running_var: Param::from(RunningState::new(running_var)),
|
||||
running_mean: RunningState::new(running_mean),
|
||||
running_var: RunningState::new(running_var),
|
||||
momentum: self.momentum,
|
||||
epsilon: self.epsilon,
|
||||
}
|
||||
|
@ -76,8 +74,8 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
|
|||
|
||||
fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
|
||||
let channels = input.dims()[1];
|
||||
let mean = self.running_mean.val().value();
|
||||
let var = self.running_var.val().value();
|
||||
let mean = self.running_mean.value();
|
||||
let var = self.running_var.value();
|
||||
|
||||
let mut shape = [1; DI];
|
||||
shape[1] = channels;
|
||||
|
@ -192,7 +190,7 @@ mod tests_1d {
|
|||
let module = BatchNormConfig::new(3).init::<TestADBackend, 1>();
|
||||
|
||||
module.forward(input_tensor());
|
||||
let module = module.inner();
|
||||
let module = module.valid();
|
||||
let output = module.forward(input_tensor());
|
||||
|
||||
output.to_data().assert_approx_eq(
|
||||
|
@ -247,7 +245,7 @@ mod tests_2d {
|
|||
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
|
||||
|
||||
module.forward(input_tensor());
|
||||
let module = module.inner();
|
||||
let module = module.valid();
|
||||
let output = module.forward(input_tensor());
|
||||
|
||||
output.to_data().assert_approx_eq(
|
||||
|
@ -301,11 +299,9 @@ mod tests_2d {
|
|||
|
||||
let _output = module.forward(input_tensor());
|
||||
|
||||
let module_valid = module.inner();
|
||||
let module_valid = module.valid();
|
||||
let running_mean = module_valid.running_mean.value();
|
||||
|
||||
let module_train = BatchNorm::<TestADBackend, 2>::from_inner(module_valid);
|
||||
let running_mean_after = module_train.running_mean.value();
|
||||
let running_mean_after = module.running_mean.value();
|
||||
|
||||
running_mean_after
|
||||
.into_data()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::{
|
||||
|
|
|
@ -97,9 +97,7 @@ mod tests {
|
|||
fn test_convert_grads() {
|
||||
let layer_1 = layer();
|
||||
let mut layer_2 = layer_1.clone();
|
||||
layer_2 = layer_2
|
||||
.to_device(&<TestADBackend as Backend>::Device::default())
|
||||
.detach();
|
||||
layer_2 = layer_2.fork(&<TestADBackend as Backend>::Device::default());
|
||||
let loss_1 = layer_1.forward(random_tensor());
|
||||
let loss_2 = layer_2.forward(random_tensor());
|
||||
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
|
||||
|
|
|
@ -83,7 +83,9 @@ where
|
|||
|
||||
if let Some(grad) = grad {
|
||||
let device = grad.device();
|
||||
let is_require_grad = tensor.is_require_grad();
|
||||
let (key, record) = self.records.remove_entry(id).unzip();
|
||||
|
||||
let (tensor, state) = self.optimizer.step(
|
||||
tensor.inner(),
|
||||
grad,
|
||||
|
@ -97,7 +99,11 @@ where
|
|||
);
|
||||
}
|
||||
|
||||
return Tensor::from_inner(tensor);
|
||||
let mut tensor = Tensor::from_inner(tensor);
|
||||
if is_require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
tensor
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use alloc::string::String;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec::Vec;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::{Record, RecordSettings};
|
||||
use crate::module::{Param, ParamId};
|
||||
|
@ -87,21 +89,22 @@ impl<E: Element> Record for DataSerialize<E> {
|
|||
}
|
||||
}
|
||||
|
||||
/// (De)serialize parameters into a clean format.
|
||||
#[derive(new, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParamSerde<T> {
|
||||
id: String,
|
||||
param: T,
|
||||
}
|
||||
|
||||
impl<T: Record> Record for Param<T> {
|
||||
type Item<S: RecordSettings> = Param<T::Item<S>>;
|
||||
type Item<S: RecordSettings> = ParamSerde<T::Item<S>>;
|
||||
|
||||
fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self.value.into_item(),
|
||||
}
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item())
|
||||
}
|
||||
|
||||
fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
|
||||
Param {
|
||||
id: item.id,
|
||||
value: T::from_item(item.value),
|
||||
}
|
||||
Param::new(ParamId::from(item.id), T::from_item(item.param))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ use burn::tensor::{Distribution, Shape, Tensor};
|
|||
use burn_core as burn;
|
||||
|
||||
pub type TestBackend = burn_ndarray::NdArrayBackend<f32>;
|
||||
#[cfg(feature = "std")]
|
||||
pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModuleBasic<B: Backend> {
|
||||
|
@ -93,3 +95,53 @@ mod num_params {
|
|||
assert_eq!(2 * 20 * 20, module.num_params());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod require_grad {
|
||||
use burn_tensor::backend::ADBackend;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_have_grad_by_default() {
|
||||
let module = ModuleBasic::<TestADBackend>::new();
|
||||
let mut grads = calculate_grads(&module);
|
||||
|
||||
let grad_x = module.weight_basic.grad_remove(&mut grads);
|
||||
|
||||
assert!(grad_x.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_have_no_grad_after_no_grad() {
|
||||
let module = ModuleBasic::<TestADBackend>::new().no_grad();
|
||||
let mut grads = calculate_grads(&module);
|
||||
|
||||
let grad_x = module.weight_basic.grad_remove(&mut grads);
|
||||
|
||||
assert!(grad_x.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_have_grad_when_from_record() {
|
||||
let module = ModuleBasic::<TestADBackend>::new();
|
||||
let record = ModuleBasicRecord {
|
||||
weight_basic: module.weight_basic.clone(), // Even when param is no_grad,
|
||||
};
|
||||
let module = module.load_record(record);
|
||||
let mut grads = calculate_grads(&module);
|
||||
|
||||
let grad_x = module.weight_basic.grad_remove(&mut grads);
|
||||
|
||||
assert!(grad_x.is_some());
|
||||
}
|
||||
|
||||
fn calculate_grads(
|
||||
module: &ModuleBasic<TestADBackend>,
|
||||
) -> <TestADBackend as ADBackend>::Gradients {
|
||||
let x = Tensor::ones([20, 20]).require_grad();
|
||||
let y = module.weight_basic.val().matmul(x);
|
||||
|
||||
y.backward()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,13 +25,9 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let num_params_fn = generator.gen_num_params_fn();
|
||||
let visit = generator.gen_visit_fn();
|
||||
let map_mut = generator.gen_map_fn();
|
||||
let devices_fn = generator.gen_devices_fn();
|
||||
let to_device_fn = generator.gen_to_device_fn();
|
||||
let inner_fn = generator.gen_inner_fn();
|
||||
let from_inner_fn = generator.gen_from_inner_fn();
|
||||
let valid_fn = generator.gen_valid_fn();
|
||||
let into_record_fn = generator.gen_into_record_fn();
|
||||
let load_record_fn = generator.gen_load_record_fn();
|
||||
let detach_fn = generator.gen_detach_fn();
|
||||
let clone_fn = generator.gen_clone_fn();
|
||||
let generics_names_except_backend = generics_names_except_backend(&ast.generics);
|
||||
|
||||
|
@ -45,14 +41,10 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
impl #generics burn::module::Module<B> for #name #generics_ty #generics_where {
|
||||
type Record = #record_name #generics_ty;
|
||||
|
||||
#devices_fn
|
||||
#to_device_fn
|
||||
|
||||
#load_record_fn
|
||||
#into_record_fn
|
||||
|
||||
#num_params_fn
|
||||
#detach_fn
|
||||
|
||||
#visit
|
||||
#map_mut
|
||||
|
@ -61,8 +53,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
impl #generics burn::module::ADModule<B> for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
|
||||
type InnerModule=#name<B::InnerBackend, #generics_names_except_backend>;
|
||||
|
||||
#inner_fn
|
||||
#from_inner_fn
|
||||
#valid_fn
|
||||
}
|
||||
|
||||
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
|
||||
|
|
|
@ -96,68 +96,15 @@ impl FnGenerator {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn gen_devices_fn(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
devices.append(&mut burn::module::Module::<B>::devices(&self.#name));
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
let mut devices = Vec::new();
|
||||
#body
|
||||
devices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_to_device_fn(&self) -> TokenStream {
|
||||
pub fn gen_valid_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::to_device(self.#name, device);
|
||||
let #name = burn::module::ADModule::<B>::valid(&self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_detach_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::detach(self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn detach(self) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_inner_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::ADModule::<B>::inner(self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
#body
|
||||
|
||||
Self::InnerModule {
|
||||
|
@ -167,24 +114,6 @@ impl FnGenerator {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn gen_from_inner_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::ADModule::<B>::from_inner(module.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_clone_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
|
|
|
@ -16,7 +16,7 @@ use burn_common::stub::Mutex;
|
|||
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NdArrayDevice {
|
||||
Cpu,
|
||||
}
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
// Orginally copied from the burn/examples/mnist package
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
// Orginally copied from the burn/examples/mnist package
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use crate::{
|
||||
conv::{ConvBlock, ConvBlockConfig},
|
||||
mlp::{Mlp, MlpConfig},
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::element::TchElement;
|
|||
use super::TchTensor;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
/// The device struct when using the `tch` backend.
|
||||
///
|
||||
/// Note that you need to provide the device index when using Cuda.
|
||||
|
|
|
@ -305,7 +305,20 @@ where
|
|||
/// Mark the tensor to keep gradients during the backward pass.
|
||||
/// This function does nothing when autodiff is not enabled.
|
||||
pub fn require_grad(self) -> Self {
|
||||
Self::new(B::require_grad(self.primitive))
|
||||
self.set_require_grad(true)
|
||||
}
|
||||
|
||||
/// Returns true if the tensor requires gradients during the backward pass.
|
||||
pub fn is_require_grad(&self) -> bool {
|
||||
B::is_require_grad(&self.primitive)
|
||||
}
|
||||
|
||||
/// Mark the tensor as tracked or untracked depending on the require grad argument.
|
||||
/// When tracked, the gradients will be available after the backward pass.
|
||||
///
|
||||
/// This function does nothing when autodiff is not enabled.
|
||||
pub fn set_require_grad(self, require_grad: bool) -> Self {
|
||||
Self::new(B::set_require_grad(self.primitive, require_grad))
|
||||
}
|
||||
|
||||
/// Applies the relu function to the tensor.
|
||||
|
|
|
@ -63,7 +63,7 @@ pub trait Backend:
|
|||
+ 'static
|
||||
{
|
||||
/// Device type.
|
||||
type Device: Clone + Default + core::fmt::Debug + Send + Sync;
|
||||
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
|
||||
|
||||
/// Pointer to another backend that have a full precision float element type
|
||||
type FullPrecisionBackend: Backend<FloatElem = Self::FullPrecisionElem, Device = Self::Device>;
|
||||
|
|
|
@ -193,10 +193,17 @@ pub trait TensorOps<B: Backend> {
|
|||
// Should only be overriden by autodiff backends.
|
||||
tensor
|
||||
}
|
||||
fn require_grad<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
fn set_require_grad<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
_require_grad: bool,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
// Should only be overriden by autodiff backends.
|
||||
tensor
|
||||
}
|
||||
fn is_require_grad<const D: usize>(_tensor: &B::TensorPrimitive<D>) -> bool {
|
||||
// Should only be overriden by autodiff backends.
|
||||
false
|
||||
}
|
||||
fn sum<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
|
||||
fn sum_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D>;
|
||||
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
|
||||
|
|
|
@ -192,7 +192,6 @@ where
|
|||
}
|
||||
None => None,
|
||||
};
|
||||
let model = model.detach();
|
||||
|
||||
Learner {
|
||||
model,
|
||||
|
|
|
@ -24,14 +24,14 @@ pub struct TrainEpoch<TI> {
|
|||
}
|
||||
|
||||
impl<I> ValidEpoch<I> {
|
||||
pub fn run<B, M, TO, VO>(&self, model: M, callback: &mut Box<dyn LearnerCallback<TO, VO>>) -> M
|
||||
pub fn run<B, M, TO, VO>(&self, model: &M, callback: &mut Box<dyn LearnerCallback<TO, VO>>)
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
M::InnerModule: ValidStep<I, VO>,
|
||||
{
|
||||
log::info!("Executing validation step for epoch {}", self.epoch);
|
||||
let model = model.inner();
|
||||
let model = model.valid();
|
||||
|
||||
let mut iterator = self.dataloader.iter();
|
||||
let mut iteration = 0;
|
||||
|
@ -50,8 +50,6 @@ impl<I> ValidEpoch<I> {
|
|||
));
|
||||
}
|
||||
callback.on_valid_end_epoch(self.epoch);
|
||||
|
||||
ADModule::from_inner(model)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,6 +75,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
|
||||
while let Some(item) = iterator.next() {
|
||||
iteration += 1;
|
||||
log::info!("Iteration {}", iteration);
|
||||
|
||||
let progress = iterator.progress();
|
||||
let item = model.step(item);
|
||||
|
@ -154,7 +153,6 @@ impl<TI> TrainEpoch<TI> {
|
|||
|
||||
let grads = item.grads.to_device(&device_main, &model);
|
||||
|
||||
log::info!("Updated device");
|
||||
accumulator.accumulate(&model, grads);
|
||||
accumulation_current += 1;
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ where
|
|||
spawn(move || loop {
|
||||
match receiver_input.recv() {
|
||||
Ok(item) => {
|
||||
let step = item.model.to_device(&device).detach();
|
||||
let step = item.model.fork(&device);
|
||||
let output = step.step(item.item);
|
||||
|
||||
sender_output.send(output).unwrap();
|
||||
|
|
|
@ -49,7 +49,7 @@ where
|
|||
log::info!("Fitting {}", self.model.to_string());
|
||||
// The reference model is always on the first device provided.
|
||||
if let Some(device) = self.devices.get(0) {
|
||||
self.model = self.model.to_device(device).detach();
|
||||
self.model = self.model.fork(device);
|
||||
}
|
||||
|
||||
let starting_epoch = match self.checkpoint {
|
||||
|
@ -83,7 +83,7 @@ where
|
|||
}
|
||||
|
||||
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
|
||||
model = epoch_valid.run(model, &mut self.callback);
|
||||
epoch_valid.run(&model, &mut self.callback);
|
||||
|
||||
Self::checkpoint(
|
||||
&model,
|
||||
|
|
Binary file not shown.
|
@ -2,8 +2,6 @@
|
|||
|
||||
// Orginally copied from the burn/examples/mnist package
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, conv::Conv2dPaddingConfig, BatchNorm},
|
||||
|
@ -15,6 +13,7 @@ pub struct Model<B: Backend> {
|
|||
conv1: ConvBlock<B>,
|
||||
conv2: ConvBlock<B>,
|
||||
conv3: ConvBlock<B>,
|
||||
dropout: nn::Dropout,
|
||||
fc1: nn::Linear<B>,
|
||||
fc2: nn::Linear<B>,
|
||||
activation: nn::GELU,
|
||||
|
@ -27,7 +26,6 @@ impl<B: Backend> Model<B> {
|
|||
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26]
|
||||
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24]
|
||||
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22]
|
||||
|
||||
let hidden_size = 24 * 22 * 22;
|
||||
let fc1 = nn::LinearConfig::new(hidden_size, 32)
|
||||
.with_bias(false)
|
||||
|
@ -36,12 +34,15 @@ impl<B: Backend> Model<B> {
|
|||
.with_bias(false)
|
||||
.init();
|
||||
|
||||
let dropout = nn::DropoutConfig::new(0.5).init();
|
||||
|
||||
Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
fc1,
|
||||
fc2,
|
||||
dropout,
|
||||
activation: nn::GELU::new(),
|
||||
}
|
||||
}
|
||||
|
@ -57,6 +58,7 @@ impl<B: Backend> Model<B> {
|
|||
let [batch_size, channels, heigth, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * heigth * width]);
|
||||
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.fc1.forward(x);
|
||||
let x = self.activation.forward(x);
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ impl<B: Backend> Model<B> {
|
|||
.with_bias(false)
|
||||
.init();
|
||||
|
||||
let dropout = nn::DropoutConfig::new(0.3).init();
|
||||
let dropout = nn::DropoutConfig::new(0.5).init();
|
||||
|
||||
Self {
|
||||
conv1,
|
||||
|
@ -60,9 +60,9 @@ impl<B: Backend> Model<B> {
|
|||
let [batch_size, channels, heigth, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * heigth * width]);
|
||||
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.fc1.forward(x);
|
||||
let x = self.activation.forward(x);
|
||||
let x = self.dropout.forward(x);
|
||||
|
||||
self.fc2.forward(x)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist";
|
|||
|
||||
#[derive(Config)]
|
||||
pub struct MnistTrainingConfig {
|
||||
#[config(default = 4)]
|
||||
#[config(default = 10)]
|
||||
pub num_epochs: usize,
|
||||
|
||||
#[config(default = 64)]
|
||||
|
|
|
@ -42,7 +42,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
|
|||
let record = Record::load::<DefaultRecordSettings>(format!("{artifact_dir}/model").into())
|
||||
.expect("Trained model weights");
|
||||
let model = model.load_record(record);
|
||||
let model = model.to_device(&device);
|
||||
let model = model.fork(&device);
|
||||
|
||||
println!("Running inference ...");
|
||||
let item = batcher.batch(samples.clone());
|
||||
|
|
Loading…
Reference in New Issue