Add prod and prod_dim tensor ops (#1460)

This commit is contained in:
Dilshod Tadjibaev 2024-03-12 14:00:02 -05:00 committed by GitHub
parent 80aac1dde4
commit 7a98b2f663
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 452 additions and 68 deletions

View File

@ -134,37 +134,37 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| Burn | PyTorch Equivalent |
| ------------------------------------- | -------------------------------------------- |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
### Numeric Operations
@ -203,13 +203,13 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
| `tensor.max()` | `tensor.max()` |
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
| `tensor.max_dim(dim)` | `tensor.max(dim, keepdim=True)` |
| `tensor.max_dim_with_indices(dim)` | N/A |
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
| `tensor.mean()` | `tensor.mean()` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim, keepdim=True)` |
| `tensor.min()` | `tensor.min()` |
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
| `tensor.min_dim(dim)` | `tensor.min(dim, keepdim=True)` |
| `tensor.min_dim_with_indices(dim)` | N/A |
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
| `tensor.mul(other)` or `tensor * other` | `tensor * other` |
@ -218,6 +218,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
| `tensor.prod()` | `tensor.prod()` |
| `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` |
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
@ -225,7 +227,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
@ -269,7 +271,7 @@ Those operations are only available for `Int` tensors.
| ------------------------------------------------ | ------------------------------------------------------- |
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
@ -277,27 +279,27 @@ Those operations are only available for `Int` tensors.
Those operations are only available for `Bool` tensors.
| Burn API | PyTorch Equivalent |
| ------------------- | ----------------------------------- |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
| Burn API | PyTorch Equivalent |
| ------------------- | ------------------------------- |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.int()` | `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
## Activation Functions
| Burn API | PyTorch Equivalent |
| ---------------------------------------- | ----------------------------------------------------- |
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | Similar to `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | Similar to `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | Similar to `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | Similar to `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | Similar to `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
| Burn API | PyTorch Equivalent |
| ---------------------------------------- | ------------------------------------------ |
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |

View File

@ -356,4 +356,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
B::int_sign(tensor)
}
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
B::int_prod(tensor)
}
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
B::int_prod_dim(tensor, dim)
}
}

View File

@ -2379,6 +2379,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.parents([&tensor])
.stateless(B::float_sign(tensor.primitive))
}
// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458
}
#[derive(Debug, Clone)]

View File

@ -313,6 +313,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!("prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")

View File

@ -1060,6 +1060,47 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(ProdOps, B::int_prod);
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())),
ProdOps::<D>::new(desc),
);
out
}
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
scalar_int_ops!(ProdDimOps, B::int_prod_dim, usize, noconvert);
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())),
ProdDimOps::<D>::new(desc),
);
out
}
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(MeanOps, B::int_mean);

View File

@ -614,6 +614,19 @@ impl<E: Element> NumericOperationDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::Prod(desc) => {
NumericOperationDescription::Prod(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::ProdDim(desc) => {
NumericOperationDescription::ProdDim(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs, // Dim should stay the same.
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::EqualElem(desc) => {
NumericOperationDescription::EqualElem(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),

View File

@ -300,6 +300,19 @@ pub enum NumericOperationDescription<E> {
/// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim).
/// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim).
SumDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [prod](burn_tensor::ops::FloatTensorOps::float_prod).
/// Int => [prod](burn_tensor::ops::IntTensorOps::int_prod).
Prod(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [prod dim](burn_tensor::ops::FloatTensorOps::float_prod_dim).
/// Int => [prod dim](burn_tensor::ops::IntTensorOps::int_prod_dim).
ProdDim(ScalarOperationDescription<usize>),
/// Operation corresponding to:
///
/// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem).
@ -1141,6 +1154,12 @@ impl<E: Element> NumericOperationDescription<E> {
NumericOperationDescription::SumDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Prod(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::ProdDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Max(desc) => {
vec![&desc.input, &desc.out]
}
@ -1358,6 +1377,8 @@ impl<E> core::hash::Hash for NumericOperationDescription<E> {
NumericOperationDescription::Mean(desc) => desc.hash(state),
NumericOperationDescription::Sum(desc) => desc.hash(state),
NumericOperationDescription::SumDim(desc) => desc.hash(state),
NumericOperationDescription::Prod(desc) => desc.hash(state),
NumericOperationDescription::ProdDim(desc) => desc.hash(state),
NumericOperationDescription::EqualElem(desc) => desc.hash(state),
NumericOperationDescription::Greater(desc) => desc.hash(state),
NumericOperationDescription::GreaterElem(desc) => desc.hash(state),

View File

@ -255,6 +255,19 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
kernel::reduce::sum_dim(tensor, dim, Default::default())
}
fn int_prod<const D: usize>(_tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
// kernel::reduce::prod(tensor, Default::default())
todo!("prod for int tensor")
}
fn int_prod_dim<const D: usize>(
_tensor: IntTensor<Self, D>,
_dim: usize,
) -> IntTensor<Self, D> {
// kernel::reduce::prod_dim(tensor, dim, Default::default())
todo!("prod_dim for int tensor")
}
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
kernel::reduce::mean_dim(tensor, dim, Default::default())
}

View File

@ -2,6 +2,7 @@ use burn_tensor::Element;
use libm::{exp, fabs, log, log1p, pow, sqrt};
use libm::{expf, fabsf, log1pf, logf, powf, sqrtf};
use ndarray::LinalgScalar;
use num_traits::One;
use num_traits::Signed;
/// A float element for ndarray backend.
@ -14,6 +15,7 @@ where
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
+ One
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ ExpElement

View File

@ -14,7 +14,7 @@ use ndarray::IxDyn;
use ndarray::SliceInfoElem;
use crate::element::NdArrayElement;
use crate::ops::macros::{keepdim, mean_dim, sum_dim};
use crate::ops::macros::{keepdim, mean_dim, prod_dim, sum_dim};
use crate::{reshape, tensor::NdArrayTensor};
pub struct NdArrayOps<E> {
@ -202,6 +202,11 @@ where
NdArrayTensor::from_data(data)
}
pub fn prod<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
let data = Data::from([tensor.array.product()]);
NdArrayTensor::from_data(data)
}
pub fn mean_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
@ -229,6 +234,21 @@ where
}
}
pub fn prod_dim<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<E, D> {
match D {
1 => keepdim!(0, dim, tensor, prod),
2 => keepdim!(1, dim, tensor, prod),
3 => keepdim!(2, dim, tensor, prod),
4 => keepdim!(3, dim, tensor, prod),
5 => keepdim!(4, dim, tensor, prod),
6 => keepdim!(5, dim, tensor, prod),
_ => panic!("Dim not supported {D}"),
}
}
pub fn gather<const D: usize>(
dim: usize,
mut tensor: NdArrayTensor<E, D>,

View File

@ -279,6 +279,17 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
NdArrayMathOps::sum_dim(tensor, dim)
}
fn int_prod<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, 1> {
NdArrayMathOps::prod(tensor)
}
fn int_prod_dim<const D: usize>(
tensor: NdArrayTensor<i64, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::prod_dim(tensor, dim)
}
fn int_mean<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, 1> {
NdArrayMathOps::mean(tensor)
}

View File

@ -21,6 +21,17 @@ macro_rules! keepdim {
shape.dims[$dim] = 1;
NdArrayOps::reshape(tensor, shape)
}};
(
$D:expr,
$dim:expr,
$self:expr,
prod
) => {{
let tensor: NdArrayTensor<E, $D> = prod_dim($self.clone(), $dim);
let mut shape = $self.shape();
shape.dims[$dim] = 1;
NdArrayOps::reshape(tensor, shape)
}};
}
pub(crate) use keepdim;
@ -45,3 +56,15 @@ pub(crate) fn sum_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
NdArrayTensor { array }
}
pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
dim: usize,
) -> NdArrayTensor<E, D2> {
let array = tensor
.array
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
.into_shared();
NdArrayTensor { array }
}

View File

@ -312,11 +312,6 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::new(tensor)
}
pub fn sum<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
let tensor = tensor.tensor.sum(E::KIND);
TchTensor::new(tensor)
}
pub fn mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchTensor::from_existing(
tensor
@ -326,6 +321,11 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
)
}
pub fn sum<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
let tensor = tensor.tensor.sum(E::KIND);
TchTensor::new(tensor)
}
pub fn sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchTensor::from_existing(
tensor
@ -335,6 +335,18 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
)
}
pub fn prod<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
let tensor = tensor.tensor.prod(E::KIND);
TchTensor::new(tensor)
}
pub fn prod_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchTensor::from_existing(
tensor.tensor.prod_dim_int(dim as i64, true, E::KIND),
tensor.storage,
)
}
pub fn argmax<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.argmax(dim as i64, true);

View File

@ -263,6 +263,14 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
TchOps::sum_dim(tensor, dim)
}
fn int_prod<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
TchOps::prod(tensor)
}
fn int_prod_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::prod_dim(tensor, dim)
}
fn int_mean<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
let tensor: TchTensor<f64, D> =
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));

View File

@ -318,12 +318,20 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
TchOps::sum(tensor)
}
fn float_sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::sum_dim(tensor, dim)
}
fn float_mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::mean_dim(tensor, dim)
}
fn float_sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::sum_dim(tensor, dim)
fn float_prod<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
TchOps::prod(tensor)
}
fn float_prod_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
TchOps::prod_dim(tensor, dim)
}
fn float_to_full_precision<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<f32, D> {

View File

@ -116,18 +116,33 @@ where
Tensor::new(K::sum(self.primitive))
}
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation.
/// Aggregate all elements along the given *dimension* or *axis*
/// in the tensor with the mean operation.
pub fn mean_dim(self, dim: usize) -> Self {
check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
Self::new(K::mean_dim(self.primitive, dim))
}
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation.
/// Aggregate all elements along the given *dimension* or *axis*
/// in the tensor with the sum operation.
pub fn sum_dim(self, dim: usize) -> Self {
check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
Self::new(K::sum_dim(self.primitive, dim))
}
/// Aggregate all elements along the given *dimension* or *axis*
/// in the tensor with the product operation.
pub fn prod(self) -> Tensor<B, 1, K> {
Tensor::new(K::prod(self.primitive))
}
/// Aggregate all elements along the given *dimension* or *axis*
/// in the tensor with the product operation.
pub fn prod_dim(self, dim: usize) -> Self {
check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
Self::new(K::prod_dim(self.primitive, dim))
}
/// Applies element wise equal comparison and returns a boolean tensor.
pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
K::equal_elem::<D>(self.primitive, other.elem())
@ -1024,6 +1039,51 @@ where
/// which is more high-level and designed for public use.
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
/// Computes the product of all the elements of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the product of.
///
/// # Returns
///
/// The product of all the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the product of all the elements of a tensor, users should prefer the
/// [Tensor::prod](Tensor::prod) function,
/// which is more high-level and designed for public use.
fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
/// Computes the product of all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the product of.
/// * `dim` - The dimension along which to compute the product.
///
/// # Returns
///
/// The product of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the product of all the elements of a tensor along a dimension, users should
/// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
/// which is more high-level and designed for public use.
///
///
fn prod_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
/// Computes the mean of all the elements of the tensor.
///
/// # Arguments
@ -1881,12 +1941,23 @@ impl<B: Backend> Numeric<B> for Int {
) -> Self::Primitive<D> {
B::int_full(shape, fill_value.elem(), device)
}
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::int_sum(tensor)
}
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::int_sum_dim(tensor, dim)
}
fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::int_prod(tensor)
}
fn prod_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::int_prod_dim(tensor, dim)
}
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::int_mean(tensor)
}
@ -2174,6 +2245,7 @@ impl<B: Backend> Numeric<B> for Float {
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
B::float_ones(shape, device)
}
fn full<const D: usize, E: ElementConversion>(
shape: Shape<D>,
fill_value: E,
@ -2181,15 +2253,27 @@ impl<B: Backend> Numeric<B> for Float {
) -> Self::Primitive<D> {
B::float_full(shape, fill_value.elem(), device)
}
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::float_sum(tensor)
}
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::float_sum_dim(tensor, dim)
}
fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::float_prod(tensor)
}
fn prod_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::float_prod_dim(tensor, dim)
}
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
B::float_mean(tensor)
}
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
B::float_mean_dim(tensor, dim)
}

View File

@ -1,12 +1,13 @@
use crate::Distribution;
use half::{bf16, f16};
use num_traits::{identities::Zero, ToPrimitive};
use num_traits::{identities::Zero, One, ToPrimitive};
use rand::RngCore;
/// Element trait for tensor.
pub trait Element:
ToPrimitive
+ Zero
+ One
+ ElementRandom
+ ElementConversion
+ ElementPrecision

View File

@ -757,6 +757,29 @@ pub trait IntTensorOps<B: Backend> {
/// The sum of all elements in the tensor along the dimension.
fn int_sum_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D>;
/// Computes the product of all elements in the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the product of.
///
/// # Returns
///
/// The product of all elements in the tensor.
fn int_prod<const D: usize>(tensor: IntTensor<B, D>) -> IntTensor<B, 1>;
/// Computes the product of all elements in the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the product of.
/// * `dim` - The dimension to compute the product along.
///
/// # Returns
///
/// The product of all elements in the tensor along the dimension.
fn int_prod_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D>;
/// Computes the mean of all elements in the tensor.
///
/// # Arguments

View File

@ -831,6 +831,34 @@ pub trait FloatTensorOps<B: Backend> {
/// A tensor with the sum of all elements in `tensor` along `dim`.
fn float_sum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
/// Product of all elements in a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to product.
///
/// # Returns
///
/// A scalar tensor with the product of all elements in `tensor`.
fn float_prod<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
// Product of all elements in a tensor
B::float_exp(B::float_sum(B::float_log(tensor)))
}
/// Product of all elements in a tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to product.
///
/// # Returns
///
/// A tensor with the product of all elements in `tensor` along `dim`.
fn float_prod_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D> {
// Product of all elements in a tensor along a dimension
B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
}
/// Mean of all elements in a tensor.
///
/// # Arguments

View File

@ -122,4 +122,59 @@ mod tests {
Data::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], Shape::new([2, 1, 3]))
);
}
#[test]
fn test_prod_float() {
let tensor = TestTensor::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let data_actual = tensor.prod().to_data();
// 2 * 1 * 2 * 3 * 4 * 5 = 240 but we need to check the precision because of the float
Data::from([240.0]).assert_approx_eq(&data_actual, 4);
let tensor_with_zero = TestTensor::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
let data_actual = tensor_with_zero.prod().to_data();
assert_eq!(data_actual, Data::from([0.0]));
}
#[test]
#[ignore = "Not implemented for all backends yet"]
fn test_prod_int() {
let tensor = TestTensorInt::from([[2, 1, 2], [3, 4, 5]]);
let data_actual = tensor.prod().to_data();
assert_eq!(data_actual, Data::from([240]));
let tensor_with_zero = TestTensorInt::from([[2, 0, 2], [3, 4, 5]]);
let data_actual = tensor_with_zero.prod().to_data();
assert_eq!(data_actual, Data::from([0]));
}
#[test]
fn test_prod_dim_float() {
let tensor = TestTensor::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let data_actual = tensor.prod_dim(1).to_data();
Data::from([[4.0], [60.0]]).assert_approx_eq(&data_actual, 4);
let tensor_with_zero = TestTensor::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
let data_actual = tensor_with_zero.prod_dim(1).to_data();
Data::from([[0.0], [60.0]]).assert_approx_eq(&data_actual, 4);
}
#[test]
#[ignore = "Not implemented for all backends yet"]
fn test_prod_dim_int() {
let tensor = TestTensorInt::from([[2, 1, 2], [3, 4, 5]]);
let data_actual = tensor.prod_dim(1).to_data();
assert_eq!(data_actual, Data::from([[4], [60]]));
let tensor_with_zero = TestTensorInt::from([[2, 0, 2], [3, 4, 5]]);
let data_actual = tensor_with_zero.prod_dim(1).to_data();
assert_eq!(data_actual, Data::from([[0], [60]]));
}
}