mirror of https://github.com/tracel-ai/burn.git
Add prod and prod_dim tensor ops (#1460)
This commit is contained in:
parent
80aac1dde4
commit
7a98b2f663
|
@ -135,7 +135,7 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
|
|||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
| ------------------------------------- | -------------------------------------------- |
|
||||
| ------------------------------------- | ------------------------------------ |
|
||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
|
@ -151,7 +151,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
|
||||
| `tensor.into_scalar()` | `tensor.item()` |
|
||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
||||
|
@ -203,13 +203,13 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
|
||||
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
|
||||
| `tensor.max()` | `tensor.max()` |
|
||||
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
|
||||
| `tensor.max_dim(dim)` | `tensor.max(dim, keepdim=True)` |
|
||||
| `tensor.max_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
|
||||
| `tensor.mean()` | `tensor.mean()` |
|
||||
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
|
||||
| `tensor.mean_dim(dim)` | `tensor.mean(dim, keepdim=True)` |
|
||||
| `tensor.min()` | `tensor.min()` |
|
||||
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
|
||||
| `tensor.min_dim(dim)` | `tensor.min(dim, keepdim=True)` |
|
||||
| `tensor.min_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
|
||||
| `tensor.mul(other)` or `tensor * other` | `tensor * other` |
|
||||
|
@ -218,6 +218,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
|
||||
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
|
||||
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
|
||||
| `tensor.prod()` | `tensor.prod()` |
|
||||
| `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` |
|
||||
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
|
||||
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
|
||||
| `tensor.select_assign(dim, indices, values)` | N/A |
|
||||
|
@ -225,7 +227,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
|
||||
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
|
||||
| `tensor.sum()` | `tensor.sum()` |
|
||||
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
|
||||
| `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` |
|
||||
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
|
||||
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
|
||||
|
||||
|
@ -269,7 +271,7 @@ Those operations are only available for `Int` tensors.
|
|||
| ------------------------------------------------ | ------------------------------------------------------- |
|
||||
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
|
||||
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.float()` | `tensor.to(torch.float)` |
|
||||
| `tensor.from_ints(ints)` | N/A |
|
||||
| `tensor.int_random(shape, distribution, device)` | N/A |
|
||||
|
||||
|
@ -278,9 +280,9 @@ Those operations are only available for `Int` tensors.
|
|||
Those operations are only available for `Bool` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------- | ----------------------------------- |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||
| ------------------- | ------------------------------- |
|
||||
| `tensor.float()` | `tensor.to(torch.float)` |
|
||||
| `tensor.int()` | `tensor.to(torch.long)` |
|
||||
| `tensor.not()` | `tensor.logical_not()` |
|
||||
| `tensor.argwhere()` | `tensor.argwhere()` |
|
||||
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
|
||||
|
@ -288,16 +290,16 @@ Those operations are only available for `Bool` tensors.
|
|||
## Activation Functions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------------------------------- | ----------------------------------------------------- |
|
||||
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
|
||||
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
|
||||
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
|
||||
| `activation::mish(tensor)` | Similar to `nn.functional.mish(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
|
||||
| `activation::quiet_softmax(tensor, dim)` | Similar to `nn.functional.quiet_softmax(tensor, dim)` |
|
||||
| `activation::relu(tensor)` | Similar to `nn.functional.relu(tensor)` |
|
||||
| `activation::sigmoid(tensor)` | Similar to `nn.functional.sigmoid(tensor)` |
|
||||
| `activation::silu(tensor)` | Similar to `nn.functional.silu(tensor)` |
|
||||
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
|
||||
| ---------------------------------------- | ------------------------------------------ |
|
||||
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
|
||||
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
|
||||
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
|
||||
| `activation::mish(tensor)` | `nn.functional.mish(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` |
|
||||
| `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` |
|
||||
| `activation::relu(tensor)` | `nn.functional.relu(tensor)` |
|
||||
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
|
||||
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
|
||||
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
|
||||
|
|
|
@ -356,4 +356,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
B::int_sign(tensor)
|
||||
}
|
||||
|
||||
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
B::int_prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
B::int_prod_dim(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2379,6 +2379,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
.parents([&tensor])
|
||||
.stateless(B::float_sign(tensor.primitive))
|
||||
}
|
||||
|
||||
// TODO: Implement float_prod and float_sum
|
||||
// https://github.com/tracel-ai/burn/issues/1458
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
|
@ -313,6 +313,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
|
||||
}
|
||||
|
||||
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
|
||||
}
|
||||
|
||||
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
todo!("prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
|
||||
}
|
||||
|
||||
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
|
||||
panic!("Not supported by Candle")
|
||||
|
|
|
@ -1060,6 +1060,47 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
unary_int_ops!(ProdOps, B::int_prod);
|
||||
|
||||
let stream = tensor.stream;
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())),
|
||||
ProdOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(ProdDimOps, B::int_prod_dim, usize, noconvert);
|
||||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())),
|
||||
ProdDimOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
unary_int_ops!(MeanOps, B::int_mean);
|
||||
|
||||
|
|
|
@ -614,6 +614,19 @@ impl<E: Element> NumericOperationDescription<E> {
|
|||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOperationDescription::Prod(desc) => {
|
||||
NumericOperationDescription::Prod(UnaryOperationDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOperationDescription::ProdDim(desc) => {
|
||||
NumericOperationDescription::ProdDim(ScalarOperationDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs, // Dim should stay the same.
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOperationDescription::EqualElem(desc) => {
|
||||
NumericOperationDescription::EqualElem(ScalarOperationDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
|
|
|
@ -300,6 +300,19 @@ pub enum NumericOperationDescription<E> {
|
|||
/// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim).
|
||||
/// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim).
|
||||
SumDim(ScalarOperationDescription<usize>),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [prod](burn_tensor::ops::FloatTensorOps::float_prod).
|
||||
/// Int => [prod](burn_tensor::ops::IntTensorOps::int_prod).
|
||||
Prod(UnaryOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [prod dim](burn_tensor::ops::FloatTensorOps::float_prod_dim).
|
||||
/// Int => [prod dim](burn_tensor::ops::IntTensorOps::int_prod_dim).
|
||||
ProdDim(ScalarOperationDescription<usize>),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem).
|
||||
|
@ -1141,6 +1154,12 @@ impl<E: Element> NumericOperationDescription<E> {
|
|||
NumericOperationDescription::SumDim(desc) => {
|
||||
vec![&desc.lhs, &desc.out]
|
||||
}
|
||||
NumericOperationDescription::Prod(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
NumericOperationDescription::ProdDim(desc) => {
|
||||
vec![&desc.lhs, &desc.out]
|
||||
}
|
||||
NumericOperationDescription::Max(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
|
@ -1358,6 +1377,8 @@ impl<E> core::hash::Hash for NumericOperationDescription<E> {
|
|||
NumericOperationDescription::Mean(desc) => desc.hash(state),
|
||||
NumericOperationDescription::Sum(desc) => desc.hash(state),
|
||||
NumericOperationDescription::SumDim(desc) => desc.hash(state),
|
||||
NumericOperationDescription::Prod(desc) => desc.hash(state),
|
||||
NumericOperationDescription::ProdDim(desc) => desc.hash(state),
|
||||
NumericOperationDescription::EqualElem(desc) => desc.hash(state),
|
||||
NumericOperationDescription::Greater(desc) => desc.hash(state),
|
||||
NumericOperationDescription::GreaterElem(desc) => desc.hash(state),
|
||||
|
|
|
@ -255,6 +255,19 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
|
|||
kernel::reduce::sum_dim(tensor, dim, Default::default())
|
||||
}
|
||||
|
||||
fn int_prod<const D: usize>(_tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
// kernel::reduce::prod(tensor, Default::default())
|
||||
todo!("prod for int tensor")
|
||||
}
|
||||
|
||||
fn int_prod_dim<const D: usize>(
|
||||
_tensor: IntTensor<Self, D>,
|
||||
_dim: usize,
|
||||
) -> IntTensor<Self, D> {
|
||||
// kernel::reduce::prod_dim(tensor, dim, Default::default())
|
||||
todo!("prod_dim for int tensor")
|
||||
}
|
||||
|
||||
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
kernel::reduce::mean_dim(tensor, dim, Default::default())
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use burn_tensor::Element;
|
|||
use libm::{exp, fabs, log, log1p, pow, sqrt};
|
||||
use libm::{expf, fabsf, log1pf, logf, powf, sqrtf};
|
||||
use ndarray::LinalgScalar;
|
||||
use num_traits::One;
|
||||
use num_traits::Signed;
|
||||
|
||||
/// A float element for ndarray backend.
|
||||
|
@ -14,6 +15,7 @@ where
|
|||
/// A general element for ndarray backend.
|
||||
pub trait NdArrayElement:
|
||||
Element
|
||||
+ One
|
||||
+ ndarray::LinalgScalar
|
||||
+ ndarray::ScalarOperand
|
||||
+ ExpElement
|
||||
|
|
|
@ -14,7 +14,7 @@ use ndarray::IxDyn;
|
|||
use ndarray::SliceInfoElem;
|
||||
|
||||
use crate::element::NdArrayElement;
|
||||
use crate::ops::macros::{keepdim, mean_dim, sum_dim};
|
||||
use crate::ops::macros::{keepdim, mean_dim, prod_dim, sum_dim};
|
||||
use crate::{reshape, tensor::NdArrayTensor};
|
||||
|
||||
pub struct NdArrayOps<E> {
|
||||
|
@ -202,6 +202,11 @@ where
|
|||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
pub fn prod<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
|
||||
let data = Data::from([tensor.array.product()]);
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
pub fn mean_dim<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
|
@ -229,6 +234,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn prod_dim<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
match D {
|
||||
1 => keepdim!(0, dim, tensor, prod),
|
||||
2 => keepdim!(1, dim, tensor, prod),
|
||||
3 => keepdim!(2, dim, tensor, prod),
|
||||
4 => keepdim!(3, dim, tensor, prod),
|
||||
5 => keepdim!(4, dim, tensor, prod),
|
||||
6 => keepdim!(5, dim, tensor, prod),
|
||||
_ => panic!("Dim not supported {D}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
|
|
|
@ -279,6 +279,17 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
|
|||
NdArrayMathOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_prod<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, 1> {
|
||||
NdArrayMathOps::prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim<const D: usize>(
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
dim: usize,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, 1> {
|
||||
NdArrayMathOps::mean(tensor)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,17 @@ macro_rules! keepdim {
|
|||
shape.dims[$dim] = 1;
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}};
|
||||
(
|
||||
$D:expr,
|
||||
$dim:expr,
|
||||
$self:expr,
|
||||
prod
|
||||
) => {{
|
||||
let tensor: NdArrayTensor<E, $D> = prod_dim($self.clone(), $dim);
|
||||
let mut shape = $self.shape();
|
||||
shape.dims[$dim] = 1;
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}};
|
||||
}
|
||||
|
||||
pub(crate) use keepdim;
|
||||
|
@ -45,3 +56,15 @@ pub(crate) fn sum_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
|
|||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
dim: usize,
|
||||
) -> NdArrayTensor<E, D2> {
|
||||
let array = tensor
|
||||
.array
|
||||
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
|
|
@ -312,11 +312,6 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn sum<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
|
||||
let tensor = tensor.tensor.sum(E::KIND);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchTensor::from_existing(
|
||||
tensor
|
||||
|
@ -326,6 +321,11 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn sum<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
|
||||
let tensor = tensor.tensor.sum(E::KIND);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchTensor::from_existing(
|
||||
tensor
|
||||
|
@ -335,6 +335,18 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn prod<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
|
||||
let tensor = tensor.tensor.prod(E::KIND);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn prod_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchTensor::from_existing(
|
||||
tensor.tensor.prod_dim_int(dim as i64, true, E::KIND),
|
||||
tensor.storage,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn argmax<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.argmax(dim as i64, true);
|
||||
|
|
|
@ -263,6 +263,14 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_prod<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
|
||||
TchOps::prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
TchOps::prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
|
||||
let tensor: TchTensor<f64, D> =
|
||||
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
|
||||
|
|
|
@ -318,12 +318,20 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::sum(tensor)
|
||||
}
|
||||
|
||||
fn float_sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_mean_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchOps::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_sum_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchOps::sum_dim(tensor, dim)
|
||||
fn float_prod<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, 1> {
|
||||
TchOps::prod(tensor)
|
||||
}
|
||||
|
||||
fn float_prod_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchOps::prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_to_full_precision<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<f32, D> {
|
||||
|
|
|
@ -116,18 +116,33 @@ where
|
|||
Tensor::new(K::sum(self.primitive))
|
||||
}
|
||||
|
||||
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation.
|
||||
/// Aggregate all elements along the given *dimension* or *axis*
|
||||
/// in the tensor with the mean operation.
|
||||
pub fn mean_dim(self, dim: usize) -> Self {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
|
||||
Self::new(K::mean_dim(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation.
|
||||
/// Aggregate all elements along the given *dimension* or *axis*
|
||||
/// in the tensor with the sum operation.
|
||||
pub fn sum_dim(self, dim: usize) -> Self {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
|
||||
Self::new(K::sum_dim(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Aggregate all elements along the given *dimension* or *axis*
|
||||
/// in the tensor with the product operation.
|
||||
pub fn prod(self) -> Tensor<B, 1, K> {
|
||||
Tensor::new(K::prod(self.primitive))
|
||||
}
|
||||
|
||||
/// Aggregate all elements along the given *dimension* or *axis*
|
||||
/// in the tensor with the product operation.
|
||||
pub fn prod_dim(self, dim: usize) -> Self {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
|
||||
Self::new(K::prod_dim(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Applies element wise equal comparison and returns a boolean tensor.
|
||||
pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
|
||||
K::equal_elem::<D>(self.primitive, other.elem())
|
||||
|
@ -1024,6 +1039,51 @@ where
|
|||
/// which is more high-level and designed for public use.
|
||||
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Computes the product of all the elements of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the product of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The product of all the elements of the tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For computing the product of all the elements of a tensor, users should prefer the
|
||||
/// [Tensor::prod](Tensor::prod) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||
|
||||
/// Computes the product of all the elements of the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the product of.
|
||||
/// * `dim` - The dimension along which to compute the product.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The product of all the elements of the tensor along the specified dimension.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For computing the product of all the elements of a tensor along a dimension, users should
|
||||
/// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
///
|
||||
///
|
||||
fn prod_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Computes the mean of all the elements of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -1881,12 +1941,23 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
) -> Self::Primitive<D> {
|
||||
B::int_full(shape, fill_value.elem(), device)
|
||||
}
|
||||
|
||||
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
B::int_sum(tensor)
|
||||
}
|
||||
|
||||
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
|
||||
B::int_sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
B::int_prod(tensor)
|
||||
}
|
||||
|
||||
fn prod_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
|
||||
B::int_prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
B::int_mean(tensor)
|
||||
}
|
||||
|
@ -2174,6 +2245,7 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
|
||||
B::float_ones(shape, device)
|
||||
}
|
||||
|
||||
fn full<const D: usize, E: ElementConversion>(
|
||||
shape: Shape<D>,
|
||||
fill_value: E,
|
||||
|
@ -2181,15 +2253,27 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
) -> Self::Primitive<D> {
|
||||
B::float_full(shape, fill_value.elem(), device)
|
||||
}
|
||||
|
||||
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
B::float_sum(tensor)
|
||||
}
|
||||
|
||||
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
|
||||
B::float_sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn prod<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
B::float_prod(tensor)
|
||||
}
|
||||
|
||||
fn prod_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
|
||||
B::float_prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
B::float_mean(tensor)
|
||||
}
|
||||
|
||||
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D> {
|
||||
B::float_mean_dim(tensor, dim)
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use crate::Distribution;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::{identities::Zero, ToPrimitive};
|
||||
use num_traits::{identities::Zero, One, ToPrimitive};
|
||||
use rand::RngCore;
|
||||
|
||||
/// Element trait for tensor.
|
||||
pub trait Element:
|
||||
ToPrimitive
|
||||
+ Zero
|
||||
+ One
|
||||
+ ElementRandom
|
||||
+ ElementConversion
|
||||
+ ElementPrecision
|
||||
|
|
|
@ -757,6 +757,29 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// The sum of all elements in the tensor along the dimension.
|
||||
fn int_sum_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D>;
|
||||
|
||||
/// Computes the product of all elements in the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the product of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The product of all elements in the tensor.
|
||||
fn int_prod<const D: usize>(tensor: IntTensor<B, D>) -> IntTensor<B, 1>;
|
||||
|
||||
/// Computes the product of all elements in the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the product of.
|
||||
/// * `dim` - The dimension to compute the product along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The product of all elements in the tensor along the dimension.
|
||||
fn int_prod_dim<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D>;
|
||||
|
||||
/// Computes the mean of all elements in the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -831,6 +831,34 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// A tensor with the sum of all elements in `tensor` along `dim`.
|
||||
fn float_sum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
|
||||
|
||||
/// Product of all elements in a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to product.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A scalar tensor with the product of all elements in `tensor`.
|
||||
fn float_prod<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
|
||||
// Product of all elements in a tensor
|
||||
B::float_exp(B::float_sum(B::float_log(tensor)))
|
||||
}
|
||||
|
||||
/// Product of all elements in a tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to product.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the product of all elements in `tensor` along `dim`.
|
||||
fn float_prod_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D> {
|
||||
// Product of all elements in a tensor along a dimension
|
||||
B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
|
||||
}
|
||||
|
||||
/// Mean of all elements in a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -122,4 +122,59 @@ mod tests {
|
|||
Data::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], Shape::new([2, 1, 3]))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prod_float() {
|
||||
let tensor = TestTensor::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_actual = tensor.prod().to_data();
|
||||
|
||||
// 2 * 1 * 2 * 3 * 4 * 5 = 240 but we need to check the precision because of the float
|
||||
Data::from([240.0]).assert_approx_eq(&data_actual, 4);
|
||||
|
||||
let tensor_with_zero = TestTensor::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_actual = tensor_with_zero.prod().to_data();
|
||||
|
||||
assert_eq!(data_actual, Data::from([0.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Not implemented for all backends yet"]
|
||||
fn test_prod_int() {
|
||||
let tensor = TestTensorInt::from([[2, 1, 2], [3, 4, 5]]);
|
||||
let data_actual = tensor.prod().to_data();
|
||||
|
||||
assert_eq!(data_actual, Data::from([240]));
|
||||
|
||||
let tensor_with_zero = TestTensorInt::from([[2, 0, 2], [3, 4, 5]]);
|
||||
let data_actual = tensor_with_zero.prod().to_data();
|
||||
|
||||
assert_eq!(data_actual, Data::from([0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prod_dim_float() {
|
||||
let tensor = TestTensor::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_actual = tensor.prod_dim(1).to_data();
|
||||
|
||||
Data::from([[4.0], [60.0]]).assert_approx_eq(&data_actual, 4);
|
||||
|
||||
let tensor_with_zero = TestTensor::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_actual = tensor_with_zero.prod_dim(1).to_data();
|
||||
|
||||
Data::from([[0.0], [60.0]]).assert_approx_eq(&data_actual, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Not implemented for all backends yet"]
|
||||
fn test_prod_dim_int() {
|
||||
let tensor = TestTensorInt::from([[2, 1, 2], [3, 4, 5]]);
|
||||
let data_actual = tensor.prod_dim(1).to_data();
|
||||
|
||||
assert_eq!(data_actual, Data::from([[4], [60]]));
|
||||
|
||||
let tensor_with_zero = TestTensorInt::from([[2, 0, 2], [3, 4, 5]]);
|
||||
let data_actual = tensor_with_zero.prod_dim(1).to_data();
|
||||
|
||||
assert_eq!(data_actual, Data::from([[0], [60]]));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue