mirror of https://github.com/tracel-ai/burn.git
Add `sign` tensor operator (#1446)
This commit is contained in:
parent
56f460295a
commit
3f7e6bd5bc
|
@ -220,6 +220,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
|
||||
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
|
||||
| `tensor.select_assign(dim, indices, values)` | N/A |
|
||||
| `tensor.sign()` | `tensor.sign()` |
|
||||
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
|
||||
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
|
||||
| `tensor.sum()` | `tensor.sum()` |
|
||||
|
|
|
@ -352,4 +352,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
) -> IntTensor<Self, D> {
|
||||
B::int_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
B::int_sign(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2350,6 +2350,35 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
OpsKind::UnTracked(prep) => prep.finish(B::float_powf(lhs.primitive, rhs.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_sign<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
#[derive(Debug)]
|
||||
struct Sign;
|
||||
|
||||
retro_unary!(RetroSign, B::float_sign);
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Sign {
|
||||
type State = ();
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad|
|
||||
// Always return 0 because the derivative of the sign function
|
||||
// does not contribute to gradient updates in a meaningful way.
|
||||
B::float_mul_scalar(grad, 0.elem()));
|
||||
}
|
||||
}
|
||||
|
||||
Sign.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSign::<B, D>::new(tensor.node.id.clone()))
|
||||
.parents([&tensor])
|
||||
.stateless(B::float_sign(tensor.primitive))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
|
@ -43,6 +43,7 @@ mod relu;
|
|||
mod reshape;
|
||||
mod select;
|
||||
mod sigmoid;
|
||||
mod sign;
|
||||
mod sin;
|
||||
mod slice;
|
||||
mod softmax;
|
||||
|
@ -114,5 +115,6 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_transpose!();
|
||||
burn_autodiff::testgen_ad_permute!();
|
||||
burn_autodiff::testgen_ad_nonzero!();
|
||||
burn_autodiff::testgen_ad_sign!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
#[burn_tensor_testgen::testgen(ad_sign)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
/// Example using the sign function with PyTorch:
|
||||
// >>> import torch
|
||||
// >>> # Create a tensor with requires_grad=True
|
||||
// >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
// >>> # Forward pass: Apply the sign function
|
||||
// >>> y = torch.sign(x)
|
||||
// >>> print("Forward pass:")
|
||||
// Forward pass:
|
||||
// >>> print("x:", x)
|
||||
// x: tensor([-2., -1., 0., 1., 2.], requires_grad=True)
|
||||
// >>> print("y:", y)
|
||||
// y: tensor([-1., -1., 0., 1., 1.], grad_fn=<SignBackward0>)
|
||||
// >>> # Compute the loss (just an example)
|
||||
// >>> loss = y.sum()
|
||||
// >>> # Backward pass: Compute the gradients
|
||||
// >>> loss.backward()
|
||||
// >>> print("\nBackward pass:")
|
||||
// Backward pass:
|
||||
// >>> print("x.grad:", x.grad)
|
||||
// x.grad: tensor([0., 0., 0., 0., 0.])
|
||||
|
||||
#[test]
|
||||
fn should_diff_sign() {
|
||||
let data = Data::<f32, 1>::from([-2.0, -1.0, 0.0, 1.0, 2.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(data, &device).require_grad();
|
||||
|
||||
let y = x.clone().sign();
|
||||
|
||||
let loss = y.clone().sum();
|
||||
let grads = loss.backward();
|
||||
let grad = x.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(y.to_data(), Data::from([-1., -1., 0., 1., 1.]));
|
||||
assert_eq!(grad.to_data(), Data::from([0., 0., 0., 0., 0.]));
|
||||
}
|
||||
}
|
|
@ -82,6 +82,7 @@ mod tests {
|
|||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
|
||||
// TODO: https://github.com/tracel-ai/burn/issues/1237
|
||||
//
|
||||
|
|
|
@ -419,4 +419,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
) -> IntTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
// TODO add sign operator once Candle supports it:
|
||||
// https://github.com/huggingface/candle/issues/1827
|
||||
}
|
||||
|
|
|
@ -524,4 +524,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
) -> FloatTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
// TODO add sign operator once Candle supports it:
|
||||
// https://github.com/huggingface/candle/issues/1827
|
||||
}
|
||||
|
|
|
@ -2,9 +2,10 @@ use burn_tensor::Element;
|
|||
use libm::{exp, fabs, log, log1p, pow, sqrt};
|
||||
use libm::{expf, fabsf, log1pf, logf, powf, sqrtf};
|
||||
use ndarray::LinalgScalar;
|
||||
use num_traits::Signed;
|
||||
|
||||
/// A float element for ndarray backend.
|
||||
pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar
|
||||
pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
|
|
|
@ -5,6 +5,7 @@ use core::{marker::PhantomData, ops::Range};
|
|||
use ndarray::s;
|
||||
use ndarray::Array2;
|
||||
use ndarray::Zip;
|
||||
use num_traits::Signed;
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use ndarray::Axis;
|
||||
|
@ -480,6 +481,26 @@ where
|
|||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayTensor::new(lhs.array.mapv(var_name).into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn sign_op<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D>
|
||||
where
|
||||
E: Signed,
|
||||
{
|
||||
NdArrayTensor::new(
|
||||
tensor
|
||||
.array
|
||||
.mapv(|x| {
|
||||
if x > E::zero() {
|
||||
E::one()
|
||||
} else if x < E::zero() {
|
||||
-E::one()
|
||||
} else {
|
||||
E::zero()
|
||||
}
|
||||
})
|
||||
.into_shared(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
enum CmpType {
|
||||
|
|
|
@ -434,4 +434,8 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
|
|||
let array = tensor.array.permuted_axes(axes.into_dimension());
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn int_sign<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -493,4 +493,8 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
let array = tensor.array.permuted_axes(axes.into_dimension());
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn float_sign<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -472,4 +472,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sign<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -465,4 +465,10 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
) -> burn_tensor::ops::IntTensor<Self, D> {
|
||||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
|
||||
TchOps::sign(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -474,4 +474,10 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
|||
) -> burn_tensor::ops::FloatTensor<Self, D> {
|
||||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_sign<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
|
||||
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
|
||||
TchOps::sign(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,6 +82,11 @@ where
|
|||
Self::new(K::neg(self.primitive))
|
||||
}
|
||||
|
||||
/// Returns the signs of the elements of the input tensor.
|
||||
pub fn sign(self) -> Self {
|
||||
Self::new(K::sign(self.primitive))
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is zero.
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
|
||||
Self::new(K::zeros(shape.into(), device))
|
||||
|
@ -890,6 +895,26 @@ where
|
|||
/// which is more high-level and designed for public use.
|
||||
fn neg<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Returns the signs of the elements of a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The signs of the elements of the tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn sign<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Creates a tensor filled with zeros.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -2085,6 +2110,10 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
) -> Self::Primitive<D> {
|
||||
B::int_random(shape, distribution, device)
|
||||
}
|
||||
|
||||
fn sign<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
|
||||
B::int_sign(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric<B> for Float {
|
||||
|
@ -2382,6 +2411,10 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
) -> Self::Primitive<D> {
|
||||
B::float_random(shape, distribution, device)
|
||||
}
|
||||
|
||||
fn sign<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
|
||||
B::float_sign(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
|
||||
|
|
|
@ -1107,7 +1107,6 @@ pub trait IntTensorOps<B: Backend> {
|
|||
///
|
||||
/// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
|
||||
/// evaluate to True, False otherwise.
|
||||
|
||||
fn int_all<const D: usize>(tensor: IntTensor<B, D>) -> BoolTensor<B, 1> {
|
||||
let num_elems = B::int_shape(&tensor).num_elements();
|
||||
let bool_tensor = B::int_equal_elem(tensor, 0.elem());
|
||||
|
@ -1128,7 +1127,6 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
|
||||
/// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
|
||||
/// evaluates to True, False otherwise.
|
||||
|
||||
fn int_all_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> BoolTensor<B, D> {
|
||||
let num_elems = B::int_shape(&tensor).dims[dim];
|
||||
let bool_tensor = B::int_equal_elem(tensor, 0.elem());
|
||||
|
@ -1136,4 +1134,23 @@ pub trait IntTensorOps<B: Backend> {
|
|||
let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
|
||||
B::int_equal_elem(sum, (num_elems as i32).elem())
|
||||
}
|
||||
|
||||
/// Returns the signs of the int `tensor`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to extract the signs from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
|
||||
fn int_sign<const D: usize>(tensor: IntTensor<B, D>) -> IntTensor<B, D> {
|
||||
let zeros = B::int_zeros(B::int_shape(&tensor), &B::int_device(&tensor));
|
||||
let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
|
||||
let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
|
||||
|
||||
let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
|
||||
result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
|
||||
result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1303,4 +1303,23 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
|
||||
B::float_equal_elem(sum, (num_elems as f32).elem())
|
||||
}
|
||||
|
||||
/// Returns the signs of the float `tensor`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to extract the signs from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
|
||||
fn float_sign<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
|
||||
let zeros = B::float_zeros(B::float_shape(&tensor), &B::float_device(&tensor));
|
||||
let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
|
||||
let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
|
||||
|
||||
let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
|
||||
result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
|
||||
result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,6 +88,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
|
|
|
@ -41,6 +41,7 @@ mod recip;
|
|||
mod repeat;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sign;
|
||||
mod sin;
|
||||
mod slice;
|
||||
mod sqrt;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
#[burn_tensor_testgen::testgen(sign)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_sign_ops_float() {
|
||||
let tensor = TestTensor::from([[-0.2, -1.0, 2.0], [3.0, 0.0, -5.0]]);
|
||||
|
||||
let data_actual = tensor.sign().into_data();
|
||||
|
||||
let data_expected = Data::from([[-1.0, -1.0, 1.0], [1.0, 0.0, -1.0]]);
|
||||
assert_eq!(data_actual, data_expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_sign_ops_int() {
|
||||
let tensor = TestTensorInt::from([[-2, -1, 2], [3, 0, -5]]);
|
||||
|
||||
let data_actual = tensor.sign().into_data();
|
||||
|
||||
let data_expected = Data::from([[-1, -1, 1], [1, 0, -1]]);
|
||||
assert_eq!(data_actual, data_expected);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue