Tensor expand operator (#1508)

* Improve CI cache - remove burn-tch artifacts

* PyTorch config deserializer from .pt file

* Update pytorch-model.md

* WIP

* Rename broadcast_to to expand

* Rename broadcast_to expand file

* Implemented fusion backend and fix bugs

* Remove old files

* Remove unused state

* Rename to the correct op name

* Add missing comment

* Fix expand check function doc

* Rename the leftover names

* Rename leftover names
This commit is contained in:
Dilshod Tadjibaev 2024-03-22 16:33:53 -05:00 committed by GitHub
parent dc45cf1700
commit 6feda90a8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 783 additions and 34 deletions

View File

@ -144,6 +144,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |

View File

@ -128,4 +128,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
B::bool_nonzero(tensor)
}
fn bool_expand<const D: usize, const D2: usize>(
tensor: BoolTensor<B, D>,
shape: Shape<D2>,
) -> BoolTensor<B, D2> {
B::bool_expand(tensor, shape)
}
}

View File

@ -369,6 +369,13 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_prod_dim(tensor, dim)
}
fn int_expand<const D: usize, const D2: usize>(
tensor: IntTensor<B, D>,
shape: Shape<D2>,
) -> IntTensor<B, D2> {
B::int_expand(tensor, shape)
}
fn int_sort<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,

View File

@ -2437,6 +2437,81 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateless(B::float_sign(tensor.primitive))
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
#[derive(Debug)]
struct ExpandDim<const D1: usize, const D2: usize>;
#[derive(new, Debug)]
struct RetroExpand<B: Backend, const D1: usize, const D2: usize> {
input_id: NodeID,
shape: Shape<D2>,
_backend: PhantomData<B>,
}
impl<B: Backend, const D1: usize, const D2: usize> RetroForward for RetroExpand<B, D1, D2> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let input = states.get_state::<B::FloatTensorPrimitive<D1>>(&self.input_id);
let out = B::float_expand(input, self.shape.clone());
states.save(out_node, out)
}
}
impl<B: Backend, const D1: usize, const D2: usize> Backward<B, D2, 1> for ExpandDim<D1, D2> {
type State = Shape<D1>;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let shape_original = ops.state;
let mut shape_expanded = [1; D2];
debug_assert!(D2 >= D1);
for i in 0..D1 {
shape_expanded[i + (D2 - D1)] = shape_original.dims[i];
}
unary::<B, D2, D1, _>(ops.parents, ops.node, grads, |grad| {
let shape_grad = B::float_shape(&grad);
let mut grad = grad;
#[allow(clippy::needless_range_loop)]
for i in 0..D2 {
if shape_expanded[i] == 1 && shape_grad.dims[i] != 1 {
grad = B::float_sum_dim(grad, i);
}
}
B::float_reshape(grad, shape_original)
});
}
}
match ExpandDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroExpand::<B, D1, D2>::new(
tensor.node.id.clone(),
shape.clone(),
))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
B::float_shape(&tensor.primitive),
B::float_expand(tensor.primitive, shape),
),
OpsKind::UnTracked(prep) => prep.finish(B::float_expand(tensor.primitive, shape)),
}
}
fn float_sort<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,

View File

@ -0,0 +1,38 @@
#[burn_tensor_testgen::testgen(ad_expand)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_diff_expand() {
// Python code to generate the test case values
// import torch
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
// y = x1.expand(4, 4)
// z = (x2 * y).sum()
// z.backward()
// print("x1", x1.grad)
// print("x2", x2.grad)
let device = Default::default();
let data_1: Data<f32, 1> = Data::from([4.0, 7.0, 2.0, 3.0]);
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let data_2: Data<f32, 1> = Data::from([2.0, 4.5, 7.0, 3.0]);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().expand([4, 4]);
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([8., 18., 28., 12.]));
assert_eq!(grad_2.to_data(), Data::from([16., 28., 8., 12.]));
}
}

View File

@ -21,6 +21,7 @@ mod cross_entropy;
mod div;
mod erf;
mod exp;
mod expand;
mod flip;
mod gather_scatter;
mod gelu;
@ -119,6 +120,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_flip!();
burn_autodiff::testgen_ad_nonzero!();
burn_autodiff::testgen_ad_sign!();
burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!();
};
}

View File

@ -102,6 +102,7 @@ mod tests {
burn_tensor::testgen_sub!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_transpose!();
burn_tensor::testgen_expand!();
// test stats
burn_tensor::testgen_var!();
@ -157,4 +158,5 @@ mod tests {
burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_expand!();
}

View File

@ -142,3 +142,10 @@ pub fn chunk<E: CandleElement, const D: usize>(
Err(e) => panic!("error chunk from Candle"),
}
}
pub fn expand<E: CandleElement, const D1: usize, const D2: usize>(
tensor: CandleTensor<E, D1>,
shape: Shape<D2>,
) -> CandleTensor<E, D2> {
CandleTensor::new(tensor.tensor.broadcast_as(&shape.dims).unwrap())
}

View File

@ -8,6 +8,8 @@ use crate::{
Candle, CandleTensor,
};
use super::base::{expand, permute};
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
super::base::empty(shape, device)
@ -140,4 +142,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
) -> BoolTensor<Self, D> {
super::base::flip(tensor, axes)
}
fn bool_expand<const D1: usize, const D2: usize>(
tensor: BoolTensor<Self, D1>,
shape: Shape<D2>,
) -> BoolTensor<Self, D2> {
expand(tensor, shape)
}
}

View File

@ -8,6 +8,8 @@ use crate::{
Candle, CandleTensor,
};
use super::base::{expand, permute};
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
super::base::empty(shape, device)
@ -430,6 +432,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
super::base::flip(tensor, axes)
}
fn int_expand<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
shape: Shape<D2>,
) -> IntTensor<Self, D2> {
expand(tensor, shape)
}
// TODO add sign operator once Candle supports it:
// https://github.com/huggingface/candle/issues/1827
}

View File

@ -11,6 +11,8 @@ use crate::{
Candle, CandleTensor,
};
use super::base::{expand, permute};
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data<const D: usize>(
data: Data<F, D>,
@ -530,6 +532,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
super::base::flip(tensor, axes)
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
expand(tensor, shape)
}
// TODO add sign operator once Candle supports it:
// https://github.com/huggingface/candle/issues/1827
}

View File

@ -4,9 +4,10 @@ use crate::{
ops::binary::binary_ops_shape,
stream::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, FlipOperationDescription, Operation, OperationDescription,
PermuteOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, StreamId, SwapDimsDescription, UnaryOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation,
OperationDescription, PermuteOperationDescription, ReshapeDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
},
Fusion, FusionBackend,
};
@ -467,6 +468,44 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out
}
fn bool_expand<const D1: usize, const D2: usize>(
tensor: BoolTensor<Self, D1>,
shape: Shape<D2>,
) -> BoolTensor<Self, D2> {
#[derive(new)]
struct ExpandOps<const D: usize, const D2: usize> {
desc: ExpandOperationDescription,
}
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
let output = B::bool_expand(input, shape.into());
handles.register_bool_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(shape.dims.into());
let desc = ExpandOperationDescription {
input: tensor.into_description(),
shape: shape.dims.into(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Expand(desc.clone())),
ExpandOps::<D1, D2>::new(desc),
);
out
}
fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],

View File

@ -6,13 +6,14 @@ use crate::{
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
stream::{
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, FlipOperationDescription, FloatOperationDescription,
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, Operation,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
},
unary_float_ops, Fusion, FusionBackend, TensorDescription,
};
@ -1847,6 +1848,44 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
#[derive(new)]
struct ExpandOps<const D: usize, const D2: usize> {
desc: ExpandOperationDescription,
}
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
let output = B::float_expand(input, shape.into());
handles.register_float_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(shape.dims.into());
let desc = ExpandOperationDescription {
input: tensor.into_description(),
shape: shape.dims.into(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Expand(desc.clone())),
ExpandOps::<D1, D2>::new(desc),
);
out
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],

View File

@ -6,13 +6,13 @@ use crate::{
scalar_int_cmp_ops, scalar_int_ops,
stream::{
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, FlipOperationDescription, GatherOperationDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
},
unary_int_ops, Fusion, FusionBackend, TensorDescription,
};
@ -1553,6 +1553,43 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}
fn int_expand<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
shape: Shape<D2>,
) -> IntTensor<Self, D2> {
#[derive(new)]
struct ExpandOps<const D: usize, const D2: usize> {
desc: ExpandOperationDescription,
}
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
let output = B::bool_expand(input, shape.into());
handles.register_bool_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(shape.dims.into());
let desc = ExpandOperationDescription {
input: tensor.into_description(),
shape: shape.dims.into(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Expand(desc.clone())),
ExpandOps::<D1, D2>::new(desc),
);
out
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
#[derive(new)]
struct FlipDimsOps<const D: usize> {

View File

@ -4,16 +4,17 @@ use super::{
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription,
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
EmbeddingBackwardDescription, EmbeddingDescription, FlipOperationDescription,
FloatOperationDescription, GatherOperationDescription, IntOperationDescription,
InterpolateBackwardDescription, InterpolateDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, ModuleOperationDescription, NumericOperationDescription,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
EmbeddingBackwardDescription, EmbeddingDescription, ExpandOperationDescription,
FlipOperationDescription, FloatOperationDescription, GatherOperationDescription,
IntOperationDescription, InterpolateBackwardDescription, InterpolateDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
};
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
use burn_tensor::{Element, ElementConversion};
@ -798,6 +799,13 @@ impl BaseOperationDescription {
axes: desc.axes.clone(),
})
}
BaseOperationDescription::Expand(desc) => {
BaseOperationDescription::Expand(ExpandOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
shape: desc.shape.clone(),
})
}
BaseOperationDescription::Flip(desc) => {
BaseOperationDescription::Flip(FlipOperationDescription {
input: desc.input.to_relative(converter),

View File

@ -141,6 +141,7 @@ pub enum BaseOperationDescription {
/// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape).
/// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape).
Reshape(ReshapeDescription),
/// Operation corresponding to:
///
/// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims).
@ -161,6 +162,13 @@ pub enum BaseOperationDescription {
/// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
Flip(FlipOperationDescription),
/// Operation corresponding to:
///
/// Float => [expand](burn_tensor::ops::FloatTensorOps::float_expand).
/// Int => [expand](burn_tensor::ops::IntTensorOps::int_expand).
/// Bool => [expand](burn_tensor::ops::BoolTensorOps::bool_expand).
Expand(ExpandOperationDescription),
/// Operation corresponding to:
///
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
@ -462,6 +470,17 @@ pub struct PermuteOperationDescription {
pub axes: Vec<usize>,
}
/// Expand operation description.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct ExpandOperationDescription {
/// Input tensor description.
pub input: TensorDescription,
/// Output tensor description.
pub out: TensorDescription,
/// The new shape.
pub shape: Vec<usize>,
}
/// Flip operation description.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct FlipOperationDescription {
@ -487,6 +506,13 @@ pub struct ReshapeDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ExpandDescription {
pub input: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct BinaryOperationDescription {
@ -1054,6 +1080,11 @@ impl BaseOperationDescription {
BaseOperationDescription::Permute(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Expand(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Flip(desc) => {
vec![&desc.input, &desc.out]
}

Binary file not shown.

View File

@ -1,3 +1,5 @@
use std::marker::PhantomData;
use crate::{element::JitElement, kernel, tensor::JitTensor, Runtime};
use burn_tensor::{Data, Reader, Shape};
@ -80,6 +82,55 @@ pub(crate) fn permute<R: Runtime, E: JitElement, const D: usize>(
tensor
}
pub(crate) fn expand<R: Runtime, E: JitElement, const D: usize, const D_OUT: usize>(
tensor: JitTensor<R, E, D>,
target_shape: Shape<D_OUT>,
) -> JitTensor<R, E, D_OUT> {
// Initialize new strides with zeros
let mut new_strides = [0usize; D_OUT];
// Calculate the difference in dimensions
let dim_diff = D_OUT.saturating_sub(D);
// Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones
let mut tensor_dim_iter = tensor.shape.dims.iter().rev();
for i in (0..D_OUT).rev() {
if i >= dim_diff {
if let Some(&tensor_dim) = tensor_dim_iter.next() {
if tensor_dim == target_shape.dims[i] || tensor_dim == 1 {
// Copy stride for non-broadcast dimensions or set to 0 for broadcast ones
new_strides[i] = if tensor_dim == target_shape.dims[i] {
tensor.strides[i - dim_diff]
} else {
0
};
} else {
// Error handling: Dimension mismatch for broadcasting
panic!(
"Dimension mismatch: cannot broadcast dimension {} of tensor to target shape",
tensor_dim
);
}
} else {
// If the input tensor has fewer dimensions, treat missing dimensions as 1
// and set stride to 0 (broadcasting)
new_strides[i] = 0;
}
} else {
// For extra dimensions in the target shape, set stride to 0 (broadcasting)
new_strides[i] = 0;
}
}
JitTensor {
client: tensor.client,
device: tensor.device,
shape: target_shape,
strides: new_strides,
handle: tensor.handle,
elem: PhantomData,
}
}
pub(crate) fn reshape<R: Runtime, E: JitElement, const D1: usize, const D2: usize>(
tensor: JitTensor<R, E, D1>,

View File

@ -4,7 +4,7 @@ use burn_tensor::Reader;
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
use std::ops::Range;
use super::permute;
use super::{expand, permute};
impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
@ -114,6 +114,13 @@ impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
permute(tensor, axes)
}
fn bool_expand<const D1: usize, const D2: usize>(
tensor: BoolTensor<Self, D1>,
shape: Shape<D2>,
) -> BoolTensor<Self, D2> {
expand(tensor, shape)
}
fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],

View File

@ -1,4 +1,4 @@
use super::{numeric, permute};
use super::{expand, numeric, permute};
use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryOperator};
use crate::kernel::matmul::{matmul, MatmulStrategy};
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
@ -521,6 +521,13 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
permute(tensor, axes)
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
expand(tensor, shape)
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],

View File

@ -1,4 +1,4 @@
use super::{numeric, permute};
use super::{expand, numeric, permute};
use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator};
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
use crate::{kernel, unary, JitBackend, Runtime};
@ -340,6 +340,13 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
permute(tensor, axes)
}
fn int_expand<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
shape: Shape<D2>,
) -> IntTensor<Self, D2> {
expand(tensor, shape)
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
kernel::flip(tensor, axes)
}

View File

@ -176,7 +176,7 @@ where
for d in 0..D {
let stride = self.strides[D - 1 - d];
if stride < current_stride {
if stride <= current_stride {
return false;
}

View File

@ -1,9 +1,11 @@
use alloc::vec::Vec;
use burn_tensor::Data;
use burn_tensor::ElementConversion;
use core::fmt::Debug;
use core::{marker::PhantomData, ops::Range};
use ndarray::s;
use ndarray::Array2;
use ndarray::IntoDimension;
use ndarray::SliceInfo;
use ndarray::Zip;
use num_traits::Signed;
@ -28,7 +30,7 @@ pub(crate) struct NdArrayMathOps<E> {
impl<E> NdArrayOps<E>
where
E: Copy,
E: Copy + Debug,
{
pub fn slice<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
@ -113,6 +115,22 @@ where
NdArrayTensor::new(array)
}
/// Broadcasts the tensor to the given shape
pub(crate) fn expand<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<E, D2> {
let array = tensor
.array
.broadcast(shape.dims.into_dimension())
.expect("The shapes should be broadcastable")
// need to convert view to owned array because NdArrayTensor expects owned array
// and try_into_owned_nocopy() panics for broadcasted arrays (zero strides)
.into_owned()
.into_shared();
NdArrayTensor { array }
}
pub fn flip<const D: usize>(
tensor: NdArrayTensor<E, D>,
axes: &[usize],

View File

@ -138,6 +138,13 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
NdArrayTensor { array }
}
fn bool_expand<const D1: usize, const D2: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D1>,
shape: Shape<D2>,
) -> burn_tensor::ops::BoolTensor<Self, D2> {
NdArrayOps::expand(tensor, shape)
}
fn bool_flip<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D>,
axes: &[usize],

View File

@ -456,4 +456,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
fn int_sign<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
NdArrayMathOps::sign_op(tensor)
}
fn int_expand<const D1: usize, const D2: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D1>,
shape: Shape<D2>,
) -> burn_tensor::ops::IntTensor<Self, D2> {
NdArrayOps::expand(tensor, shape)
}
}

View File

@ -504,4 +504,11 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_sign<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::sign_op(tensor)
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> burn_tensor::ops::FloatTensor<Self, D2> {
NdArrayOps::expand(tensor, shape)
}
}

View File

@ -495,6 +495,14 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
}
pub fn expand<const D: usize, const D2: usize>(
tensor: TchTensor<E, D>,
shape: Shape<D2>,
) -> TchTensor<E, D2> {
let tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64));
TchTensor::new(tensor)
}
pub fn sort<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,

View File

@ -159,4 +159,11 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
.map(TchTensor::new)
.collect()
}
fn bool_expand<const D1: usize, const D2: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D1>,
shape: Shape<D2>,
) -> burn_tensor::ops::BoolTensor<Self, D2> {
TchOps::expand(tensor, shape)
}
}

View File

@ -487,6 +487,13 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
TchOps::sign(tensor)
}
fn int_expand<const D1: usize, const D2: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D1>,
shape: Shape<D2>,
) -> burn_tensor::ops::IntTensor<Self, D2> {
TchOps::expand(tensor, shape)
}
fn int_sort<const D: usize>(
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
dim: usize,

View File

@ -496,6 +496,13 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
TchOps::sign(tensor)
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> burn_tensor::ops::FloatTensor<Self, D2> {
TchOps::expand(tensor, shape)
}
fn float_sort<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
dim: usize,

View File

@ -10,6 +10,7 @@ use alloc::string::String;
use alloc::vec;
use burn_common::{reader::Reader, stub::Mutex};
use core::iter::repeat;
use core::{fmt::Debug, ops::Range};
use serde::{Deserialize, Deserializer};
@ -750,6 +751,29 @@ where
let data = self.into_data().await;
data.value[0]
}
/// Broadcast the tensor to the given shape.
///
/// # Arguments
///
/// * `shape` - The shape to broadcast the tensor to.
/// Can contain -1 for dimensions that should be inferred.
/// The number of elements in the shape must be greater or equal as
/// the number of dimensions of the tensor.
///
/// # Panics
///
/// If the tensor cannot be broadcasted to the given shape.
///
/// # Returns
///
/// A new tensor with the given shape.
pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
let shape = shape.into_shape(&self.shape());
check!(TensorCheck::expand("expand", &self.shape(), &shape,));
Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
}
}
/// Iterator given by (Tensor::iter_dim).
@ -1461,7 +1485,6 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
/// which is more high-level and designed for public use.
fn all<const D: usize>(tensor: Self::Primitive<D>) -> Tensor<B, 1, Bool>;
/// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
@ -1481,8 +1504,22 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
/// which is more high-level and designed for public use.
fn all_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Tensor<B, D, Bool>;
/// Broadcasts the given tensor to the specified shape.
///
/// # Arguments
///
/// * `tensor` - The tensor to broadcast.
/// * `shape` - The shape to broadcast to.
///
/// # Returns
///
/// The broadcasted tensor.
fn expand<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
shape: Shape<D2>,
) -> Self::Primitive<D2>;
}
impl<B: Backend> BasicOps<B> for Float {
@ -1491,6 +1528,7 @@ impl<B: Backend> BasicOps<B> for Float {
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
B::float_empty(shape, device)
}
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D> {
B::float_shape(tensor)
}
@ -1598,6 +1636,13 @@ impl<B: Backend> BasicOps<B> for Float {
B::float_permute(tensor, axes)
}
fn expand<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
shape: Shape<D2>,
) -> Self::Primitive<D2> {
B::float_expand(tensor, shape)
}
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
B::float_flip(tensor, axes)
}
@ -1716,6 +1761,13 @@ impl<B: Backend> BasicOps<B> for Int {
B::int_permute(tensor, axes)
}
fn expand<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
shape: Shape<D2>,
) -> Self::Primitive<D2> {
B::int_expand(tensor, shape)
}
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
B::int_flip(tensor, axes)
}
@ -1834,6 +1886,13 @@ impl<B: Backend> BasicOps<B> for Bool {
B::bool_permute(tensor, axes)
}
fn expand<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
shape: Shape<D2>,
) -> Self::Primitive<D2> {
B::bool_expand(tensor, shape)
}
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
B::bool_flip(tensor, axes)
}
@ -1925,6 +1984,55 @@ impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
}
}
/// Trait used for broadcast arguments.
pub trait BroadcastArgs<const D1: usize, const D2: usize> {
/// Converts to a shape.
fn into_shape(self, shape: &Shape<D1>) -> Shape<D2>;
}
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape<D2> {
fn into_shape(self, _shape: &Shape<D1>) -> Shape<D2> {
self
}
}
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
fn into_shape(self, _shape: &Shape<D1>) -> Shape<D2> {
Shape::from(self)
}
}
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [i32; D2] {
// Passing -1 as the size for a dimension means not changing the size of that dimension.
fn into_shape(self, shape: &Shape<D1>) -> Shape<D2> {
if self.len() < shape.dims.len() {
panic!("Broadcast arguments must be greater than the number of dimensions");
}
if self.iter().any(|&x| x < -1 || x == 0) {
panic!("Broadcast arguments must be positive or -1");
}
// Zip the two shapes in reverse order and replace -1 with the actual dimension value.
let new_shape: Vec<_> = self
.iter()
.rev()
.zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
.map(|(&x, &y)| if x == -1 { y } else { x as usize })
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
if new_shape.iter().any(|&x| x == 0) {
panic!("Cannot substitute -1 for a non-existing dimension");
}
let new_shape: [usize; D2] = new_shape.try_into().unwrap();
Shape::from(new_shape)
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
where

View File

@ -889,6 +889,56 @@ impl TensorCheck {
false => self,
}
}
/// Checks if expand operation is possible for the given shapes.
pub fn expand<const D1: usize, const D2: usize>(
ops: &str,
shape: &Shape<D1>,
to: &Shape<D2>,
) -> Self {
let mut check = TensorCheck::Ok;
let max_dims = core::cmp::max(D1, D2);
// Calculate the starting indices for each shape array, ensuring alignment from the right.
let start_index_shape = max_dims.saturating_sub(D1);
let start_index_to = max_dims.saturating_sub(D2);
for i in 0..max_dims {
// Use 1 as the default dimension size for dimensions beyond the tensor's rank.
let d_shape = if i >= start_index_shape {
shape.dims[i - start_index_shape]
} else {
1
};
let d_to = if i >= start_index_to {
to.dims[i - start_index_to]
} else {
1
};
if d_shape != d_to && d_shape != 1 && d_to != 1 {
// Register an incompatibility error.
check = check.register(
ops,
TensorError::new(
"The provided tensor can't be broadcasted to the target shape.",
)
.details(format!(
"Incompatible size at dimension '{}' => '{} != {}', which can't be \
broadcasted. Tensor shape {:?}, Target shape {:?}.",
max_dims - i - 1,
d_shape,
d_to,
shape.dims,
to.dims,
)),
);
break; // Incompatibility found, no need to check further.
}
}
check
}
}
pub(crate) struct FailedTensorCheck {

View File

@ -457,4 +457,10 @@ pub trait BoolTensorOps<B: Backend> {
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
.collect()
}
/// Broadcasts the bool `tensor` to the given `shape`.
fn bool_expand<const D1: usize, const D2: usize>(
tensor: BoolTensor<B, D1>,
shape: Shape<D2>,
) -> BoolTensor<B, D2>;
}

View File

@ -1199,6 +1199,12 @@ pub trait IntTensorOps<B: Backend> {
result
}
/// Broadcasts the int `tensor` to the given `shape`.
fn int_expand<const D1: usize, const D2: usize>(
tensor: IntTensor<B, D1>,
shape: Shape<D2>,
) -> IntTensor<B, D2>;
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).

View File

@ -1373,6 +1373,12 @@ pub trait FloatTensorOps<B: Backend> {
result
}
/// Broadcasts the float `tensor` to the given `shape`.
fn float_expand<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
shape: Shape<D2>,
) -> FloatTensor<B, D2>;
/// Sort the elements of the input `tensor` by value in along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).

View File

@ -90,6 +90,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
burn_tensor::testgen_expand!();
burn_tensor::testgen_tri_mask!();
burn_tensor::testgen_sort_argsort!();
burn_tensor::testgen_topk!();

View File

@ -0,0 +1,111 @@
#[burn_tensor_testgen::testgen(expand)]
mod tests {
use super::*;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn expand_2d() {
let tensor = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0], &Default::default());
let expanded_tensor = tensor.expand([3, 3]);
let expected_data = Data::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
assert_eq!(expanded_tensor.into_data(), expected_data);
let tensor =
Tensor::<TestBackend, 1>::from_floats([4.0, 7.0, 2.0, 3.0], &Default::default());
let expanded_tensor = tensor.expand([2, 4]);
let expected_data = Data::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
fn expand_3d() {
let tensor =
Tensor::<TestBackend, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default());
let expanded_tensor = tensor.expand([3, 2, 2]);
let expected_data = Data::from([
[[1.0, 2.0], [3.0, 4.0]],
[[1.0, 2.0], [3.0, 4.0]],
[[1.0, 2.0], [3.0, 4.0]],
]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
fn expand_higher_dimensions() {
let tensor =
Tensor::<TestBackend, 2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &Default::default());
let expanded_tensor = tensor.expand([2, 3, 4]);
let expected_data = Data::from([
[
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
],
[
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
],
]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
fn broadcast_single() {
let tensor = Tensor::<TestBackend, 1>::from_floats([1.0], &Default::default());
let expanded_tensor = tensor.expand([2, 3]);
let expected_data = Data::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
#[should_panic]
fn should_fail_expand_incompatible_shapes() {
let tensor = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0], &Default::default());
let _expanded_tensor = tensor.expand([2, 2]);
}
#[test]
fn expand_2d_bool() {
let tensor = TestTensorBool::from([false, true, false]);
let expanded_tensor = tensor.expand([3, 3]);
let expected_data = Data::from([
[false, true, false],
[false, true, false],
[false, true, false],
]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
fn expand_2d_int() {
let tensor = TestTensorInt::from([1, 2, 3]);
let expanded_tensor = tensor.expand([3, 3]);
let expected_data = Data::from([[1, 2, 3], [1, 2, 3], [1, 2, 3]]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
fn should_all_negative_one() {
let tensor = TestTensorInt::from([1, 2, 3]);
let expanded_tensor = tensor.expand([2, -1]);
let expected_data = Data::from([[1, 2, 3], [1, 2, 3]]);
assert_eq!(expanded_tensor.into_data(), expected_data);
}
#[test]
#[should_panic]
fn should_panic_negative_one_on_non_existing_dim() {
let tensor = TestTensorInt::from([1, 2, 3]);
let _expanded_tensor = tensor.expand([-1, 3]);
}
}

View File

@ -18,6 +18,7 @@ mod create_like;
mod div;
mod erf;
mod exp;
mod expand;
mod flatten;
mod flip;
mod full;