mirror of https://github.com/tracel-ai/burn.git
fix: backward order
This commit is contained in:
parent
2f7f65cea5
commit
230cd01ea1
|
@ -1,19 +1,36 @@
|
|||
use super::NodeStateRef;
|
||||
use crate::ops::{RecordedOpsParent, RecordedOpsParentRef, RecordedOpsRef};
|
||||
use std::{collections::HashSet, ops::Add, rc::Rc};
|
||||
use std::{collections::HashMap, ops::Add, rc::Rc};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Node<Out> {
|
||||
pub id: String,
|
||||
pub order: usize,
|
||||
pub state: NodeStateRef<Out>,
|
||||
pub ops: RecordedOpsRef<Out>,
|
||||
}
|
||||
|
||||
impl<Out> Node<Out> {
|
||||
pub fn new(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
|
||||
let id = nanoid::nanoid!();
|
||||
println!("Creating node {}", id);
|
||||
Self { id, state, ops }
|
||||
pub fn from_root(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
|
||||
let order = 0;
|
||||
Self { order, state, ops }
|
||||
}
|
||||
|
||||
pub fn from_unary<T>(
|
||||
node: &Node<T>,
|
||||
state: NodeStateRef<Out>,
|
||||
ops: RecordedOpsRef<Out>,
|
||||
) -> Self {
|
||||
let order = node.order + 1;
|
||||
Self { order, state, ops }
|
||||
}
|
||||
pub fn from_binary<Lhs, Rhs>(
|
||||
lhs: &Node<Lhs>,
|
||||
rhs: &Node<Rhs>,
|
||||
state: NodeStateRef<Out>,
|
||||
ops: RecordedOpsRef<Out>,
|
||||
) -> Self {
|
||||
let order = usize::max(lhs.order, rhs.order) + 1;
|
||||
Self { order, state, ops }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,47 +41,55 @@ where
|
|||
{
|
||||
pub fn backward(&self) {
|
||||
let grad = self.state.borrow().value().ones();
|
||||
|
||||
self.state.borrow_mut().update_grad(grad);
|
||||
self.ops.backward_step(&self.state);
|
||||
|
||||
let mut nodes = HashMap::new();
|
||||
let mut parents = self.ops.backward_parents();
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
loop {
|
||||
match parents.pop() {
|
||||
Some(node) => {
|
||||
let id = node.id();
|
||||
if visited.contains(&id) {
|
||||
|
||||
if id == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
visited.insert(id);
|
||||
node.backward_step();
|
||||
if nodes.contains_key(&id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for parent in node.backward_parents() {
|
||||
if !visited.contains(&parent.id()) {
|
||||
if !nodes.contains_key(&parent.id()) {
|
||||
parents.push(parent);
|
||||
}
|
||||
}
|
||||
nodes.insert(id, node);
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
for i in (0..self.order + 1).rev() {
|
||||
if let Some(node) = nodes.get(&i) {
|
||||
node.backward_step();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug> RecordedOpsParent for Node<T> {
|
||||
fn backward_step(&self) {
|
||||
println!("backward node {}", self.id);
|
||||
println!("backward node {}", self.order);
|
||||
self.ops.backward_step(&self.state)
|
||||
}
|
||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||
self.ops.backward_parents()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.id.clone()
|
||||
fn id(&self) -> usize {
|
||||
self.order
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ pub trait RecordedOps<T>: std::fmt::Debug {
|
|||
}
|
||||
|
||||
pub trait RecordedOpsParent: std::fmt::Debug {
|
||||
fn id(&self) -> String;
|
||||
fn id(&self) -> usize;
|
||||
fn backward_step(&self);
|
||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef>;
|
||||
}
|
||||
|
|
|
@ -121,7 +121,7 @@ macro_rules! execute_ops {
|
|||
let ops = BinaryRecordedOps::new($lhs, $rhs, ops);
|
||||
let ops = std::rc::Rc::new(ops);
|
||||
|
||||
let node = $crate::node::Node::new(state, ops);
|
||||
let node = $crate::node::Node::from_binary(&$lhs, &$rhs, state, ops);
|
||||
std::rc::Rc::new(node)
|
||||
};
|
||||
callback()
|
||||
|
@ -139,7 +139,7 @@ macro_rules! execute_ops {
|
|||
let ops = UnaryRecordedOps::new($input, ops);
|
||||
let ops = std::rc::Rc::new(ops);
|
||||
|
||||
let node = $crate::node::Node::new(state, ops);
|
||||
let node = $crate::node::Node::from_unary(&$input, state, ops);
|
||||
std::rc::Rc::new(node)
|
||||
};
|
||||
callback()
|
||||
|
|
|
@ -41,7 +41,7 @@ where
|
|||
|
||||
let ops = InitRecordedOps::new();
|
||||
let ops = Rc::new(ops);
|
||||
let node = Rc::new(Node::new(state, ops));
|
||||
let node = Rc::new(Node::from_root(state, ops));
|
||||
|
||||
Self { node, shape, kind }
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue