Add `sign` tensor operator (#1446)

This commit is contained in:
Dilshod Tadjibaev 2024-03-11 10:39:30 -05:00 committed by GitHub
parent 56f460295a
commit 3f7e6bd5bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 231 additions and 3 deletions

View File

@ -220,6 +220,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.sign()` | `tensor.sign()` |
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
| `tensor.sum()` | `tensor.sum()` |

View File

@ -352,4 +352,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
) -> IntTensor<Self, D> {
B::int_permute(tensor, axes)
}
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
B::int_sign(tensor)
}
}

View File

@ -2350,6 +2350,35 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
OpsKind::UnTracked(prep) => prep.finish(B::float_powf(lhs.primitive, rhs.primitive)),
}
}
fn float_sign<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Sign;
retro_unary!(RetroSign, B::float_sign);
impl<B: Backend, const D: usize> Backward<B, D, 1> for Sign {
type State = ();
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad|
// Always return 0 because the derivative of the sign function
// does not contribute to gradient updates in a meaningful way.
B::float_mul_scalar(grad, 0.elem()));
}
}
Sign.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroSign::<B, D>::new(tensor.node.id.clone()))
.parents([&tensor])
.stateless(B::float_sign(tensor.primitive))
}
}
#[derive(Debug, Clone)]

View File

@ -43,6 +43,7 @@ mod relu;
mod reshape;
mod select;
mod sigmoid;
mod sign;
mod sin;
mod slice;
mod softmax;
@ -114,5 +115,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_permute!();
burn_autodiff::testgen_ad_nonzero!();
burn_autodiff::testgen_ad_sign!();
};
}

View File

@ -0,0 +1,43 @@
#[burn_tensor_testgen::testgen(ad_sign)]
mod tests {
use super::*;
use burn_tensor::Data;
/// Example using the sign function with PyTorch:
// >>> import torch
// >>> # Create a tensor with requires_grad=True
// >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
// >>> # Forward pass: Apply the sign function
// >>> y = torch.sign(x)
// >>> print("Forward pass:")
// Forward pass:
// >>> print("x:", x)
// x: tensor([-2., -1., 0., 1., 2.], requires_grad=True)
// >>> print("y:", y)
// y: tensor([-1., -1., 0., 1., 1.], grad_fn=<SignBackward0>)
// >>> # Compute the loss (just an example)
// >>> loss = y.sum()
// >>> # Backward pass: Compute the gradients
// >>> loss.backward()
// >>> print("\nBackward pass:")
// Backward pass:
// >>> print("x.grad:", x.grad)
// x.grad: tensor([0., 0., 0., 0., 0.])
#[test]
fn should_diff_sign() {
let data = Data::<f32, 1>::from([-2.0, -1.0, 0.0, 1.0, 2.0]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(data, &device).require_grad();
let y = x.clone().sign();
let loss = y.clone().sum();
let grads = loss.backward();
let grad = x.grad(&grads).unwrap();
assert_eq!(y.to_data(), Data::from([-1., -1., 0., 1., 1.]));
assert_eq!(grad.to_data(), Data::from([0., 0., 0., 0., 0.]));
}
}

View File

@ -82,6 +82,7 @@ mod tests {
burn_tensor::testgen_neg!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
// TODO: https://github.com/tracel-ai/burn/issues/1237
//

View File

@ -419,4 +419,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
) -> IntTensor<Self, D> {
permute(tensor, axes)
}
// TODO add sign operator once Candle supports it:
// https://github.com/huggingface/candle/issues/1827
}

View File

@ -524,4 +524,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
) -> FloatTensor<Self, D> {
permute(tensor, axes)
}
// TODO add sign operator once Candle supports it:
// https://github.com/huggingface/candle/issues/1827
}

View File

@ -2,9 +2,10 @@ use burn_tensor::Element;
use libm::{exp, fabs, log, log1p, pow, sqrt};
use libm::{expf, fabsf, log1pf, logf, powf, sqrtf};
use ndarray::LinalgScalar;
use num_traits::Signed;
/// A float element for ndarray backend.
pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar
pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
where
Self: Sized,
{

View File

@ -5,6 +5,7 @@ use core::{marker::PhantomData, ops::Range};
use ndarray::s;
use ndarray::Array2;
use ndarray::Zip;
use num_traits::Signed;
use burn_tensor::Shape;
use ndarray::Axis;
@ -480,6 +481,26 @@ where
) -> NdArrayTensor<E, D> {
NdArrayTensor::new(lhs.array.mapv(var_name).into_shared())
}
pub(crate) fn sign_op<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D>
where
E: Signed,
{
NdArrayTensor::new(
tensor
.array
.mapv(|x| {
if x > E::zero() {
E::one()
} else if x < E::zero() {
-E::one()
} else {
E::zero()
}
})
.into_shared(),
)
}
}
enum CmpType {

View File

@ -434,4 +434,8 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
let array = tensor.array.permuted_axes(axes.into_dimension());
NdArrayTensor { array }
}
fn int_sign<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
NdArrayMathOps::sign_op(tensor)
}
}

View File

@ -493,4 +493,8 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
let array = tensor.array.permuted_axes(axes.into_dimension());
NdArrayTensor { array }
}
fn float_sign<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::sign_op(tensor)
}
}

View File

@ -472,4 +472,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
)
}
pub fn sign<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
}
}

View File

@ -465,4 +465,10 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
) -> burn_tensor::ops::IntTensor<Self, D> {
TchOps::permute(tensor, axes)
}
fn int_sign<const D: usize>(
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::sign(tensor)
}
}

View File

@ -474,4 +474,10 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
) -> burn_tensor::ops::FloatTensor<Self, D> {
TchOps::permute(tensor, axes)
}
fn float_sign<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
TchOps::sign(tensor)
}
}

View File

@ -82,6 +82,11 @@ where
Self::new(K::neg(self.primitive))
}
/// Returns the signs of the elements of the input tensor.
pub fn sign(self) -> Self {
Self::new(K::sign(self.primitive))
}
/// Create a tensor of the given shape where each element is zero.
pub fn zeros<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
Self::new(K::zeros(shape.into(), device))
@ -890,6 +895,26 @@ where
/// which is more high-level and designed for public use.
fn neg<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
/// Returns the signs of the elements of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The signs of the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
/// which is more high-level and designed for public use.
fn sign<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
/// Creates a tensor filled with zeros.
///
/// # Arguments
@ -2085,6 +2110,10 @@ impl<B: Backend> Numeric<B> for Int {
) -> Self::Primitive<D> {
B::int_random(shape, distribution, device)
}
fn sign<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::int_sign(tensor)
}
}
impl<B: Backend> Numeric<B> for Float {
@ -2382,6 +2411,10 @@ impl<B: Backend> Numeric<B> for Float {
) -> Self::Primitive<D> {
B::float_random(shape, distribution, device)
}
fn sign<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
B::float_sign(tensor)
}
}
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>

View File

@ -1107,7 +1107,6 @@ pub trait IntTensorOps<B: Backend> {
///
/// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
/// evaluate to True, False otherwise.
fn int_all<const D: usize>(tensor: IntTensor<B, D>) -> BoolTensor<B, 1> {
let num_elems = B::int_shape(&tensor).num_elements();
let bool_tensor = B::int_equal_elem(tensor, 0.elem());
@ -1128,7 +1127,6 @@ pub trait IntTensorOps<B: Backend> {
/// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
/// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
/// evaluates to True, False otherwise.
fn int_all_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> BoolTensor<B, D> {
let num_elems = B::int_shape(&tensor).dims[dim];
let bool_tensor = B::int_equal_elem(tensor, 0.elem());
@ -1136,4 +1134,23 @@ pub trait IntTensorOps<B: Backend> {
let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
B::int_equal_elem(sum, (num_elems as i32).elem())
}
/// Returns the signs of the int `tensor`.
///
/// # Arguments
///
/// * `tensor` - The tensor to extract the signs from.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
fn int_sign<const D: usize>(tensor: IntTensor<B, D>) -> IntTensor<B, D> {
let zeros = B::int_zeros(B::int_shape(&tensor), &B::int_device(&tensor));
let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
result
}
}

View File

@ -1303,4 +1303,23 @@ pub trait FloatTensorOps<B: Backend> {
let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
B::float_equal_elem(sum, (num_elems as f32).elem())
}
/// Returns the signs of the float `tensor`.
///
/// # Arguments
///
/// * `tensor` - The tensor to extract the signs from.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
fn float_sign<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
let zeros = B::float_zeros(B::float_shape(&tensor), &B::float_device(&tensor));
let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
result
}
}

View File

@ -88,6 +88,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_permute!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
// test stats
burn_tensor::testgen_var!();

View File

@ -41,6 +41,7 @@ mod recip;
mod repeat;
mod reshape;
mod select;
mod sign;
mod sin;
mod slice;
mod sqrt;

View File

@ -0,0 +1,25 @@
#[burn_tensor_testgen::testgen(sign)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_sign_ops_float() {
let tensor = TestTensor::from([[-0.2, -1.0, 2.0], [3.0, 0.0, -5.0]]);
let data_actual = tensor.sign().into_data();
let data_expected = Data::from([[-1.0, -1.0, 1.0], [1.0, 0.0, -1.0]]);
assert_eq!(data_actual, data_expected);
}
#[test]
fn should_support_sign_ops_int() {
let tensor = TestTensorInt::from([[-2, -1, 2], [3, 0, -5]]);
let data_actual = tensor.sign().into_data();
let data_expected = Data::from([[-1, -1, 1], [1, 0, -1]]);
assert_eq!(data_actual, data_expected);
}
}