feat: only use gradients struct for getting grad

This commit is contained in:
nathaniel 2022-07-25 11:28:42 -04:00
parent e430a62795
commit f9cbcd4db4
5 changed files with 67 additions and 77 deletions

View File

@ -90,13 +90,13 @@ mod tests {
let tensor_2 = ADTchTensor::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() + tensor_2.clone();
tensor_3.backward();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0]));
assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0]));
}
@ -106,10 +106,11 @@ mod tests {
let tensor = ADTchTensor::from_data(data.clone());
let tensor_out = tensor.clone() + 5.0;
tensor_out.backward();
let grads = tensor_out.backward();
let grad = tensor.grad();
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
let grad = grads.wrt(&tensor).unwrap();
assert_eq!(grad.to_data(), Data::from([1.0, 1.0]));
assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0]));
}
@ -131,12 +132,12 @@ mod tests {
.add(&tensor_2);
let tensor_6 = tensor_1.add(&tensor_5);
tensor_6.backward();
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), Data::from([[3.0, 3.0], [3.0, 3.0]]));
assert_eq!(grad_2.into_data(), Data::from([[2.0, 2.0], [2.0, 2.0]]));
assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]]));
assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]]));
}
}

View File

@ -51,13 +51,13 @@ mod tests {
let tensor_2 = ADTchTensor::from_data(data_2.clone());
let tensor_3 = &tensor_1.matmul(&tensor_2);
tensor_3.backward();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
assert_eq!(grad_2.into_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
assert_eq!(
tensor_3.clone().into_data(),
Data::from([[18.0, 28.0], [14.0, 23.0]])
@ -77,13 +77,13 @@ mod tests {
let tensor_4 = tensor_1.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);
tensor_5.backward();
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
assert_eq!(grad_2.into_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
}
#[test]
fn test_matmul_complex_2() {
@ -99,17 +99,17 @@ mod tests {
let tensor_5 = tensor_4.matmul(&tensor_3);
let tensor_6 = tensor_1.matmul(&tensor_5);
tensor_6.backward();
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(
grad_1.into_data(),
grad_1.to_data(),
Data::from([[800.0, 792.0], [360.0, 592.0]])
);
assert_eq!(
grad_2.into_data(),
grad_2.to_data(),
Data::from([[264., 264.0], [344.0, 344.0]])
);
}

View File

@ -90,13 +90,13 @@ mod tests {
let tensor_2 = ADTchTensor::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() * tensor_2.clone();
tensor_3.backward();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), data_2);
assert_eq!(grad_2.into_data(), data_1);
assert_eq!(grad_1.to_data(), data_2);
assert_eq!(grad_2.to_data(), data_1);
assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0]));
}
@ -106,11 +106,12 @@ mod tests {
let tensor = ADTchTensor::from_data(data.clone());
let tensor_out = tensor.clone() * 4.0;
tensor_out.backward();
let grad = tensor.grad();
let grads = tensor_out.backward();
let grad = grads.wrt(&tensor).unwrap();
assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0]));
assert_eq!(grad.into_data(), Data::from([4.0, 4.0]));
assert_eq!(grad.to_data(), Data::from([4.0, 4.0]));
}
#[test]
@ -127,15 +128,15 @@ mod tests {
let tensor_5 = tensor_4.mul(&tensor_3);
let tensor_6 = tensor_1.mul(&tensor_5);
tensor_6.backward();
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(
grad_1.into_data(),
grad_1.to_data(),
Data::from([[16.0, 196.0], [104.0, -36.0]])
);
assert_eq!(grad_2.into_data(), Data::from([[2.0, 98.0], [338.0, 18.0]]));
assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]]));
}
}

View File

@ -90,13 +90,13 @@ mod tests {
let tensor_2 = ADTchTensor::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() - tensor_2.clone();
tensor_3.backward();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0]));
assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0]));
assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0]));
}
@ -105,10 +105,11 @@ mod tests {
let data = Data::from([2.0, 10.0]);
let tensor = ADTchTensor::from_data(data.clone());
let tensor_out = tensor.clone() - 5.0;
tensor_out.backward();
let grads = tensor_out.backward();
let grad = tensor.grad();
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
let grad = grads.wrt(&tensor).unwrap();
assert_eq!(grad.to_data(), Data::from([1.0, 1.0]));
assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0]));
}
@ -126,12 +127,12 @@ mod tests {
let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(&5.0);
let tensor_6 = tensor_1.sub(&tensor_5);
tensor_6.backward();
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(grad_1.into_data(), Data::from([[0.0, 0.0], [0.0, 0.0]]));
assert_eq!(grad_2.into_data(), Data::from([[1.0, 1.0], [1.0, 1.0]]));
assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]]));
assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]]));
}
}

View File

@ -15,16 +15,6 @@ where
}
}
impl<T, P, const D: usize> ADTensor<P, D, T>
where
T: Zeros<T> + Clone + Add<Output = T>,
T: std::fmt::Debug,
{
pub fn grad(&self) -> T {
self.node.state.borrow_mut().grad()
}
}
impl<T, P, const D: usize> AsNode<T> for ADTensor<P, D, T> {
fn as_node(&self) -> &crate::node::Node<T> {
&self.node
@ -50,17 +40,17 @@ mod tests {
let tensor_4 = tensor_3.matmul(&tensor_1);
let tensor_5 = tensor_4.mul(&tensor_2);
tensor_5.backward();
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(
grad_1.into_data(),
grad_1.to_data(),
Data::from([[593., 463.0], [487.0, 539.0]])
);
assert_eq!(
grad_2.into_data(),
grad_2.to_data(),
Data::from([[734.0, 294.0], [1414.0, 242.0]])
);
}
@ -77,19 +67,16 @@ mod tests {
let tensor_4 = tensor_3.matmul(&tensor_1);
let tensor_5 = tensor_4.add_scalar(&17.0).add(&tensor_2);
tensor_5.backward();
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
let grad_1 = grads.wrt(&tensor_1).unwrap();
let grad_2 = grads.wrt(&tensor_2).unwrap();
assert_eq!(
grad_1.into_data(),
grad_1.to_data(),
Data::from([[166.0, 110.0], [212.0, 156.0]])
);
assert_eq!(
grad_2.into_data(),
Data::from([[113.0, 141.0], [33.0, 41.0]])
);
assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]]));
}
#[test]