mirror of https://github.com/tracel-ai/burn.git
fix: backward order
This commit is contained in:
parent
2f7f65cea5
commit
230cd01ea1
|
@ -1,19 +1,36 @@
|
||||||
use super::NodeStateRef;
|
use super::NodeStateRef;
|
||||||
use crate::ops::{RecordedOpsParent, RecordedOpsParentRef, RecordedOpsRef};
|
use crate::ops::{RecordedOpsParent, RecordedOpsParentRef, RecordedOpsRef};
|
||||||
use std::{collections::HashSet, ops::Add, rc::Rc};
|
use std::{collections::HashMap, ops::Add, rc::Rc};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Node<Out> {
|
pub struct Node<Out> {
|
||||||
pub id: String,
|
pub order: usize,
|
||||||
pub state: NodeStateRef<Out>,
|
pub state: NodeStateRef<Out>,
|
||||||
pub ops: RecordedOpsRef<Out>,
|
pub ops: RecordedOpsRef<Out>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Out> Node<Out> {
|
impl<Out> Node<Out> {
|
||||||
pub fn new(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
|
pub fn from_root(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
|
||||||
let id = nanoid::nanoid!();
|
let order = 0;
|
||||||
println!("Creating node {}", id);
|
Self { order, state, ops }
|
||||||
Self { id, state, ops }
|
}
|
||||||
|
|
||||||
|
pub fn from_unary<T>(
|
||||||
|
node: &Node<T>,
|
||||||
|
state: NodeStateRef<Out>,
|
||||||
|
ops: RecordedOpsRef<Out>,
|
||||||
|
) -> Self {
|
||||||
|
let order = node.order + 1;
|
||||||
|
Self { order, state, ops }
|
||||||
|
}
|
||||||
|
pub fn from_binary<Lhs, Rhs>(
|
||||||
|
lhs: &Node<Lhs>,
|
||||||
|
rhs: &Node<Rhs>,
|
||||||
|
state: NodeStateRef<Out>,
|
||||||
|
ops: RecordedOpsRef<Out>,
|
||||||
|
) -> Self {
|
||||||
|
let order = usize::max(lhs.order, rhs.order) + 1;
|
||||||
|
Self { order, state, ops }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,47 +41,55 @@ where
|
||||||
{
|
{
|
||||||
pub fn backward(&self) {
|
pub fn backward(&self) {
|
||||||
let grad = self.state.borrow().value().ones();
|
let grad = self.state.borrow().value().ones();
|
||||||
|
|
||||||
self.state.borrow_mut().update_grad(grad);
|
self.state.borrow_mut().update_grad(grad);
|
||||||
self.ops.backward_step(&self.state);
|
self.ops.backward_step(&self.state);
|
||||||
|
|
||||||
|
let mut nodes = HashMap::new();
|
||||||
let mut parents = self.ops.backward_parents();
|
let mut parents = self.ops.backward_parents();
|
||||||
let mut visited = HashSet::new();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match parents.pop() {
|
match parents.pop() {
|
||||||
Some(node) => {
|
Some(node) => {
|
||||||
let id = node.id();
|
let id = node.id();
|
||||||
if visited.contains(&id) {
|
|
||||||
|
if id == 0 {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
visited.insert(id);
|
if nodes.contains_key(&id) {
|
||||||
node.backward_step();
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for parent in node.backward_parents() {
|
for parent in node.backward_parents() {
|
||||||
if !visited.contains(&parent.id()) {
|
if !nodes.contains_key(&parent.id()) {
|
||||||
parents.push(parent);
|
parents.push(parent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
nodes.insert(id, node);
|
||||||
}
|
}
|
||||||
None => break,
|
None => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i in (0..self.order + 1).rev() {
|
||||||
|
if let Some(node) = nodes.get(&i) {
|
||||||
|
node.backward_step();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: std::fmt::Debug> RecordedOpsParent for Node<T> {
|
impl<T: std::fmt::Debug> RecordedOpsParent for Node<T> {
|
||||||
fn backward_step(&self) {
|
fn backward_step(&self) {
|
||||||
println!("backward node {}", self.id);
|
println!("backward node {}", self.order);
|
||||||
self.ops.backward_step(&self.state)
|
self.ops.backward_step(&self.state)
|
||||||
}
|
}
|
||||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||||
self.ops.backward_parents()
|
self.ops.backward_parents()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id(&self) -> String {
|
fn id(&self) -> usize {
|
||||||
self.id.clone()
|
self.order
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ pub trait RecordedOps<T>: std::fmt::Debug {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait RecordedOpsParent: std::fmt::Debug {
|
pub trait RecordedOpsParent: std::fmt::Debug {
|
||||||
fn id(&self) -> String;
|
fn id(&self) -> usize;
|
||||||
fn backward_step(&self);
|
fn backward_step(&self);
|
||||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef>;
|
fn backward_parents(&self) -> Vec<RecordedOpsParentRef>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,7 @@ macro_rules! execute_ops {
|
||||||
let ops = BinaryRecordedOps::new($lhs, $rhs, ops);
|
let ops = BinaryRecordedOps::new($lhs, $rhs, ops);
|
||||||
let ops = std::rc::Rc::new(ops);
|
let ops = std::rc::Rc::new(ops);
|
||||||
|
|
||||||
let node = $crate::node::Node::new(state, ops);
|
let node = $crate::node::Node::from_binary(&$lhs, &$rhs, state, ops);
|
||||||
std::rc::Rc::new(node)
|
std::rc::Rc::new(node)
|
||||||
};
|
};
|
||||||
callback()
|
callback()
|
||||||
|
@ -139,7 +139,7 @@ macro_rules! execute_ops {
|
||||||
let ops = UnaryRecordedOps::new($input, ops);
|
let ops = UnaryRecordedOps::new($input, ops);
|
||||||
let ops = std::rc::Rc::new(ops);
|
let ops = std::rc::Rc::new(ops);
|
||||||
|
|
||||||
let node = $crate::node::Node::new(state, ops);
|
let node = $crate::node::Node::from_unary(&$input, state, ops);
|
||||||
std::rc::Rc::new(node)
|
std::rc::Rc::new(node)
|
||||||
};
|
};
|
||||||
callback()
|
callback()
|
||||||
|
|
|
@ -41,7 +41,7 @@ where
|
||||||
|
|
||||||
let ops = InitRecordedOps::new();
|
let ops = InitRecordedOps::new();
|
||||||
let ops = Rc::new(ops);
|
let ops = Rc::new(ops);
|
||||||
let node = Rc::new(Node::new(state, ops));
|
let node = Rc::new(Node::from_root(state, ops));
|
||||||
|
|
||||||
Self { node, shape, kind }
|
Self { node, shape, kind }
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue