From bdb62fbcd01357e5303b01eed8b7570c3e293cf8 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Thu, 11 Apr 2024 11:32:45 -0400 Subject: [PATCH] Repeat ops autodiff & fusion + fix autodiff ones & zeros (#1600) * added repeat to autodiff and fusion + zero one backend init in autodiff * autodiff for repeat --- crates/burn-autodiff/src/ops/bool_tensor.rs | 8 +++ crates/burn-autodiff/src/ops/tensor.rs | 59 ++++++++++++++++++++- crates/burn-autodiff/src/tests/mod.rs | 2 + crates/burn-autodiff/src/tests/repeat.rs | 24 +++++++++ crates/burn-fusion/src/ops/boolean.rs | 46 ++++++++++++++-- crates/burn-fusion/src/ops/float.rs | 48 +++++++++++++++-- crates/burn-fusion/src/ops/int.rs | 49 +++++++++++++++-- 7 files changed, 223 insertions(+), 13 deletions(-) create mode 100644 crates/burn-autodiff/src/tests/repeat.rs diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index 214117327..3ecc39ce1 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -135,4 +135,12 @@ impl BoolTensorOps for Autodiff { ) -> BoolTensor { B::bool_expand(tensor, shape) } + + fn bool_repeat( + tensor: BoolTensor, + dim: usize, + times: usize, + ) -> BoolTensor { + B::bool_repeat(tensor, dim, times) + } } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 6cc82d4be..113c300d0 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -40,11 +40,11 @@ impl FloatTensorOps for Autodiff } fn float_zeros(shape: Shape, device: &Device) -> FloatTensor { - Self::float_from_data(Data::zeros(shape), device) + AutodiffTensor::new(B::float_zeros(shape, device)) } fn float_ones(shape: Shape, device: &Device) -> FloatTensor { - Self::float_from_data(Data::ones(shape), device) + AutodiffTensor::new(B::float_ones(shape, device)) } fn float_shape(tensor: &FloatTensor) -> Shape { @@ -2410,6 +2410,61 @@ impl FloatTensorOps for Autodiff B::float_argsort(tensor.primitive, dim, descending) } + fn float_repeat( + tensor: FloatTensor, + dim: usize, + times: usize, + ) -> FloatTensor { + #[derive(Debug)] + struct Repeat; + + #[derive(new, Debug)] + struct RetroRepeat { + tensor_id: NodeID, + dim: usize, + times: usize, + _backend: PhantomData, + } + + impl RetroForward for RetroRepeat { + fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { + let tensor = states.get_state::>(&self.tensor_id); + let out = B::float_repeat(tensor, self.dim, self.times); + states.save(out_node, out) + } + } + + impl Backward for Repeat { + type State = usize; + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let dim = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + B::float_sum_dim(grad, dim) + }); + } + } + + match Repeat + .prepare::([tensor.node.clone()]) + .memory_bound() + .retro_forward(RetroRepeat::::new(tensor.node.id, dim, times)) + .parents([&tensor]) + .stateful() + { + OpsKind::Tracked(prep) => { + prep.finish(dim, B::float_repeat(tensor.primitive, dim, times)) + } + OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)), + } + } + // TODO: Implement float_prod and float_sum // https://github.com/tracel-ai/burn/issues/1458 } diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 7128e2e6d..d9420ec62 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -44,6 +44,7 @@ mod permute; mod pow; mod recip; mod relu; +mod repeat; mod reshape; mod select; mod sigmoid; @@ -126,5 +127,6 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_sign!(); burn_autodiff::testgen_ad_expand!(); burn_autodiff::testgen_ad_sort!(); + burn_autodiff::testgen_ad_repeat!(); }; } diff --git a/crates/burn-autodiff/src/tests/repeat.rs b/crates/burn-autodiff/src/tests/repeat.rs new file mode 100644 index 000000000..1abf9f5fd --- /dev/null +++ b/crates/burn-autodiff/src/tests/repeat.rs @@ -0,0 +1,24 @@ +#[burn_tensor_testgen::testgen(ad_repeat)] +mod tests { + use super::*; + use burn_tensor::{activation, Data}; + + #[test] + fn should_diff_repeat() { + let data_1 = Data::::from([[1.0, 7.0], [-2.0, -3.0]]); + let data_2 = Data::::from([[4.0], [2.0]]); + + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); + + let tensor_3 = tensor_2.clone().repeat(1, 3); + + let tensor_3 = tensor_1.matmul(tensor_3); + let grads = tensor_3.backward(); + + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_2.to_data(), Data::from([[-3.0], [12.0]])); + } +} diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index da3c86cd0..97e82008c 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -5,9 +5,9 @@ use crate::{ stream::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation, - OperationDescription, PermuteOperationDescription, ReshapeDescription, - SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription, - UnaryOperationDescription, + OperationDescription, PermuteOperationDescription, RepeatOperationDescription, + ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId, + SwapDimsDescription, UnaryOperationDescription, }, Fusion, FusionBackend, }; @@ -540,4 +540,44 @@ impl BoolTensorOps for Fusion { out } + + fn bool_repeat( + tensor: BoolTensor, + dim: usize, + times: usize, + ) -> BoolTensor { + #[derive(new)] + struct RepeatOps { + desc: RepeatOperationDescription, + } + + impl Operation for RepeatOps { + fn execute(self: Box, handles: &mut crate::HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); + + let output = B::bool_repeat::(tensor, self.desc.dim, self.desc.times); + + handles.register_bool_tensor(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let mut shape = tensor.shape.clone(); + shape[dim] = times; + let out = tensor.client.tensor_uninitialized(shape); + + let desc = RepeatOperationDescription { + tensor: tensor.into_description(), + dim, + times, + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())), + RepeatOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 3b5cecff1..ce9040865 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -10,10 +10,10 @@ use crate::{ FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription, - ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription, - UnaryOperationDescription, + ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + StreamId, SwapDimsDescription, UnaryOperationDescription, }, unary_float_ops, Fusion, FusionBackend, TensorDescription, }; @@ -1571,6 +1571,46 @@ impl FloatTensorOps for Fusion { out } + fn float_repeat( + tensor: FloatTensor, + dim: usize, + times: usize, + ) -> FloatTensor { + #[derive(new)] + struct RepeatOps { + desc: RepeatOperationDescription, + } + + impl Operation for RepeatOps { + fn execute(self: Box, handles: &mut crate::HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + + let output = B::float_repeat::(tensor, self.desc.dim, self.desc.times); + + handles.register_float_tensor(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let mut shape = tensor.shape.clone(); + shape[dim] = times; + let out = tensor.client.tensor_uninitialized(shape); + + let desc = RepeatOperationDescription { + tensor: tensor.into_description(), + dim, + times, + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())), + RepeatOps::::new(desc), + ); + + out + } + fn float_argmin( tensor: FloatTensor, dim: usize, diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 7aae2a980..45e5d78c0 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -9,10 +9,11 @@ use crate::{ ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription, - RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription, - ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, - SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, - StreamId, SwapDimsDescription, UnaryOperationDescription, + RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription, + ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription, + SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription, + UnaryOperationDescription, }, unary_int_ops, Fusion, FusionBackend, TensorDescription, }; @@ -1623,4 +1624,44 @@ impl IntTensorOps for Fusion { out } + + fn int_repeat( + tensor: IntTensor, + dim: usize, + times: usize, + ) -> IntTensor { + #[derive(new)] + struct RepeatOps { + desc: RepeatOperationDescription, + } + + impl Operation for RepeatOps { + fn execute(self: Box, handles: &mut crate::HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + + let output = B::int_repeat::(tensor, self.desc.dim, self.desc.times); + + handles.register_int_tensor(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let mut shape = tensor.shape.clone(); + shape[dim] = times; + let out = tensor.client.tensor_uninitialized(shape); + + let desc = RepeatOperationDescription { + tensor: tensor.into_description(), + dim, + times, + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())), + RepeatOps::::new(desc), + ); + + out + } }