mirror of https://github.com/tracel-ai/burn.git
feat: only use gradients struct for getting grad
This commit is contained in:
parent
e430a62795
commit
f9cbcd4db4
|
@ -90,13 +90,13 @@ mod tests {
|
|||
let tensor_2 = ADTchTensor::from_data(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_1.clone() + tensor_2.clone();
|
||||
tensor_3.backward();
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0]));
|
||||
}
|
||||
|
||||
|
@ -106,10 +106,11 @@ mod tests {
|
|||
|
||||
let tensor = ADTchTensor::from_data(data.clone());
|
||||
let tensor_out = tensor.clone() + 5.0;
|
||||
tensor_out.backward();
|
||||
let grads = tensor_out.backward();
|
||||
|
||||
let grad = tensor.grad();
|
||||
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
|
||||
let grad = grads.wrt(&tensor).unwrap();
|
||||
|
||||
assert_eq!(grad.to_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0]));
|
||||
}
|
||||
|
||||
|
@ -131,12 +132,12 @@ mod tests {
|
|||
.add(&tensor_2);
|
||||
let tensor_6 = tensor_1.add(&tensor_5);
|
||||
|
||||
tensor_6.backward();
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([[3.0, 3.0], [3.0, 3.0]]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([[2.0, 2.0], [2.0, 2.0]]));
|
||||
assert_eq!(grad_1.to_data(), Data::from([[3.0, 3.0], [3.0, 3.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[2.0, 2.0], [2.0, 2.0]]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,13 +51,13 @@ mod tests {
|
|||
let tensor_2 = ADTchTensor::from_data(data_2.clone());
|
||||
|
||||
let tensor_3 = &tensor_1.matmul(&tensor_2);
|
||||
tensor_3.backward();
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
|
||||
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
|
||||
assert_eq!(
|
||||
tensor_3.clone().into_data(),
|
||||
Data::from([[18.0, 28.0], [14.0, 23.0]])
|
||||
|
@ -77,13 +77,13 @@ mod tests {
|
|||
let tensor_4 = tensor_1.matmul(&tensor_2);
|
||||
let tensor_5 = tensor_4.matmul(&tensor_3);
|
||||
|
||||
tensor_5.backward();
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
|
||||
assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
|
||||
}
|
||||
#[test]
|
||||
fn test_matmul_complex_2() {
|
||||
|
@ -99,17 +99,17 @@ mod tests {
|
|||
let tensor_5 = tensor_4.matmul(&tensor_3);
|
||||
let tensor_6 = tensor_1.matmul(&tensor_5);
|
||||
|
||||
tensor_6.backward();
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
grad_1.to_data(),
|
||||
Data::from([[800.0, 792.0], [360.0, 592.0]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.into_data(),
|
||||
grad_2.to_data(),
|
||||
Data::from([[264., 264.0], [344.0, 344.0]])
|
||||
);
|
||||
}
|
||||
|
|
|
@ -90,13 +90,13 @@ mod tests {
|
|||
let tensor_2 = ADTchTensor::from_data(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_1.clone() * tensor_2.clone();
|
||||
tensor_3.backward();
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), data_2);
|
||||
assert_eq!(grad_2.into_data(), data_1);
|
||||
assert_eq!(grad_1.to_data(), data_2);
|
||||
assert_eq!(grad_2.to_data(), data_1);
|
||||
assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0]));
|
||||
}
|
||||
|
||||
|
@ -106,11 +106,12 @@ mod tests {
|
|||
|
||||
let tensor = ADTchTensor::from_data(data.clone());
|
||||
let tensor_out = tensor.clone() * 4.0;
|
||||
tensor_out.backward();
|
||||
|
||||
let grad = tensor.grad();
|
||||
let grads = tensor_out.backward();
|
||||
let grad = grads.wrt(&tensor).unwrap();
|
||||
|
||||
assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0]));
|
||||
assert_eq!(grad.into_data(), Data::from([4.0, 4.0]));
|
||||
assert_eq!(grad.to_data(), Data::from([4.0, 4.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -127,15 +128,15 @@ mod tests {
|
|||
let tensor_5 = tensor_4.mul(&tensor_3);
|
||||
let tensor_6 = tensor_1.mul(&tensor_5);
|
||||
|
||||
tensor_6.backward();
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
grad_1.to_data(),
|
||||
Data::from([[16.0, 196.0], [104.0, -36.0]])
|
||||
);
|
||||
assert_eq!(grad_2.into_data(), Data::from([[2.0, 98.0], [338.0, 18.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[2.0, 98.0], [338.0, 18.0]]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,13 +90,13 @@ mod tests {
|
|||
let tensor_2 = ADTchTensor::from_data(data_2.clone());
|
||||
|
||||
let tensor_3 = tensor_1.clone() - tensor_2.clone();
|
||||
tensor_3.backward();
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0]));
|
||||
assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0]));
|
||||
}
|
||||
|
||||
|
@ -105,10 +105,11 @@ mod tests {
|
|||
let data = Data::from([2.0, 10.0]);
|
||||
let tensor = ADTchTensor::from_data(data.clone());
|
||||
let tensor_out = tensor.clone() - 5.0;
|
||||
tensor_out.backward();
|
||||
let grads = tensor_out.backward();
|
||||
|
||||
let grad = tensor.grad();
|
||||
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
|
||||
let grad = grads.wrt(&tensor).unwrap();
|
||||
|
||||
assert_eq!(grad.to_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0]));
|
||||
}
|
||||
|
||||
|
@ -126,12 +127,12 @@ mod tests {
|
|||
let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(&5.0);
|
||||
let tensor_6 = tensor_1.sub(&tensor_5);
|
||||
|
||||
tensor_6.backward();
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([[0.0, 0.0], [0.0, 0.0]]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([[1.0, 1.0], [1.0, 1.0]]));
|
||||
assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,16 +15,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> ADTensor<P, D, T>
|
||||
where
|
||||
T: Zeros<T> + Clone + Add<Output = T>,
|
||||
T: std::fmt::Debug,
|
||||
{
|
||||
pub fn grad(&self) -> T {
|
||||
self.node.state.borrow_mut().grad()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> AsNode<T> for ADTensor<P, D, T> {
|
||||
fn as_node(&self) -> &crate::node::Node<T> {
|
||||
&self.node
|
||||
|
@ -50,17 +40,17 @@ mod tests {
|
|||
let tensor_4 = tensor_3.matmul(&tensor_1);
|
||||
let tensor_5 = tensor_4.mul(&tensor_2);
|
||||
|
||||
tensor_5.backward();
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
grad_1.to_data(),
|
||||
Data::from([[593., 463.0], [487.0, 539.0]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.into_data(),
|
||||
grad_2.to_data(),
|
||||
Data::from([[734.0, 294.0], [1414.0, 242.0]])
|
||||
);
|
||||
}
|
||||
|
@ -77,19 +67,16 @@ mod tests {
|
|||
let tensor_4 = tensor_3.matmul(&tensor_1);
|
||||
let tensor_5 = tensor_4.add_scalar(&17.0).add(&tensor_2);
|
||||
|
||||
tensor_5.backward();
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
let grad_1 = grads.wrt(&tensor_1).unwrap();
|
||||
let grad_2 = grads.wrt(&tensor_2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
grad_1.to_data(),
|
||||
Data::from([[166.0, 110.0], [212.0, 156.0]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.into_data(),
|
||||
Data::from([[113.0, 141.0], [33.0, 41.0]])
|
||||
);
|
||||
assert_eq!(grad_2.to_data(), Data::from([[113.0, 141.0], [33.0, 41.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in New Issue