fix: backward order

This commit is contained in:
nathaniel 2022-07-21 07:39:47 -04:00
parent 2f7f65cea5
commit 230cd01ea1
4 changed files with 44 additions and 19 deletions

View File

@ -1,19 +1,36 @@
use super::NodeStateRef; use super::NodeStateRef;
use crate::ops::{RecordedOpsParent, RecordedOpsParentRef, RecordedOpsRef}; use crate::ops::{RecordedOpsParent, RecordedOpsParentRef, RecordedOpsRef};
use std::{collections::HashSet, ops::Add, rc::Rc}; use std::{collections::HashMap, ops::Add, rc::Rc};
#[derive(Debug)] #[derive(Debug)]
pub struct Node<Out> { pub struct Node<Out> {
pub id: String, pub order: usize,
pub state: NodeStateRef<Out>, pub state: NodeStateRef<Out>,
pub ops: RecordedOpsRef<Out>, pub ops: RecordedOpsRef<Out>,
} }
impl<Out> Node<Out> { impl<Out> Node<Out> {
pub fn new(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self { pub fn from_root(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
let id = nanoid::nanoid!(); let order = 0;
println!("Creating node {}", id); Self { order, state, ops }
Self { id, state, ops } }
pub fn from_unary<T>(
node: &Node<T>,
state: NodeStateRef<Out>,
ops: RecordedOpsRef<Out>,
) -> Self {
let order = node.order + 1;
Self { order, state, ops }
}
pub fn from_binary<Lhs, Rhs>(
lhs: &Node<Lhs>,
rhs: &Node<Rhs>,
state: NodeStateRef<Out>,
ops: RecordedOpsRef<Out>,
) -> Self {
let order = usize::max(lhs.order, rhs.order) + 1;
Self { order, state, ops }
} }
} }
@ -24,47 +41,55 @@ where
{ {
pub fn backward(&self) { pub fn backward(&self) {
let grad = self.state.borrow().value().ones(); let grad = self.state.borrow().value().ones();
self.state.borrow_mut().update_grad(grad); self.state.borrow_mut().update_grad(grad);
self.ops.backward_step(&self.state); self.ops.backward_step(&self.state);
let mut nodes = HashMap::new();
let mut parents = self.ops.backward_parents(); let mut parents = self.ops.backward_parents();
let mut visited = HashSet::new();
loop { loop {
match parents.pop() { match parents.pop() {
Some(node) => { Some(node) => {
let id = node.id(); let id = node.id();
if visited.contains(&id) {
if id == 0 {
continue; continue;
} }
visited.insert(id); if nodes.contains_key(&id) {
node.backward_step(); continue;
}
for parent in node.backward_parents() { for parent in node.backward_parents() {
if !visited.contains(&parent.id()) { if !nodes.contains_key(&parent.id()) {
parents.push(parent); parents.push(parent);
} }
} }
nodes.insert(id, node);
} }
None => break, None => break,
} }
} }
for i in (0..self.order + 1).rev() {
if let Some(node) = nodes.get(&i) {
node.backward_step();
}
}
} }
} }
impl<T: std::fmt::Debug> RecordedOpsParent for Node<T> { impl<T: std::fmt::Debug> RecordedOpsParent for Node<T> {
fn backward_step(&self) { fn backward_step(&self) {
println!("backward node {}", self.id); println!("backward node {}", self.order);
self.ops.backward_step(&self.state) self.ops.backward_step(&self.state)
} }
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> { fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
self.ops.backward_parents() self.ops.backward_parents()
} }
fn id(&self) -> String { fn id(&self) -> usize {
self.id.clone() self.order
} }
} }

View File

@ -20,7 +20,7 @@ pub trait RecordedOps<T>: std::fmt::Debug {
} }
pub trait RecordedOpsParent: std::fmt::Debug { pub trait RecordedOpsParent: std::fmt::Debug {
fn id(&self) -> String; fn id(&self) -> usize;
fn backward_step(&self); fn backward_step(&self);
fn backward_parents(&self) -> Vec<RecordedOpsParentRef>; fn backward_parents(&self) -> Vec<RecordedOpsParentRef>;
} }

View File

@ -121,7 +121,7 @@ macro_rules! execute_ops {
let ops = BinaryRecordedOps::new($lhs, $rhs, ops); let ops = BinaryRecordedOps::new($lhs, $rhs, ops);
let ops = std::rc::Rc::new(ops); let ops = std::rc::Rc::new(ops);
let node = $crate::node::Node::new(state, ops); let node = $crate::node::Node::from_binary(&$lhs, &$rhs, state, ops);
std::rc::Rc::new(node) std::rc::Rc::new(node)
}; };
callback() callback()
@ -139,7 +139,7 @@ macro_rules! execute_ops {
let ops = UnaryRecordedOps::new($input, ops); let ops = UnaryRecordedOps::new($input, ops);
let ops = std::rc::Rc::new(ops); let ops = std::rc::Rc::new(ops);
let node = $crate::node::Node::new(state, ops); let node = $crate::node::Node::from_unary(&$input, state, ops);
std::rc::Rc::new(node) std::rc::Rc::new(node)
}; };
callback() callback()

View File

@ -41,7 +41,7 @@ where
let ops = InitRecordedOps::new(); let ops = InitRecordedOps::new();
let ops = Rc::new(ops); let ops = Rc::new(ops);
let node = Rc::new(Node::new(state, ops)); let node = Rc::new(Node::from_root(state, ops));
Self { node, shape, kind } Self { node, shape, kind }
} }