refactor: autodiff gradients types (#107)

This commit is contained in:
Nathaniel Simard 2022-11-19 19:43:49 -05:00 committed by GitHub
parent dda067e79b
commit ca94a9f105
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 24 additions and 28 deletions

View File

@ -17,11 +17,5 @@ export_tests = ["burn-tensor-testgen"]
[dependencies]
burn-tensor-testgen = { path = "../burn-tensor-testgen", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.2.3" }
libm = "0.2"
derive-new = "0.5"
rand = "0.8"
num-traits = "0.2"
nanoid = "0.4"
[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.2.3", features = ["export_tests"] }

View File

@ -2,7 +2,11 @@ use crate::graph::{
node::{BackwardNode, ForwardNode},
traversal::{BreadthFirstSearch, GraphTraversal},
};
use burn_tensor::{backend::Gradients, ops::Zeros};
use burn_tensor::{
backend::{ADBackend, Gradients},
ops::Zeros,
Tensor,
};
use std::{any::Any, collections::HashMap, ops::Add};
#[derive(Default)]
@ -10,13 +14,13 @@ pub struct Grads {
grads: HashMap<String, Box<dyn Any + Send + Sync>>,
}
impl Gradients for Grads {
impl<B: ADBackend> Gradients<B> for Grads {
fn empty() -> Self {
Self {
grads: HashMap::new(),
}
}
fn get<V: 'static>(&self, id: &str) -> Option<&V> {
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>> {
let grad = match self.grads.get(id) {
Some(grad) => grad,
None => return None,
@ -24,16 +28,13 @@ impl Gradients for Grads {
grad.downcast_ref()
}
fn register<V>(&mut self, id: String, value: V)
where
V: std::fmt::Debug + 'static + Send + Sync,
{
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>) {
self.grads.insert(id, Box::new(value));
}
}
impl Grads {
pub fn register<T>(&mut self, node: &BackwardNode<T>)
pub fn register_node<T>(&mut self, node: &BackwardNode<T>)
where
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
@ -42,14 +43,14 @@ impl Grads {
self.grads.insert(node.id.clone(), Box::new(grad));
}
pub fn from<T>(node: &BackwardNode<T>) -> Self
pub fn from_node<T>(node: &BackwardNode<T>) -> Self
where
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
{
let mut grads = Self::empty();
let mut grads = Self::default();
let traversal = BreadthFirstSearch::new(node);
grads.register(node);
grads.register_node(node);
traversal.traverse(|node| {
node.register_grad(&mut grads);

View File

@ -66,7 +66,7 @@ where
}
}
Grads::from(self)
Grads::from_node(self)
}
}
@ -89,6 +89,6 @@ where
&self.id
}
fn register_grad(&self, grads: &mut Grads) {
grads.register(self)
grads.register_node(self)
}
}

View File

@ -67,7 +67,7 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
pub trait ADBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, Elem = Self::Elem>;
type Gradients: Gradients;
type Gradients: Gradients<Self>;
fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Self::Gradients;
fn grad<const D: usize>(

View File

@ -1,7 +1,8 @@
pub trait Gradients: Send + Sync {
use crate::backend::ADBackend;
use crate::Tensor;
pub trait Gradients<B: ADBackend>: Send + Sync {
fn empty() -> Self;
fn get<V: 'static>(&self, id: &str) -> Option<&V>;
fn register<V>(&mut self, id: String, value: V)
where
V: std::fmt::Debug + 'static + Send + Sync;
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>>;
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>);
}

View File

@ -86,7 +86,7 @@ pub(super) fn register_state_gradients<const D: usize, B: ADBackend, F: Fn(&str)
) {
let id = id.to_string();
if let Some(velocity) = grads.get::<Tensor<B::InnerBackend, D>>(&id) {
if let Some(velocity) = grads.get::<D>(&id) {
let data = State::Data(velocity.to_data().serialize());
state.register_state(id_to_key(&id).as_str(), data);
};

View File

@ -35,7 +35,7 @@ impl<B: ADBackend> WeightDecay<B> {
) -> Tensor<B::InnerBackend, D> {
let id = id.to_string();
let grad = match self.gradients.get::<Tensor<B::InnerBackend, D>>(&id) {
let grad = match self.gradients.get::<D>(&id) {
Some(grad_last_step) => grad_last_step.mul_scalar(self.penalty).add(&grad),
None => grad,
};

View File

@ -46,7 +46,7 @@ impl<B: ADBackend> Momentum<B> {
) -> Tensor<B::InnerBackend, D> {
let id = id.to_string();
let velocity = match self.velocity.get::<Tensor<B::InnerBackend, D>>(&id) {
let velocity = match self.velocity.get::<D>(&id) {
Some(grad_last_step) => grad
.mul_scalar(1.0 - self.dampening)
.add(&grad_last_step.mul_scalar(self.momentum)),