mirror of https://github.com/tracel-ai/burn.git
Tensor expand operator (#1508)
* Improve CI cache - remove burn-tch artifacts * PyTorch config deserializer from .pt file * Update pytorch-model.md * WIP * Rename broadcast_to to expand * Rename broadcast_to expand file * Implemented fusion backend and fix bugs * Remove old files * Remove unused state * Rename to the correct op name * Add missing comment * Fix expand check function doc * Rename the leftover names * Rename leftover names
This commit is contained in:
parent
dc45cf1700
commit
6feda90a8c
|
@ -144,6 +144,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
||||
| `tensor.any()` | `tensor.any()` |
|
||||
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
||||
| `tensor.expand(shape)` | `tensor.expand(shape)` |
|
||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
|
|
|
@ -128,4 +128,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
|||
fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
|
||||
B::bool_nonzero(tensor)
|
||||
}
|
||||
|
||||
fn bool_expand<const D: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<B, D2> {
|
||||
B::bool_expand(tensor, shape)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -369,6 +369,13 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_expand<const D: usize, const D2: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
shape: Shape<D2>,
|
||||
) -> IntTensor<B, D2> {
|
||||
B::int_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_sort<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -2437,6 +2437,81 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
.stateless(B::float_sign(tensor.primitive))
|
||||
}
|
||||
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> FloatTensor<Self, D2> {
|
||||
#[derive(Debug)]
|
||||
struct ExpandDim<const D1: usize, const D2: usize>;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct RetroExpand<B: Backend, const D1: usize, const D2: usize> {
|
||||
input_id: NodeID,
|
||||
shape: Shape<D2>,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D1: usize, const D2: usize> RetroForward for RetroExpand<B, D1, D2> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
|
||||
let input = states.get_state::<B::FloatTensorPrimitive<D1>>(&self.input_id);
|
||||
let out = B::float_expand(input, self.shape.clone());
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D1: usize, const D2: usize> Backward<B, D2, 1> for ExpandDim<D1, D2> {
|
||||
type State = Shape<D1>;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let shape_original = ops.state;
|
||||
|
||||
let mut shape_expanded = [1; D2];
|
||||
|
||||
debug_assert!(D2 >= D1);
|
||||
|
||||
for i in 0..D1 {
|
||||
shape_expanded[i + (D2 - D1)] = shape_original.dims[i];
|
||||
}
|
||||
|
||||
unary::<B, D2, D1, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let shape_grad = B::float_shape(&grad);
|
||||
let mut grad = grad;
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..D2 {
|
||||
if shape_expanded[i] == 1 && shape_grad.dims[i] != 1 {
|
||||
grad = B::float_sum_dim(grad, i);
|
||||
}
|
||||
}
|
||||
|
||||
B::float_reshape(grad, shape_original)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match ExpandDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroExpand::<B, D1, D2>::new(
|
||||
tensor.node.id.clone(),
|
||||
shape.clone(),
|
||||
))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
B::float_shape(&tensor.primitive),
|
||||
B::float_expand(tensor.primitive, shape),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::float_expand(tensor.primitive, shape)),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_sort<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#[burn_tensor_testgen::testgen(ad_expand)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_diff_expand() {
|
||||
// Python code to generate the test case values
|
||||
// import torch
|
||||
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
|
||||
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
|
||||
// y = x1.expand(4, 4)
|
||||
// z = (x2 * y).sum()
|
||||
// z.backward()
|
||||
// print("x1", x1.grad)
|
||||
// print("x2", x2.grad)
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let data_1: Data<f32, 1> = Data::from([4.0, 7.0, 2.0, 3.0]);
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
|
||||
let data_2: Data<f32, 1> = Data::from([2.0, 4.5, 7.0, 3.0]);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().expand([4, 4]);
|
||||
|
||||
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
|
||||
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([8., 18., 28., 12.]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([16., 28., 8., 12.]));
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@ mod cross_entropy;
|
|||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod expand;
|
||||
mod flip;
|
||||
mod gather_scatter;
|
||||
mod gelu;
|
||||
|
@ -119,6 +120,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_flip!();
|
||||
burn_autodiff::testgen_ad_nonzero!();
|
||||
burn_autodiff::testgen_ad_sign!();
|
||||
burn_autodiff::testgen_ad_expand!();
|
||||
burn_autodiff::testgen_ad_sort!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -102,6 +102,7 @@ mod tests {
|
|||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
burn_tensor::testgen_expand!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
|
@ -157,4 +158,5 @@ mod tests {
|
|||
burn_autodiff::testgen_ad_sub!();
|
||||
burn_autodiff::testgen_ad_tanh!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
burn_autodiff::testgen_ad_expand!();
|
||||
}
|
||||
|
|
|
@ -142,3 +142,10 @@ pub fn chunk<E: CandleElement, const D: usize>(
|
|||
Err(e) => panic!("error chunk from Candle"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expand<E: CandleElement, const D1: usize, const D2: usize>(
|
||||
tensor: CandleTensor<E, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> CandleTensor<E, D2> {
|
||||
CandleTensor::new(tensor.tensor.broadcast_as(&shape.dims).unwrap())
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ use crate::{
|
|||
Candle, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::{expand, permute};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
super::base::empty(shape, device)
|
||||
|
@ -140,4 +142,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
|
|||
) -> BoolTensor<Self, D> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_expand<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<Self, D2> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ use crate::{
|
|||
Candle, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::{expand, permute};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
super::base::empty(shape, device)
|
||||
|
@ -430,6 +432,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_expand<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> IntTensor<Self, D2> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
// TODO add sign operator once Candle supports it:
|
||||
// https://github.com/huggingface/candle/issues/1827
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@ use crate::{
|
|||
Candle, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::{expand, permute};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
|
||||
fn float_from_data<const D: usize>(
|
||||
data: Data<F, D>,
|
||||
|
@ -530,6 +532,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> FloatTensor<Self, D2> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
// TODO add sign operator once Candle supports it:
|
||||
// https://github.com/huggingface/candle/issues/1827
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ use crate::{
|
|||
ops::binary::binary_ops_shape,
|
||||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, FlipOperationDescription, Operation, OperationDescription,
|
||||
PermuteOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
||||
SliceOperationDescription, StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation,
|
||||
OperationDescription, PermuteOperationDescription, ReshapeDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend,
|
||||
};
|
||||
|
@ -467,6 +468,44 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn bool_expand<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<Self, D2> {
|
||||
#[derive(new)]
|
||||
struct ExpandOps<const D: usize, const D2: usize> {
|
||||
desc: ExpandOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
|
||||
let output = B::bool_expand(input, shape.into());
|
||||
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(shape.dims.into());
|
||||
|
||||
let desc = ExpandOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
shape: shape.dims.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Expand(desc.clone())),
|
||||
ExpandOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
|
|
|
@ -6,13 +6,14 @@ use crate::{
|
|||
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
|
||||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, FlipOperationDescription, FloatOperationDescription,
|
||||
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
|
||||
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
|
||||
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, NumericOperationDescription, Operation,
|
||||
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
unary_float_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
|
@ -1847,6 +1848,44 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> FloatTensor<Self, D2> {
|
||||
#[derive(new)]
|
||||
struct ExpandOps<const D: usize, const D2: usize> {
|
||||
desc: ExpandOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
|
||||
let output = B::float_expand(input, shape.into());
|
||||
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(shape.dims.into());
|
||||
|
||||
let desc = ExpandOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
shape: shape.dims.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseFloat(BaseOperationDescription::Expand(desc.clone())),
|
||||
ExpandOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
|
|
|
@ -6,13 +6,13 @@ use crate::{
|
|||
scalar_int_cmp_ops, scalar_int_ops,
|
||||
stream::{
|
||||
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, FlipOperationDescription, GatherOperationDescription,
|
||||
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
|
||||
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
|
||||
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
|
||||
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
unary_int_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
|
@ -1553,6 +1553,43 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn int_expand<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> IntTensor<Self, D2> {
|
||||
#[derive(new)]
|
||||
struct ExpandOps<const D: usize, const D2: usize> {
|
||||
desc: ExpandOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, const D2: usize, B: FusionBackend> Operation<B> for ExpandOps<D, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let shape: [usize; D2] = self.desc.shape.try_into().unwrap();
|
||||
let output = B::bool_expand(input, shape.into());
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(shape.dims.into());
|
||||
|
||||
let desc = ExpandOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
shape: shape.dims.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Expand(desc.clone())),
|
||||
ExpandOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct FlipDimsOps<const D: usize> {
|
||||
|
|
|
@ -4,16 +4,17 @@ use super::{
|
|||
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription,
|
||||
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
|
||||
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
|
||||
EmbeddingBackwardDescription, EmbeddingDescription, FlipOperationDescription,
|
||||
FloatOperationDescription, GatherOperationDescription, IntOperationDescription,
|
||||
InterpolateBackwardDescription, InterpolateDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
|
||||
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
|
||||
MaxPool2dWithIndicesDescription, ModuleOperationDescription, NumericOperationDescription,
|
||||
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
|
||||
EmbeddingBackwardDescription, EmbeddingDescription, ExpandOperationDescription,
|
||||
FlipOperationDescription, FloatOperationDescription, GatherOperationDescription,
|
||||
IntOperationDescription, InterpolateBackwardDescription, InterpolateDescription,
|
||||
MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription,
|
||||
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
|
||||
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
|
||||
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
|
||||
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
|
||||
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
};
|
||||
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
|
||||
use burn_tensor::{Element, ElementConversion};
|
||||
|
@ -798,6 +799,13 @@ impl BaseOperationDescription {
|
|||
axes: desc.axes.clone(),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Expand(desc) => {
|
||||
BaseOperationDescription::Expand(ExpandOperationDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
shape: desc.shape.clone(),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Flip(desc) => {
|
||||
BaseOperationDescription::Flip(FlipOperationDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
|
|
|
@ -141,6 +141,7 @@ pub enum BaseOperationDescription {
|
|||
/// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape).
|
||||
/// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape).
|
||||
Reshape(ReshapeDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims).
|
||||
|
@ -161,6 +162,13 @@ pub enum BaseOperationDescription {
|
|||
/// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
|
||||
Flip(FlipOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [expand](burn_tensor::ops::FloatTensorOps::float_expand).
|
||||
/// Int => [expand](burn_tensor::ops::IntTensorOps::int_expand).
|
||||
/// Bool => [expand](burn_tensor::ops::BoolTensorOps::bool_expand).
|
||||
Expand(ExpandOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
|
||||
|
@ -462,6 +470,17 @@ pub struct PermuteOperationDescription {
|
|||
pub axes: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Expand operation description.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExpandOperationDescription {
|
||||
/// Input tensor description.
|
||||
pub input: TensorDescription,
|
||||
/// Output tensor description.
|
||||
pub out: TensorDescription,
|
||||
/// The new shape.
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Flip operation description.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub struct FlipOperationDescription {
|
||||
|
@ -487,6 +506,13 @@ pub struct ReshapeDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ExpandDescription {
|
||||
pub input: TensorDescription,
|
||||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct BinaryOperationDescription {
|
||||
|
@ -1054,6 +1080,11 @@ impl BaseOperationDescription {
|
|||
BaseOperationDescription::Permute(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
|
||||
BaseOperationDescription::Expand(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
|
||||
BaseOperationDescription::Flip(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -1,3 +1,5 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{element::JitElement, kernel, tensor::JitTensor, Runtime};
|
||||
use burn_tensor::{Data, Reader, Shape};
|
||||
|
||||
|
@ -80,6 +82,55 @@ pub(crate) fn permute<R: Runtime, E: JitElement, const D: usize>(
|
|||
|
||||
tensor
|
||||
}
|
||||
pub(crate) fn expand<R: Runtime, E: JitElement, const D: usize, const D_OUT: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
target_shape: Shape<D_OUT>,
|
||||
) -> JitTensor<R, E, D_OUT> {
|
||||
// Initialize new strides with zeros
|
||||
let mut new_strides = [0usize; D_OUT];
|
||||
|
||||
// Calculate the difference in dimensions
|
||||
let dim_diff = D_OUT.saturating_sub(D);
|
||||
|
||||
// Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones
|
||||
let mut tensor_dim_iter = tensor.shape.dims.iter().rev();
|
||||
for i in (0..D_OUT).rev() {
|
||||
if i >= dim_diff {
|
||||
if let Some(&tensor_dim) = tensor_dim_iter.next() {
|
||||
if tensor_dim == target_shape.dims[i] || tensor_dim == 1 {
|
||||
// Copy stride for non-broadcast dimensions or set to 0 for broadcast ones
|
||||
new_strides[i] = if tensor_dim == target_shape.dims[i] {
|
||||
tensor.strides[i - dim_diff]
|
||||
} else {
|
||||
0
|
||||
};
|
||||
} else {
|
||||
// Error handling: Dimension mismatch for broadcasting
|
||||
panic!(
|
||||
"Dimension mismatch: cannot broadcast dimension {} of tensor to target shape",
|
||||
tensor_dim
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// If the input tensor has fewer dimensions, treat missing dimensions as 1
|
||||
// and set stride to 0 (broadcasting)
|
||||
new_strides[i] = 0;
|
||||
}
|
||||
} else {
|
||||
// For extra dimensions in the target shape, set stride to 0 (broadcasting)
|
||||
new_strides[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
JitTensor {
|
||||
client: tensor.client,
|
||||
device: tensor.device,
|
||||
shape: target_shape,
|
||||
strides: new_strides,
|
||||
handle: tensor.handle,
|
||||
elem: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reshape<R: Runtime, E: JitElement, const D1: usize, const D2: usize>(
|
||||
tensor: JitTensor<R, E, D1>,
|
||||
|
|
|
@ -4,7 +4,7 @@ use burn_tensor::Reader;
|
|||
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
|
||||
use std::ops::Range;
|
||||
|
||||
use super::permute;
|
||||
use super::{expand, permute};
|
||||
|
||||
impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
|
@ -114,6 +114,13 @@ impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
|
|||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_expand<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<Self, D2> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{numeric, permute};
|
||||
use super::{expand, numeric, permute};
|
||||
use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryOperator};
|
||||
use crate::kernel::matmul::{matmul, MatmulStrategy};
|
||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
|
@ -521,6 +521,13 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> FloatTensor<Self, D2> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{numeric, permute};
|
||||
use super::{expand, numeric, permute};
|
||||
use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator};
|
||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
use crate::{kernel, unary, JitBackend, Runtime};
|
||||
|
@ -340,6 +340,13 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
|
|||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_expand<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> IntTensor<Self, D2> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
|
||||
kernel::flip(tensor, axes)
|
||||
}
|
||||
|
|
|
@ -176,7 +176,7 @@ where
|
|||
for d in 0..D {
|
||||
let stride = self.strides[D - 1 - d];
|
||||
|
||||
if stride < current_stride {
|
||||
if stride <= current_stride {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use alloc::vec::Vec;
|
||||
use burn_tensor::Data;
|
||||
use burn_tensor::ElementConversion;
|
||||
use core::fmt::Debug;
|
||||
use core::{marker::PhantomData, ops::Range};
|
||||
use ndarray::s;
|
||||
use ndarray::Array2;
|
||||
use ndarray::IntoDimension;
|
||||
use ndarray::SliceInfo;
|
||||
use ndarray::Zip;
|
||||
use num_traits::Signed;
|
||||
|
@ -28,7 +30,7 @@ pub(crate) struct NdArrayMathOps<E> {
|
|||
|
||||
impl<E> NdArrayOps<E>
|
||||
where
|
||||
E: Copy,
|
||||
E: Copy + Debug,
|
||||
{
|
||||
pub fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
|
@ -113,6 +115,22 @@ where
|
|||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
/// Broadcasts the tensor to the given shape
|
||||
pub(crate) fn expand<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> NdArrayTensor<E, D2> {
|
||||
let array = tensor
|
||||
.array
|
||||
.broadcast(shape.dims.into_dimension())
|
||||
.expect("The shapes should be broadcastable")
|
||||
// need to convert view to owned array because NdArrayTensor expects owned array
|
||||
// and try_into_owned_nocopy() panics for broadcasted arrays (zero strides)
|
||||
.into_owned()
|
||||
.into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub fn flip<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
axes: &[usize],
|
||||
|
|
|
@ -138,6 +138,13 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
|
|||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn bool_expand<const D1: usize, const D2: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> burn_tensor::ops::BoolTensor<Self, D2> {
|
||||
NdArrayOps::expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
|
|
|
@ -456,4 +456,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
|
|||
fn int_sign<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
|
||||
fn int_expand<const D1: usize, const D2: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> burn_tensor::ops::IntTensor<Self, D2> {
|
||||
NdArrayOps::expand(tensor, shape)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -504,4 +504,11 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
fn float_sign<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self, D2> {
|
||||
NdArrayOps::expand(tensor, shape)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -495,6 +495,14 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
|
||||
}
|
||||
|
||||
pub fn expand<const D: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
shape: Shape<D2>,
|
||||
) -> TchTensor<E, D2> {
|
||||
let tensor = tensor.tensor.broadcast_to(shape.dims.map(|x| x as i64));
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn sort<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -159,4 +159,11 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
|||
.map(TchTensor::new)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn bool_expand<const D1: usize, const D2: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> burn_tensor::ops::BoolTensor<Self, D2> {
|
||||
TchOps::expand(tensor, shape)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -487,6 +487,13 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::sign(tensor)
|
||||
}
|
||||
|
||||
fn int_expand<const D1: usize, const D2: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> burn_tensor::ops::IntTensor<Self, D2> {
|
||||
TchOps::expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_sort<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -496,6 +496,13 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::sign(tensor)
|
||||
}
|
||||
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self, D2> {
|
||||
TchOps::expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn float_sort<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -10,6 +10,7 @@ use alloc::string::String;
|
|||
use alloc::vec;
|
||||
|
||||
use burn_common::{reader::Reader, stub::Mutex};
|
||||
use core::iter::repeat;
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
|
||||
|
@ -750,6 +751,29 @@ where
|
|||
let data = self.into_data().await;
|
||||
data.value[0]
|
||||
}
|
||||
|
||||
/// Broadcast the tensor to the given shape.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape to broadcast the tensor to.
|
||||
/// Can contain -1 for dimensions that should be inferred.
|
||||
/// The number of elements in the shape must be greater or equal as
|
||||
/// the number of dimensions of the tensor.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor cannot be broadcasted to the given shape.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor with the given shape.
|
||||
pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
|
||||
let shape = shape.into_shape(&self.shape());
|
||||
check!(TensorCheck::expand("expand", &self.shape(), &shape,));
|
||||
|
||||
Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator given by (Tensor::iter_dim).
|
||||
|
@ -1461,7 +1485,6 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
|
||||
fn all<const D: usize>(tensor: Self::Primitive<D>) -> Tensor<B, 1, Bool>;
|
||||
|
||||
/// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
|
||||
|
@ -1481,8 +1504,22 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
|
||||
fn all_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Broadcasts the given tensor to the specified shape.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to broadcast.
|
||||
/// * `shape` - The shape to broadcast to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The broadcasted tensor.
|
||||
fn expand<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2>;
|
||||
}
|
||||
|
||||
impl<B: Backend> BasicOps<B> for Float {
|
||||
|
@ -1491,6 +1528,7 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D> {
|
||||
B::float_empty(shape, device)
|
||||
}
|
||||
|
||||
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D> {
|
||||
B::float_shape(tensor)
|
||||
}
|
||||
|
@ -1598,6 +1636,13 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
B::float_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn expand<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2> {
|
||||
B::float_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
|
||||
B::float_flip(tensor, axes)
|
||||
}
|
||||
|
@ -1716,6 +1761,13 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
B::int_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn expand<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2> {
|
||||
B::int_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
|
||||
B::int_flip(tensor, axes)
|
||||
}
|
||||
|
@ -1834,6 +1886,13 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
B::bool_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn expand<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2> {
|
||||
B::bool_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
|
||||
B::bool_flip(tensor, axes)
|
||||
}
|
||||
|
@ -1925,6 +1984,55 @@ impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
|
|||
}
|
||||
}
|
||||
|
||||
/// Trait used for broadcast arguments.
|
||||
pub trait BroadcastArgs<const D1: usize, const D2: usize> {
|
||||
/// Converts to a shape.
|
||||
fn into_shape(self, shape: &Shape<D1>) -> Shape<D2>;
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape<D2> {
|
||||
fn into_shape(self, _shape: &Shape<D1>) -> Shape<D2> {
|
||||
self
|
||||
}
|
||||
}
|
||||
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
|
||||
fn into_shape(self, _shape: &Shape<D1>) -> Shape<D2> {
|
||||
Shape::from(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [i32; D2] {
|
||||
// Passing -1 as the size for a dimension means not changing the size of that dimension.
|
||||
fn into_shape(self, shape: &Shape<D1>) -> Shape<D2> {
|
||||
if self.len() < shape.dims.len() {
|
||||
panic!("Broadcast arguments must be greater than the number of dimensions");
|
||||
}
|
||||
|
||||
if self.iter().any(|&x| x < -1 || x == 0) {
|
||||
panic!("Broadcast arguments must be positive or -1");
|
||||
}
|
||||
|
||||
// Zip the two shapes in reverse order and replace -1 with the actual dimension value.
|
||||
let new_shape: Vec<_> = self
|
||||
.iter()
|
||||
.rev()
|
||||
.zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
|
||||
.map(|(&x, &y)| if x == -1 { y } else { x as usize })
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
if new_shape.iter().any(|&x| x == 0) {
|
||||
panic!("Cannot substitute -1 for a non-existing dimension");
|
||||
}
|
||||
|
||||
let new_shape: [usize; D2] = new_shape.try_into().unwrap();
|
||||
|
||||
Shape::from(new_shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
|
||||
where
|
||||
|
|
|
@ -889,6 +889,56 @@ impl TensorCheck {
|
|||
false => self,
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if expand operation is possible for the given shapes.
|
||||
pub fn expand<const D1: usize, const D2: usize>(
|
||||
ops: &str,
|
||||
shape: &Shape<D1>,
|
||||
to: &Shape<D2>,
|
||||
) -> Self {
|
||||
let mut check = TensorCheck::Ok;
|
||||
let max_dims = core::cmp::max(D1, D2);
|
||||
|
||||
// Calculate the starting indices for each shape array, ensuring alignment from the right.
|
||||
let start_index_shape = max_dims.saturating_sub(D1);
|
||||
let start_index_to = max_dims.saturating_sub(D2);
|
||||
|
||||
for i in 0..max_dims {
|
||||
// Use 1 as the default dimension size for dimensions beyond the tensor's rank.
|
||||
let d_shape = if i >= start_index_shape {
|
||||
shape.dims[i - start_index_shape]
|
||||
} else {
|
||||
1
|
||||
};
|
||||
let d_to = if i >= start_index_to {
|
||||
to.dims[i - start_index_to]
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
if d_shape != d_to && d_shape != 1 && d_to != 1 {
|
||||
// Register an incompatibility error.
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(
|
||||
"The provided tensor can't be broadcasted to the target shape.",
|
||||
)
|
||||
.details(format!(
|
||||
"Incompatible size at dimension '{}' => '{} != {}', which can't be \
|
||||
broadcasted. Tensor shape {:?}, Target shape {:?}.",
|
||||
max_dims - i - 1,
|
||||
d_shape,
|
||||
d_to,
|
||||
shape.dims,
|
||||
to.dims,
|
||||
)),
|
||||
);
|
||||
break; // Incompatibility found, no need to check further.
|
||||
}
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct FailedTensorCheck {
|
||||
|
|
|
@ -457,4 +457,10 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Broadcasts the bool `tensor` to the given `shape`.
|
||||
fn bool_expand<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<B, D2>;
|
||||
}
|
||||
|
|
|
@ -1199,6 +1199,12 @@ pub trait IntTensorOps<B: Backend> {
|
|||
result
|
||||
}
|
||||
|
||||
/// Broadcasts the int `tensor` to the given `shape`.
|
||||
fn int_expand<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<B, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> IntTensor<B, D2>;
|
||||
|
||||
/// Sort the elements of the input `tensor` by value along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
|
|
|
@ -1373,6 +1373,12 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
result
|
||||
}
|
||||
|
||||
/// Broadcasts the float `tensor` to the given `shape`.
|
||||
fn float_expand<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<B, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> FloatTensor<B, D2>;
|
||||
|
||||
/// Sort the elements of the input `tensor` by value in along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
|
|
|
@ -90,6 +90,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
burn_tensor::testgen_expand!();
|
||||
burn_tensor::testgen_tri_mask!();
|
||||
burn_tensor::testgen_sort_argsort!();
|
||||
burn_tensor::testgen_topk!();
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
#[burn_tensor_testgen::testgen(expand)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn expand_2d() {
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0], &Default::default());
|
||||
let expanded_tensor = tensor.expand([3, 3]);
|
||||
|
||||
let expected_data = Data::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 1>::from_floats([4.0, 7.0, 2.0, 3.0], &Default::default());
|
||||
let expanded_tensor = tensor.expand([2, 4]);
|
||||
|
||||
let expected_data = Data::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_3d() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default());
|
||||
let expanded_tensor = tensor.expand([3, 2, 2]);
|
||||
|
||||
let expected_data = Data::from([
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_higher_dimensions() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &Default::default());
|
||||
let expanded_tensor = tensor.expand([2, 3, 4]);
|
||||
|
||||
let expected_data = Data::from([
|
||||
[
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
],
|
||||
[
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
],
|
||||
]);
|
||||
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn broadcast_single() {
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([1.0], &Default::default());
|
||||
let expanded_tensor = tensor.expand([2, 3]);
|
||||
|
||||
let expected_data = Data::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_fail_expand_incompatible_shapes() {
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0], &Default::default());
|
||||
let _expanded_tensor = tensor.expand([2, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_2d_bool() {
|
||||
let tensor = TestTensorBool::from([false, true, false]);
|
||||
let expanded_tensor = tensor.expand([3, 3]);
|
||||
|
||||
let expected_data = Data::from([
|
||||
[false, true, false],
|
||||
[false, true, false],
|
||||
[false, true, false],
|
||||
]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_2d_int() {
|
||||
let tensor = TestTensorInt::from([1, 2, 3]);
|
||||
let expanded_tensor = tensor.expand([3, 3]);
|
||||
|
||||
let expected_data = Data::from([[1, 2, 3], [1, 2, 3], [1, 2, 3]]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_all_negative_one() {
|
||||
let tensor = TestTensorInt::from([1, 2, 3]);
|
||||
let expanded_tensor = tensor.expand([2, -1]);
|
||||
|
||||
let expected_data = Data::from([[1, 2, 3], [1, 2, 3]]);
|
||||
assert_eq!(expanded_tensor.into_data(), expected_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_negative_one_on_non_existing_dim() {
|
||||
let tensor = TestTensorInt::from([1, 2, 3]);
|
||||
let _expanded_tensor = tensor.expand([-1, 3]);
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ mod create_like;
|
|||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod expand;
|
||||
mod flatten;
|
||||
mod flip;
|
||||
mod full;
|
||||
|
|
Loading…
Reference in New Issue