Repeat ops autodiff & fusion + fix autodiff ones & zeros (#1600)

* added repeat to autodiff and fusion + zero one backend init in autodiff

* autodiff for repeat
This commit is contained in:
Louis Fortier-Dubois 2024-04-11 11:32:45 -04:00 committed by GitHub
parent 15f2e49aca
commit bdb62fbcd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 223 additions and 13 deletions

View File

@ -135,4 +135,12 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
) -> BoolTensor<B, D2> {
B::bool_expand(tensor, shape)
}
fn bool_repeat<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
times: usize,
) -> BoolTensor<B, D> {
B::bool_repeat(tensor, dim, times)
}
}

View File

@ -40,11 +40,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
fn float_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
Self::float_from_data(Data::zeros(shape), device)
AutodiffTensor::new(B::float_zeros(shape, device))
}
fn float_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
Self::float_from_data(Data::ones(shape), device)
AutodiffTensor::new(B::float_ones(shape, device))
}
fn float_shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
@ -2410,6 +2410,61 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
B::float_argsort(tensor.primitive, dim, descending)
}
fn float_repeat<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Repeat;
#[derive(new, Debug)]
struct RetroRepeat<B: Backend, const D: usize> {
tensor_id: NodeID,
dim: usize,
times: usize,
_backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
let out = B::float_repeat(tensor, self.dim, self.times);
states.save(out_node, out)
}
}
impl<B: Backend, const D: usize> Backward<B, D, 1> for Repeat {
type State = usize;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let dim = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::float_sum_dim(grad, dim)
});
}
}
match Repeat
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroRepeat::<B, D>::new(tensor.node.id, dim, times))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(prep) => {
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
}
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
}
}
// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458
}

View File

@ -44,6 +44,7 @@ mod permute;
mod pow;
mod recip;
mod relu;
mod repeat;
mod reshape;
mod select;
mod sigmoid;
@ -126,5 +127,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sign!();
burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!();
burn_autodiff::testgen_ad_repeat!();
};
}

View File

@ -0,0 +1,24 @@
#[burn_tensor_testgen::testgen(ad_repeat)]
mod tests {
use super::*;
use burn_tensor::{activation, Data};
#[test]
fn should_diff_repeat() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0], [2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_2.to_data(), Data::from([[-3.0], [12.0]]));
}
}

View File

@ -5,9 +5,9 @@ use crate::{
stream::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation,
OperationDescription, PermuteOperationDescription, ReshapeDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
OperationDescription, PermuteOperationDescription, RepeatOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
SwapDimsDescription, UnaryOperationDescription,
},
Fusion, FusionBackend,
};
@ -540,4 +540,44 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out
}
fn bool_repeat<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
times: usize,
) -> BoolTensor<Self, D> {
#[derive(new)]
struct RepeatOps<const D: usize> {
desc: RepeatOperationDescription,
}
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_bool_tensor::<D>(&self.desc.tensor);
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_bool_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = times;
let out = tensor.client.tensor_uninitialized(shape);
let desc = RepeatOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<D>::new(desc),
);
out
}
}

View File

@ -10,10 +10,10 @@ use crate::{
FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, Operation,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
},
unary_float_ops, Fusion, FusionBackend, TensorDescription,
};
@ -1571,6 +1571,46 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}
fn float_repeat<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
#[derive(new)]
struct RepeatOps<const D: usize> {
desc: RepeatOperationDescription,
}
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_float_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = times;
let out = tensor.client.tensor_uninitialized(shape);
let desc = RepeatOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<D>::new(desc),
);
out
}
fn float_argmin<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,

View File

@ -9,10 +9,11 @@ use crate::{
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
},
unary_int_ops, Fusion, FusionBackend, TensorDescription,
};
@ -1623,4 +1624,44 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}
fn int_repeat<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
times: usize,
) -> IntTensor<Self, D> {
#[derive(new)]
struct RepeatOps<const D: usize> {
desc: RepeatOperationDescription,
}
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_int_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = times;
let out = tensor.client.tensor_uninitialized(shape);
let desc = RepeatOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<D>::new(desc),
);
out
}
}