mirror of https://github.com/tracel-ai/burn.git
Repeat ops autodiff & fusion + fix autodiff ones & zeros (#1600)
* added repeat to autodiff and fusion + zero one backend init in autodiff * autodiff for repeat
This commit is contained in:
parent
15f2e49aca
commit
bdb62fbcd0
|
@ -135,4 +135,12 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
|||
) -> BoolTensor<B, D2> {
|
||||
B::bool_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_repeat<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_repeat(tensor, dim, times)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,11 +40,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
fn float_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
Self::float_from_data(Data::zeros(shape), device)
|
||||
AutodiffTensor::new(B::float_zeros(shape, device))
|
||||
}
|
||||
|
||||
fn float_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
Self::float_from_data(Data::ones(shape), device)
|
||||
AutodiffTensor::new(B::float_ones(shape, device))
|
||||
}
|
||||
|
||||
fn float_shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
|
||||
|
@ -2410,6 +2410,61 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
B::float_argsort(tensor.primitive, dim, descending)
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
#[derive(Debug)]
|
||||
struct Repeat;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct RetroRepeat<B: Backend, const D: usize> {
|
||||
tensor_id: NodeID,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
|
||||
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
|
||||
let out = B::float_repeat(tensor, self.dim, self.times);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Repeat {
|
||||
type State = usize;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let dim = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::float_sum_dim(grad, dim)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Repeat
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroRepeat::<B, D>::new(tensor.node.id, dim, times))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(prep) => {
|
||||
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement float_prod and float_sum
|
||||
// https://github.com/tracel-ai/burn/issues/1458
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ mod permute;
|
|||
mod pow;
|
||||
mod recip;
|
||||
mod relu;
|
||||
mod repeat;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sigmoid;
|
||||
|
@ -126,5 +127,6 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_sign!();
|
||||
burn_autodiff::testgen_ad_expand!();
|
||||
burn_autodiff::testgen_ad_sort!();
|
||||
burn_autodiff::testgen_ad_repeat!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#[burn_tensor_testgen::testgen(ad_repeat)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_repeat() {
|
||||
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[4.0], [2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().repeat(1, 3);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_3);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_2.to_data(), Data::from([[-3.0], [12.0]]));
|
||||
}
|
||||
}
|
|
@ -5,9 +5,9 @@ use crate::{
|
|||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation,
|
||||
OperationDescription, PermuteOperationDescription, ReshapeDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
OperationDescription, PermuteOperationDescription, RepeatOperationDescription,
|
||||
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend,
|
||||
};
|
||||
|
@ -540,4 +540,44 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_repeat<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> BoolTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RepeatOps<const D: usize> {
|
||||
desc: RepeatOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_bool_tensor::<D>(&self.desc.tensor);
|
||||
|
||||
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = times;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())),
|
||||
RepeatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,10 +10,10 @@ use crate::{
|
|||
FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, NumericOperationDescription, Operation,
|
||||
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
unary_float_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
|
@ -1571,6 +1571,46 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RepeatOps<const D: usize> {
|
||||
desc: RepeatOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_float_tensor::<D>(&self.desc.tensor);
|
||||
|
||||
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = times;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())),
|
||||
RepeatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn float_argmin<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -9,10 +9,11 @@ use crate::{
|
|||
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
|
||||
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
|
||||
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription,
|
||||
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
|
||||
SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
unary_int_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
|
@ -1623,4 +1624,44 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_repeat<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RepeatOps<const D: usize> {
|
||||
desc: RepeatOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for RepeatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
|
||||
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_int_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = times;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())),
|
||||
RepeatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue