mirror of https://github.com/tracel-ai/burn.git
refactor: autodiff gradients types (#107)
This commit is contained in:
parent
dda067e79b
commit
ca94a9f105
|
@ -17,11 +17,5 @@ export_tests = ["burn-tensor-testgen"]
|
|||
[dependencies]
|
||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", optional = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.2.3" }
|
||||
libm = "0.2"
|
||||
derive-new = "0.5"
|
||||
rand = "0.8"
|
||||
num-traits = "0.2"
|
||||
nanoid = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.2.3", features = ["export_tests"] }
|
||||
|
|
|
@ -2,7 +2,11 @@ use crate::graph::{
|
|||
node::{BackwardNode, ForwardNode},
|
||||
traversal::{BreadthFirstSearch, GraphTraversal},
|
||||
};
|
||||
use burn_tensor::{backend::Gradients, ops::Zeros};
|
||||
use burn_tensor::{
|
||||
backend::{ADBackend, Gradients},
|
||||
ops::Zeros,
|
||||
Tensor,
|
||||
};
|
||||
use std::{any::Any, collections::HashMap, ops::Add};
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -10,13 +14,13 @@ pub struct Grads {
|
|||
grads: HashMap<String, Box<dyn Any + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl Gradients for Grads {
|
||||
impl<B: ADBackend> Gradients<B> for Grads {
|
||||
fn empty() -> Self {
|
||||
Self {
|
||||
grads: HashMap::new(),
|
||||
}
|
||||
}
|
||||
fn get<V: 'static>(&self, id: &str) -> Option<&V> {
|
||||
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>> {
|
||||
let grad = match self.grads.get(id) {
|
||||
Some(grad) => grad,
|
||||
None => return None,
|
||||
|
@ -24,16 +28,13 @@ impl Gradients for Grads {
|
|||
|
||||
grad.downcast_ref()
|
||||
}
|
||||
fn register<V>(&mut self, id: String, value: V)
|
||||
where
|
||||
V: std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>) {
|
||||
self.grads.insert(id, Box::new(value));
|
||||
}
|
||||
}
|
||||
|
||||
impl Grads {
|
||||
pub fn register<T>(&mut self, node: &BackwardNode<T>)
|
||||
pub fn register_node<T>(&mut self, node: &BackwardNode<T>)
|
||||
where
|
||||
T: Zeros + Clone + Add<Output = T>,
|
||||
T: std::fmt::Debug + 'static + Send + Sync,
|
||||
|
@ -42,14 +43,14 @@ impl Grads {
|
|||
self.grads.insert(node.id.clone(), Box::new(grad));
|
||||
}
|
||||
|
||||
pub fn from<T>(node: &BackwardNode<T>) -> Self
|
||||
pub fn from_node<T>(node: &BackwardNode<T>) -> Self
|
||||
where
|
||||
T: Zeros + Clone + Add<Output = T>,
|
||||
T: std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
let mut grads = Self::empty();
|
||||
let mut grads = Self::default();
|
||||
let traversal = BreadthFirstSearch::new(node);
|
||||
grads.register(node);
|
||||
grads.register_node(node);
|
||||
|
||||
traversal.traverse(|node| {
|
||||
node.register_grad(&mut grads);
|
||||
|
|
|
@ -66,7 +66,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
Grads::from(self)
|
||||
Grads::from_node(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,6 +89,6 @@ where
|
|||
&self.id
|
||||
}
|
||||
fn register_grad(&self, grads: &mut Grads) {
|
||||
grads.register(self)
|
||||
grads.register_node(self)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
|
|||
|
||||
pub trait ADBackend: Backend {
|
||||
type InnerBackend: Backend<Device = Self::Device, Elem = Self::Elem>;
|
||||
type Gradients: Gradients;
|
||||
type Gradients: Gradients<Self>;
|
||||
|
||||
fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Self::Gradients;
|
||||
fn grad<const D: usize>(
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
pub trait Gradients: Send + Sync {
|
||||
use crate::backend::ADBackend;
|
||||
use crate::Tensor;
|
||||
|
||||
pub trait Gradients<B: ADBackend>: Send + Sync {
|
||||
fn empty() -> Self;
|
||||
fn get<V: 'static>(&self, id: &str) -> Option<&V>;
|
||||
fn register<V>(&mut self, id: String, value: V)
|
||||
where
|
||||
V: std::fmt::Debug + 'static + Send + Sync;
|
||||
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>>;
|
||||
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>);
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ pub(super) fn register_state_gradients<const D: usize, B: ADBackend, F: Fn(&str)
|
|||
) {
|
||||
let id = id.to_string();
|
||||
|
||||
if let Some(velocity) = grads.get::<Tensor<B::InnerBackend, D>>(&id) {
|
||||
if let Some(velocity) = grads.get::<D>(&id) {
|
||||
let data = State::Data(velocity.to_data().serialize());
|
||||
state.register_state(id_to_key(&id).as_str(), data);
|
||||
};
|
||||
|
|
|
@ -35,7 +35,7 @@ impl<B: ADBackend> WeightDecay<B> {
|
|||
) -> Tensor<B::InnerBackend, D> {
|
||||
let id = id.to_string();
|
||||
|
||||
let grad = match self.gradients.get::<Tensor<B::InnerBackend, D>>(&id) {
|
||||
let grad = match self.gradients.get::<D>(&id) {
|
||||
Some(grad_last_step) => grad_last_step.mul_scalar(self.penalty).add(&grad),
|
||||
None => grad,
|
||||
};
|
||||
|
|
|
@ -46,7 +46,7 @@ impl<B: ADBackend> Momentum<B> {
|
|||
) -> Tensor<B::InnerBackend, D> {
|
||||
let id = id.to_string();
|
||||
|
||||
let velocity = match self.velocity.get::<Tensor<B::InnerBackend, D>>(&id) {
|
||||
let velocity = match self.velocity.get::<D>(&id) {
|
||||
Some(grad_last_step) => grad
|
||||
.mul_scalar(1.0 - self.dampening)
|
||||
.add(&grad_last_step.mul_scalar(self.momentum)),
|
||||
|
|
Loading…
Reference in New Issue