mirror of https://github.com/tracel-ai/burn.git
feat: add max & min ops (#339)
This commit is contained in:
parent
69001b0d69
commit
a88357ce1d
|
@ -0,0 +1,20 @@
|
|||
use super::{unary, Backward, Ops};
|
||||
use crate::grads::Gradients;
|
||||
use burn_tensor::{backend::Backend, Shape};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct MaxMinDim;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
|
||||
type State = (B::IntTensorPrimitive<D>, Shape<D>);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let (indexes, shape) = ops.state;
|
||||
let device = B::device(&grad);
|
||||
let zeros = B::zeros(shape, &device);
|
||||
|
||||
B::index_select_assign(zeros, indexes, grad)
|
||||
});
|
||||
}
|
||||
}
|
|
@ -6,6 +6,8 @@ mod int_tensor;
|
|||
mod module;
|
||||
mod tensor;
|
||||
|
||||
pub(crate) mod maxmin;
|
||||
|
||||
pub use backward::*;
|
||||
pub use base::*;
|
||||
pub use int_tensor::*;
|
||||
|
|
|
@ -11,6 +11,8 @@ use crate::{
|
|||
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, Tensor};
|
||||
|
||||
use super::maxmin::MaxMinDim;
|
||||
|
||||
impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<FloatElem<B>, D>,
|
||||
|
@ -1304,6 +1306,66 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
|
||||
output.register_step(ops)
|
||||
}
|
||||
fn max_dim<const D: usize>(tensor: ADTensor<B, D>, dim: usize) -> ADTensor<B, D> {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
|
||||
prep.finish((index, shape), tensor)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)),
|
||||
}
|
||||
}
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
) -> (ADTensor<B, D>, IntTensor<B, D>) {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
|
||||
let tensor = prep.finish((index.clone(), shape), tensor);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
|
||||
let tensor = prep.finish(tensor);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
}
|
||||
}
|
||||
fn min_dim<const D: usize>(tensor: ADTensor<B, D>, dim: usize) -> ADTensor<B, D> {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
|
||||
prep.finish((index, shape), tensor)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)),
|
||||
}
|
||||
}
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
) -> (ADTensor<B, D>, IntTensor<B, D>) {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
|
||||
let tensor = prep.finish((index.clone(), shape), tensor);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
|
||||
let tensor = prep.finish(tensor);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure the grad tensor has the given shape.
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
#[burn_tensor_testgen::testgen(ad_maxmin)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_max_dim() {
|
||||
let tensor_1 = TestADTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad();
|
||||
let tensor_2 = TestADTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[50.0, 34.0], [40.0, -10.0]]), 5);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[8.0, 10.0], [56.0, 15.0]]), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_min_dim() {
|
||||
let tensor_1 = TestADTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]]).require_grad();
|
||||
let tensor_2 = TestADTensor::from_floats([[4.0, -7.0], [2.0, 3.0]]).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[-42.0, 38.0], [-34.0, -24.0]]), 5);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[10.0, 8.0], [15.0, 56.0]]), 5);
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ mod log;
|
|||
mod log1p;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxmin;
|
||||
mod maxpool2d;
|
||||
mod mul;
|
||||
mod multithread;
|
||||
|
@ -59,6 +60,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_multithread!();
|
||||
burn_autodiff::testgen_ad_add!();
|
||||
burn_autodiff::testgen_ad_aggregation!();
|
||||
burn_autodiff::testgen_ad_maxmin!();
|
||||
burn_autodiff::testgen_ad_cat!();
|
||||
burn_autodiff::testgen_ad_cos!();
|
||||
burn_autodiff::testgen_ad_cross_entropy_loss!();
|
||||
|
|
|
@ -15,6 +15,9 @@ impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
|
|||
tensor: TchTensor<E, D>,
|
||||
grad: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchTensor::new(tensor.tensor.gelu_backward(&grad.tensor, "none"))
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -331,6 +331,46 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
fn max_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, _indexes) = tensor.tensor.max_dim(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, indexes) = tensor.tensor.max_dim(dim as i64, true);
|
||||
|
||||
let tensor = TchTensor::from_existing(tensor, storage);
|
||||
let indexes = TchTensor::new(indexes);
|
||||
|
||||
(tensor, indexes)
|
||||
}
|
||||
|
||||
fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, _indexes) = tensor.tensor.min_dim(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, indexes) = tensor.tensor.min_dim(dim as i64, true);
|
||||
|
||||
let tensor = TchTensor::from_existing(tensor, storage);
|
||||
let indexes = TchTensor::new(indexes);
|
||||
|
||||
(tensor, indexes)
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.argmin(dim as i64, true);
|
||||
|
|
|
@ -247,6 +247,57 @@ where
|
|||
Tensor::new(B::argmax(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Find the maximum value.
|
||||
pub fn max(self) -> Tensor<B, 1> {
|
||||
Tensor::new(B::max(self.primitive))
|
||||
}
|
||||
|
||||
/// Find the maximum value along the given dimension.
|
||||
pub fn max_dim(self, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Max", dim));
|
||||
|
||||
Tensor::new(B::max_dim(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Find the maximum value along the given dimension.
|
||||
///
|
||||
/// Also returns the indexes.
|
||||
pub fn max_dim_with_indexes(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Max", dim));
|
||||
|
||||
let (tensor, index) = B::max_dim_with_indexes(self.primitive, dim);
|
||||
|
||||
let tensor = Tensor::new(tensor);
|
||||
let index = Tensor::new(index);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
|
||||
/// Find the minimum value.
|
||||
pub fn min(self) -> Tensor<B, 1> {
|
||||
Tensor::new(B::min(self.primitive))
|
||||
}
|
||||
|
||||
/// Find the minimum value along the given dimension.
|
||||
pub fn min_dim(self, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Min", dim));
|
||||
Tensor::new(B::min_dim(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Find the minimum value along the given dimension.
|
||||
///
|
||||
/// Also returns the indexes.
|
||||
pub fn min_dim_with_indexes(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Min", dim));
|
||||
|
||||
let (tensor, index) = B::min_dim_with_indexes(self.primitive, dim);
|
||||
|
||||
let tensor = Tensor::new(tensor);
|
||||
let index = Tensor::new(index);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
|
||||
/// Applies the argmin function along the given dimension and returns an integer tensor.
|
||||
///
|
||||
/// # Example
|
||||
|
|
|
@ -241,4 +241,44 @@ pub trait TensorOps<B: Backend> {
|
|||
tensors: Vec<B::TensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn max<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
|
||||
let shape = B::shape(&tensor);
|
||||
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
|
||||
|
||||
B::max_dim(tensor, 0)
|
||||
}
|
||||
fn max_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
|
||||
let index = B::argmax(tensor.clone(), dim);
|
||||
|
||||
B::index_select(tensor, index)
|
||||
}
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::argmax(tensor.clone(), dim);
|
||||
let values = B::index_select(tensor, index.clone());
|
||||
|
||||
(values, index)
|
||||
}
|
||||
fn min<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
|
||||
let shape = B::shape(&tensor);
|
||||
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
|
||||
|
||||
B::min_dim(tensor, 0)
|
||||
}
|
||||
fn min_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
|
||||
let index = B::argmin(tensor.clone(), dim);
|
||||
|
||||
B::index_select(tensor, index)
|
||||
}
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::argmin(tensor.clone(), dim);
|
||||
let values = B::index_select(tensor, index.clone());
|
||||
|
||||
(values, index)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_maxmin!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
#[burn_tensor_testgen::testgen(maxmin)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_max_dim_2d() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output_actual = tensor.max_dim(1);
|
||||
|
||||
let output_expected = Data::from([[2.], [5.]]);
|
||||
assert_eq!(output_expected, output_actual.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_dim_with_indexes_2d() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let (output_actual, index_actual) = tensor.max_dim_with_indexes(1);
|
||||
|
||||
let output_expected = Data::from([[2.], [5.]]);
|
||||
let index_expected = Data::from([[2], [2]]);
|
||||
|
||||
assert_eq!(output_expected, output_actual.into_data());
|
||||
assert_eq!(index_expected, index_actual.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_dim_2d() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output_actual = tensor.min_dim(1);
|
||||
|
||||
let output_expected = Data::from([[0.], [3.]]);
|
||||
assert_eq!(output_expected, output_actual.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_dim_with_indexes_2d() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let (output_actual, index_actual) = tensor.min_dim_with_indexes(1);
|
||||
|
||||
let output_expected = Data::from([[0.], [3.]]);
|
||||
let index_expected = Data::from([[0], [0]]);
|
||||
|
||||
assert_eq!(output_expected, output_actual.into_data());
|
||||
assert_eq!(index_expected, index_actual.into_data());
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ mod log1p;
|
|||
mod map_comparison;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxmin;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod powf;
|
||||
|
|
Loading…
Reference in New Issue