diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs index 53ba416d9..fb0df8a0f 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs @@ -90,13 +90,13 @@ mod tests { let tensor_2 = ADTchTensor::from_data(data_2.clone()); let tensor_3 = tensor_1.clone() + tensor_2.clone(); - tensor_3.backward(); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0])); assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0])); } @@ -106,10 +106,11 @@ mod tests { let tensor = ADTchTensor::from_data(data.clone()); let tensor_out = tensor.clone() + 5.0; - tensor_out.backward(); + let grads = tensor_out.backward(); - let grad = tensor.grad(); - assert_eq!(grad.into_data(), Data::from([1.0, 1.0])); + let grad = grads.wrt(&tensor).unwrap(); + + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0])); } @@ -131,12 +132,12 @@ mod tests { .add(&tensor_2); let tensor_6 = tensor_1.add(&tensor_5); - tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), Data::from([[3.0, 3.0], [3.0, 3.0]])); - assert_eq!(grad_2.into_data(), Data::from([[2.0, 2.0], [2.0, 2.0]])); + assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]])); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]])); } } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs index 7b0f76854..cdc8c2f4f 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs @@ -51,13 +51,13 @@ mod tests { let tensor_2 = ADTchTensor::from_data(data_2.clone()); let tensor_3 = &tensor_1.matmul(&tensor_2); - tensor_3.backward(); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.into_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); assert_eq!( tensor_3.clone().into_data(), Data::from([[18.0, 28.0], [14.0, 23.0]]) @@ -77,13 +77,13 @@ mod tests { let tensor_4 = tensor_1.matmul(&tensor_2); let tensor_5 = tensor_4.matmul(&tensor_3); - tensor_5.backward(); + let grads = tensor_5.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); - assert_eq!(grad_2.into_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); + assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); + assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); } #[test] fn test_matmul_complex_2() { @@ -99,17 +99,17 @@ mod tests { let tensor_5 = tensor_4.matmul(&tensor_3); let tensor_6 = tensor_1.matmul(&tensor_5); - tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); assert_eq!( - grad_1.into_data(), + grad_1.to_data(), Data::from([[800.0, 792.0], [360.0, 592.0]]) ); assert_eq!( - grad_2.into_data(), + grad_2.to_data(), Data::from([[264., 264.0], [344.0, 344.0]]) ); } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs index 57cce06d0..b956ef329 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs @@ -90,13 +90,13 @@ mod tests { let tensor_2 = ADTchTensor::from_data(data_2.clone()); let tensor_3 = tensor_1.clone() * tensor_2.clone(); - tensor_3.backward(); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), data_2); - assert_eq!(grad_2.into_data(), data_1); + assert_eq!(grad_1.to_data(), data_2); + assert_eq!(grad_2.to_data(), data_1); assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0])); } @@ -106,11 +106,12 @@ mod tests { let tensor = ADTchTensor::from_data(data.clone()); let tensor_out = tensor.clone() * 4.0; - tensor_out.backward(); - let grad = tensor.grad(); + let grads = tensor_out.backward(); + let grad = grads.wrt(&tensor).unwrap(); + assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0])); - assert_eq!(grad.into_data(), Data::from([4.0, 4.0])); + assert_eq!(grad.to_data(), Data::from([4.0, 4.0])); } #[test] @@ -127,15 +128,15 @@ mod tests { let tensor_5 = tensor_4.mul(&tensor_3); let tensor_6 = tensor_1.mul(&tensor_5); - tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); assert_eq!( - grad_1.into_data(), + grad_1.to_data(), Data::from([[16.0, 196.0], [104.0, -36.0]]) ); - assert_eq!(grad_2.into_data(), Data::from([[2.0, 98.0], [338.0, 18.0]])); + assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]])); } } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs index 24a249a42..2fd943cab 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs @@ -90,13 +90,13 @@ mod tests { let tensor_2 = ADTchTensor::from_data(data_2.clone()); let tensor_3 = tensor_1.clone() - tensor_2.clone(); - tensor_3.backward(); + let grads = tensor_3.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0])); + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); } @@ -105,10 +105,11 @@ mod tests { let data = Data::from([2.0, 10.0]); let tensor = ADTchTensor::from_data(data.clone()); let tensor_out = tensor.clone() - 5.0; - tensor_out.backward(); + let grads = tensor_out.backward(); - let grad = tensor.grad(); - assert_eq!(grad.into_data(), Data::from([1.0, 1.0])); + let grad = grads.wrt(&tensor).unwrap(); + + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); } @@ -126,12 +127,12 @@ mod tests { let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(&5.0); let tensor_6 = tensor_1.sub(&tensor_5); - tensor_6.backward(); + let grads = tensor_6.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); - assert_eq!(grad_1.into_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); - assert_eq!(grad_2.into_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); + assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); + assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); } } diff --git a/burn-tensor/src/tensor/backend/autodiff/tensor/ad.rs b/burn-tensor/src/tensor/backend/autodiff/tensor/ad.rs index c88844be6..35f230d95 100644 --- a/burn-tensor/src/tensor/backend/autodiff/tensor/ad.rs +++ b/burn-tensor/src/tensor/backend/autodiff/tensor/ad.rs @@ -15,16 +15,6 @@ where } } -impl ADTensor -where - T: Zeros + Clone + Add, - T: std::fmt::Debug, -{ - pub fn grad(&self) -> T { - self.node.state.borrow_mut().grad() - } -} - impl AsNode for ADTensor { fn as_node(&self) -> &crate::node::Node { &self.node @@ -50,17 +40,17 @@ mod tests { let tensor_4 = tensor_3.matmul(&tensor_1); let tensor_5 = tensor_4.mul(&tensor_2); - tensor_5.backward(); + let grads = tensor_5.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); assert_eq!( - grad_1.into_data(), + grad_1.to_data(), Data::from([[593., 463.0], [487.0, 539.0]]) ); assert_eq!( - grad_2.into_data(), + grad_2.to_data(), Data::from([[734.0, 294.0], [1414.0, 242.0]]) ); } @@ -77,19 +67,16 @@ mod tests { let tensor_4 = tensor_3.matmul(&tensor_1); let tensor_5 = tensor_4.add_scalar(&17.0).add(&tensor_2); - tensor_5.backward(); + let grads = tensor_5.backward(); - let grad_1 = tensor_1.grad(); - let grad_2 = tensor_2.grad(); + let grad_1 = grads.wrt(&tensor_1).unwrap(); + let grad_2 = grads.wrt(&tensor_2).unwrap(); assert_eq!( - grad_1.into_data(), + grad_1.to_data(), Data::from([[166.0, 110.0], [212.0, 156.0]]) ); - assert_eq!( - grad_2.into_data(), - Data::from([[113.0, 141.0], [33.0, 41.0]]) - ); + assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]])); } #[test]