feat: add max & min ops (#339)

This commit is contained in:
Nathaniel Simard 2023-05-09 15:49:33 -04:00 committed by GitHub
parent 69001b0d69
commit a88357ce1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 319 additions and 1 deletions

View File

@ -0,0 +1,20 @@
use super::{unary, Backward, Ops};
use crate::grads::Gradients;
use burn_tensor::{backend::Backend, Shape};
#[derive(Debug)]
pub(crate) struct MaxMinDim;
impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
type State = (B::IntTensorPrimitive<D>, Shape<D>);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let (indexes, shape) = ops.state;
let device = B::device(&grad);
let zeros = B::zeros(shape, &device);
B::index_select_assign(zeros, indexes, grad)
});
}
}

View File

@ -6,6 +6,8 @@ mod int_tensor;
mod module;
mod tensor;
pub(crate) mod maxmin;
pub use backward::*;
pub use base::*;
pub use int_tensor::*;

View File

@ -11,6 +11,8 @@ use crate::{
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, Tensor};
use super::maxmin::MaxMinDim;
impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn from_data<const D: usize>(
data: Data<FloatElem<B>, D>,
@ -1304,6 +1306,66 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
output.register_step(ops)
}
fn max_dim<const D: usize>(tensor: ADTensor<B, D>, dim: usize) -> ADTensor<B, D> {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
prep.finish((index, shape), tensor)
}
OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)),
}
}
fn max_dim_with_indexes<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
) -> (ADTensor<B, D>, IntTensor<B, D>) {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
let tensor = prep.finish((index.clone(), shape), tensor);
(tensor, index)
}
OpsKind::UnTracked(prep) => {
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
let tensor = prep.finish(tensor);
(tensor, index)
}
}
}
fn min_dim<const D: usize>(tensor: ADTensor<B, D>, dim: usize) -> ADTensor<B, D> {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
prep.finish((index, shape), tensor)
}
OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)),
}
}
fn min_dim_with_indexes<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
) -> (ADTensor<B, D>, IntTensor<B, D>) {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let shape = B::shape(&tensor.primitive);
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
let tensor = prep.finish((index.clone(), shape), tensor);
(tensor, index)
}
OpsKind::UnTracked(prep) => {
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
let tensor = prep.finish(tensor);
(tensor, index)
}
}
}
}
/// Make sure the grad tensor has the given shape.

View File

@ -0,0 +1,45 @@
#[burn_tensor_testgen::testgen(ad_maxmin)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_max_dim() {
let tensor_1 = TestADTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad();
let tensor_2 = TestADTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[50.0, 34.0], [40.0, -10.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[8.0, 10.0], [56.0, 15.0]]), 5);
}
#[test]
fn should_diff_min_dim() {
let tensor_1 = TestADTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad();
let tensor_2 = TestADTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[-42.0, 38.0], [-34.0, -24.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5);
}
}

View File

@ -20,6 +20,7 @@ mod log;
mod log1p;
mod mask;
mod matmul;
mod maxmin;
mod maxpool2d;
mod mul;
mod multithread;
@ -59,6 +60,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_multithread!();
burn_autodiff::testgen_ad_add!();
burn_autodiff::testgen_ad_aggregation!();
burn_autodiff::testgen_ad_maxmin!();
burn_autodiff::testgen_ad_cat!();
burn_autodiff::testgen_ad_cos!();
burn_autodiff::testgen_ad_cross_entropy_loss!();

View File

@ -15,6 +15,9 @@ impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
tensor: TchTensor<E, D>,
grad: TchTensor<E, D>,
) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.gelu_backward(&grad.tensor, "none"))
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
TchTensor::from_existing(tensor, storage)
}
}

View File

@ -331,6 +331,46 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::from_existing(tensor, storage)
}
fn max_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let (tensor, _indexes) = tensor.tensor.max_dim(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
fn max_dim_with_indexes<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
let storage = tensor.storage.clone();
let (tensor, indexes) = tensor.tensor.max_dim(dim as i64, true);
let tensor = TchTensor::from_existing(tensor, storage);
let indexes = TchTensor::new(indexes);
(tensor, indexes)
}
fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let (tensor, _indexes) = tensor.tensor.min_dim(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
fn min_dim_with_indexes<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
let storage = tensor.storage.clone();
let (tensor, indexes) = tensor.tensor.min_dim(dim as i64, true);
let tensor = TchTensor::from_existing(tensor, storage);
let indexes = TchTensor::new(indexes);
(tensor, indexes)
}
fn argmin<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.argmin(dim as i64, true);

View File

@ -247,6 +247,57 @@ where
Tensor::new(B::argmax(self.primitive, dim))
}
/// Find the maximum value.
pub fn max(self) -> Tensor<B, 1> {
Tensor::new(B::max(self.primitive))
}
/// Find the maximum value along the given dimension.
pub fn max_dim(self, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::aggregate_dim::<D>("Max", dim));
Tensor::new(B::max_dim(self.primitive, dim))
}
/// Find the maximum value along the given dimension.
///
/// Also returns the indexes.
pub fn max_dim_with_indexes(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
check!(TensorCheck::aggregate_dim::<D>("Max", dim));
let (tensor, index) = B::max_dim_with_indexes(self.primitive, dim);
let tensor = Tensor::new(tensor);
let index = Tensor::new(index);
(tensor, index)
}
/// Find the minimum value.
pub fn min(self) -> Tensor<B, 1> {
Tensor::new(B::min(self.primitive))
}
/// Find the minimum value along the given dimension.
pub fn min_dim(self, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::aggregate_dim::<D>("Min", dim));
Tensor::new(B::min_dim(self.primitive, dim))
}
/// Find the minimum value along the given dimension.
///
/// Also returns the indexes.
pub fn min_dim_with_indexes(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
check!(TensorCheck::aggregate_dim::<D>("Min", dim));
let (tensor, index) = B::min_dim_with_indexes(self.primitive, dim);
let tensor = Tensor::new(tensor);
let index = Tensor::new(index);
(tensor, index)
}
/// Applies the argmin function along the given dimension and returns an integer tensor.
///
/// # Example

View File

@ -241,4 +241,44 @@ pub trait TensorOps<B: Backend> {
tensors: Vec<B::TensorPrimitive<D>>,
dim: usize,
) -> B::TensorPrimitive<D>;
fn max<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
B::max_dim(tensor, 0)
}
fn max_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
let index = B::argmax(tensor.clone(), dim);
B::index_select(tensor, index)
}
fn max_dim_with_indexes<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::argmax(tensor.clone(), dim);
let values = B::index_select(tensor, index.clone());
(values, index)
}
fn min<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
B::min_dim(tensor, 0)
}
fn min_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
let index = B::argmin(tensor.clone(), dim);
B::index_select(tensor, index)
}
fn min_dim_with_indexes<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::argmin(tensor.clone(), dim);
let values = B::index_select(tensor, index.clone());
(values, index)
}
}

View File

@ -25,6 +25,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
burn_tensor::testgen_erf!();

View File

@ -0,0 +1,51 @@
#[burn_tensor_testgen::testgen(maxmin)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn test_max_dim_2d() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let output_actual = tensor.max_dim(1);
let output_expected = Data::from([[2.], [5.]]);
assert_eq!(output_expected, output_actual.into_data());
}
#[test]
fn test_max_dim_with_indexes_2d() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let (output_actual, index_actual) = tensor.max_dim_with_indexes(1);
let output_expected = Data::from([[2.], [5.]]);
let index_expected = Data::from([[2], [2]]);
assert_eq!(output_expected, output_actual.into_data());
assert_eq!(index_expected, index_actual.into_data());
}
#[test]
fn test_min_dim_2d() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let output_actual = tensor.min_dim(1);
let output_expected = Data::from([[0.], [3.]]);
assert_eq!(output_expected, output_actual.into_data());
}
#[test]
fn test_min_dim_with_indexes_2d() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let (output_actual, index_actual) = tensor.min_dim_with_indexes(1);
let output_expected = Data::from([[0.], [3.]]);
let index_expected = Data::from([[0], [0]]);
assert_eq!(output_expected, output_actual.into_data());
assert_eq!(index_expected, index_actual.into_data());
}
}

View File

@ -14,6 +14,7 @@ mod log1p;
mod map_comparison;
mod mask;
mod matmul;
mod maxmin;
mod mul;
mod neg;
mod powf;