diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index b5626167e..2b5ad197d 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -1816,8 +1816,7 @@ impl FloatTensorOps for Autodiff ) { let tensor: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(ops.state); - let output = B::float_abs(tensor.clone()); - let state = B::float_div(tensor, output); + let state = B::float_sign(tensor); unary::(ops.parents, ops.node, grads, |grad| { B::float_mul(grad, state) }); diff --git a/crates/burn-autodiff/src/tests/abs.rs b/crates/burn-autodiff/src/tests/abs.rs index 2a7175824..88710bbe1 100644 --- a/crates/burn-autodiff/src/tests/abs.rs +++ b/crates/burn-autodiff/src/tests/abs.rs @@ -25,4 +25,28 @@ mod tests { let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]); grad_2.to_data().assert_approx_eq(&expected, 3); } + + #[test] + fn should_diff_abs_no_nans() { + let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]); + let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]); + + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]); + grad_1.to_data().assert_approx_eq(&expected, 3); + + let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]); + grad_2.to_data().assert_approx_eq(&expected, 3); + + assert_eq!(grad_2.contains_nan().into_scalar(), false); + } }