mirror of https://github.com/tracel-ai/burn.git
Migrate/jit/adaptive avg pool backward (#1530)
* separate forward backward * refactor with pool strategy * refactor further * pooling refactored * refactoring for adaptive wip * wip adaptive * adaptive * delete some wgsl * avg pool backward * clippy * minor refactor * works * delete wgsl
This commit is contained in:
parent
a77979e0b6
commit
37b61ea646
|
@ -1,17 +1,244 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_compute::server::Handle;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2dBackward,
|
||||
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
|
||||
);
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
tensor::JitTensor,
|
||||
Compiler, Runtime, RuntimeInt,
|
||||
};
|
||||
|
||||
#[derive(new)]
|
||||
struct AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dBackwardComputeShader {
|
||||
grad: Variable,
|
||||
output: Variable,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dBackwardComputeShader {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
let grad = self.grad;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_2 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let grad_shape_2 = scope.create_local(Elem::UInt);
|
||||
let grad_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_stride_0 = scope.create_local(Elem::UInt);
|
||||
let output_stride_1 = scope.create_local(Elem::UInt);
|
||||
let output_stride_2 = scope.create_local(Elem::UInt);
|
||||
let output_stride_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
let output_shape_0 = scope.create_local(Elem::UInt);
|
||||
let output_shape_1 = scope.create_local(Elem::UInt);
|
||||
let output_shape_2 = scope.create_local(Elem::UInt);
|
||||
let output_shape_3 = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, grad_stride_0 = stride(grad, 0u32));
|
||||
gpu!(scope, grad_stride_1 = stride(grad, 1u32));
|
||||
gpu!(scope, grad_stride_2 = stride(grad, 2u32));
|
||||
gpu!(scope, grad_stride_3 = stride(grad, 3u32));
|
||||
|
||||
gpu!(scope, grad_shape_2 = shape(grad, 2u32));
|
||||
gpu!(scope, grad_shape_3 = shape(grad, 3u32));
|
||||
|
||||
gpu!(scope, output_stride_0 = stride(output, 0u32));
|
||||
gpu!(scope, output_stride_1 = stride(output, 1u32));
|
||||
gpu!(scope, output_stride_2 = stride(output, 2u32));
|
||||
gpu!(scope, output_stride_3 = stride(output, 3u32));
|
||||
|
||||
gpu!(scope, output_shape_0 = shape(output, 0u32));
|
||||
gpu!(scope, output_shape_1 = shape(output, 1u32));
|
||||
gpu!(scope, output_shape_2 = shape(output, 2u32));
|
||||
gpu!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
let ih = scope.create_local(Elem::UInt);
|
||||
let iw = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, b = id / output_stride_0);
|
||||
gpu!(scope, b = b % output_shape_0);
|
||||
|
||||
gpu!(scope, c = id / output_stride_1);
|
||||
gpu!(scope, c = c % output_shape_1);
|
||||
|
||||
gpu!(scope, ih = id / output_stride_2);
|
||||
gpu!(scope, ih = ih % output_shape_2);
|
||||
|
||||
gpu!(scope, iw = id / output_stride_3);
|
||||
gpu!(scope, iw = iw % output_shape_3);
|
||||
|
||||
let oh_start = Self::start_index(scope, ih, output_shape_2, grad_shape_2);
|
||||
let oh_end = Self::end_index(scope, ih, output_shape_2, grad_shape_2);
|
||||
|
||||
let ow_start = Self::start_index(scope, iw, output_shape_3, grad_shape_3);
|
||||
let ow_end = Self::end_index(scope, iw, output_shape_3, grad_shape_3);
|
||||
|
||||
let grad_acc = scope.create_local(output.item());
|
||||
let contributed_h = scope.create_local(Elem::Bool);
|
||||
let contributed_w = scope.create_local(Elem::Bool);
|
||||
let contributed_tmp = scope.create_local(Elem::Bool);
|
||||
|
||||
let count = scope.create_local(Elem::UInt);
|
||||
let count_tmp = scope.create_local(Elem::UInt);
|
||||
let count_float = scope.create_local(output.item());
|
||||
let the_grad = scope.create_local(output.item());
|
||||
let avg = scope.create_local(output.item());
|
||||
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, index_base = b * grad_stride_0);
|
||||
gpu!(scope, index_tmp = c * grad_stride_1);
|
||||
gpu!(scope, index_base += index_tmp);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(oh_start, oh_end).for_each(|oh, scope| {
|
||||
let ih_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2);
|
||||
let ih_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2);
|
||||
gpu!(scope, contributed_h = ih >= ih_start);
|
||||
gpu!(scope, contributed_tmp = ih < ih_end);
|
||||
gpu!(scope, contributed_h = contributed_h && contributed_tmp);
|
||||
|
||||
gpu!(scope, if(contributed_h).then(|scope|{
|
||||
gpu!(
|
||||
scope,
|
||||
range(ow_start, ow_end).for_each(|ow, scope| {
|
||||
let iw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3);
|
||||
let iw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3);
|
||||
|
||||
gpu!(scope, contributed_w = iw >= iw_start);
|
||||
gpu!(scope, contributed_tmp = iw < iw_end);
|
||||
gpu!(scope, contributed_w = contributed_w && contributed_tmp);
|
||||
|
||||
|
||||
gpu!(scope, if(contributed_w).then(|scope|{
|
||||
gpu!(scope, count = ih_end - ih_start);
|
||||
gpu!(scope, count_tmp = iw_end - iw_start);
|
||||
gpu!(scope, count *= count_tmp);
|
||||
gpu!(scope, count_float = cast(count));
|
||||
|
||||
gpu!(scope, index = index_base);
|
||||
gpu!(scope, index_tmp = oh * grad_stride_2);
|
||||
gpu!(scope, index += index_tmp);
|
||||
gpu!(scope, index_tmp = ow * grad_stride_3);
|
||||
gpu!(scope, index += index_tmp);
|
||||
|
||||
gpu!(scope, the_grad = grad[index]);
|
||||
gpu!(scope, avg = the_grad / count_float);
|
||||
gpu!(scope, grad_acc += avg);
|
||||
}));
|
||||
})
|
||||
);
|
||||
}));
|
||||
})
|
||||
);
|
||||
|
||||
gpu!(scope, output[id] = grad_acc);
|
||||
}
|
||||
|
||||
fn start_index(
|
||||
scope: &mut Scope,
|
||||
output_size_index: Variable,
|
||||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index * input_size);
|
||||
gpu!(scope, numerator_float = cast(index));
|
||||
gpu!(scope, div = cast(output_size));
|
||||
gpu!(scope, div = numerator_float / div);
|
||||
gpu!(scope, div = floor(div));
|
||||
gpu!(scope, index = cast(div));
|
||||
index
|
||||
}
|
||||
|
||||
fn end_index(
|
||||
scope: &mut Scope,
|
||||
output_size_index: Variable,
|
||||
output_size: Variable,
|
||||
input_size: Variable,
|
||||
) -> Variable {
|
||||
let numerator_float = scope.create_local(Elem::Float);
|
||||
let div = scope.create_local(Elem::Float);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(Elem::Bool);
|
||||
let end_index = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(scope, index = output_size_index + 1u32);
|
||||
gpu!(scope, index *= input_size);
|
||||
gpu!(scope, numerator_float = cast(index));
|
||||
gpu!(scope, div = cast(output_size));
|
||||
gpu!(scope, div = numerator_float / div);
|
||||
gpu!(scope, div = ceil(div));
|
||||
gpu!(scope, index = cast(div));
|
||||
|
||||
gpu!(scope, min = input_size < index);
|
||||
gpu!(scope, if(min).then(|scope|{
|
||||
gpu!(scope, end_index = input_size);
|
||||
}).else(|scope|{
|
||||
gpu!(scope, end_index = index);
|
||||
}));
|
||||
end_index
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let grad = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
AdaptiveAvgPool2dBackwardComputeShader { grad, output }.expand(&mut scope);
|
||||
|
||||
let grad = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let scalars = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: 6,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![grad, scalars],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<Self>(),)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
|
||||
x: JitTensor<R, E, 4>,
|
||||
|
@ -27,45 +254,24 @@ pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
|
|||
output_buffer,
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::new();
|
||||
|
||||
let info_handle = build_info(&x, &out_grad);
|
||||
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&out_grad.handle, &output.handle, &info_handle],
|
||||
execute_dynamic::<R, AdaptiveAvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
&out_grad.handle,
|
||||
&out_grad.strides,
|
||||
&out_grad.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn build_info<R: Runtime, E: JitElement>(
|
||||
x: &JitTensor<R, E, 4>,
|
||||
output: &JitTensor<R, E, 4>,
|
||||
) -> Handle<R::Server> {
|
||||
let mut info: [u32; 16] = [0; 16];
|
||||
info[0] = x.strides[0] as u32;
|
||||
info[1] = x.strides[1] as u32;
|
||||
info[2] = x.strides[2] as u32;
|
||||
info[3] = x.strides[3] as u32;
|
||||
info[4] = x.shape.dims[0] as u32;
|
||||
info[5] = x.shape.dims[1] as u32;
|
||||
info[6] = x.shape.dims[2] as u32;
|
||||
info[7] = x.shape.dims[3] as u32;
|
||||
|
||||
info[8] = output.strides[0] as u32;
|
||||
info[9] = output.strides[1] as u32;
|
||||
info[10] = output.strides[2] as u32;
|
||||
info[11] = output.strides[3] as u32;
|
||||
info[12] = output.shape.dims[0] as u32;
|
||||
info[13] = output.shape.dims[1] as u32;
|
||||
info[14] = output.shape.dims[2] as u32;
|
||||
info[15] = output.shape.dims[3] as u32;
|
||||
|
||||
output.client.create(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ pub(crate) use adaptive_pool2d_shader::*;
|
|||
pub(crate) use pool2d_shader::*;
|
||||
|
||||
pub(crate) use adaptive_avg_pool2d::*;
|
||||
pub use adaptive_avg_pool2d_backward::*;
|
||||
pub(crate) use adaptive_avg_pool2d_backward::*;
|
||||
pub(crate) use avg_pool2d::*;
|
||||
pub(crate) use avg_pool2d_backward::*;
|
||||
pub(super) use base::*;
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> grad: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 16>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[0];
|
||||
let input_stride_1 = info[1];
|
||||
let input_stride_2 = info[2];
|
||||
let input_stride_3 = info[3];
|
||||
let input_shape_0 = info[4];
|
||||
let input_shape_1 = info[5];
|
||||
let input_shape_2 = info[6];
|
||||
let input_shape_3 = info[7];
|
||||
|
||||
let grad_stride_0 = info[8];
|
||||
let grad_stride_1 = info[9];
|
||||
let grad_stride_2 = info[10];
|
||||
let grad_stride_3 = info[11];
|
||||
let grad_shape_0 = info[12];
|
||||
let grad_shape_1 = info[13];
|
||||
let grad_shape_2 = info[14];
|
||||
let grad_shape_3 = info[15];
|
||||
|
||||
let b = id / input_stride_0 % input_shape_0;
|
||||
let c = id / input_stride_1 % input_shape_1;
|
||||
let ih = id / input_stride_2 % input_shape_2;
|
||||
let iw = id / input_stride_3 % input_shape_3;
|
||||
|
||||
let oh_start = start_index(ih, input_shape_2, grad_shape_2);
|
||||
let oh_end = end_index(ih, input_shape_2, grad_shape_2);
|
||||
|
||||
let ow_start = start_index(iw, input_shape_3, grad_shape_3);
|
||||
let ow_end = end_index(iw, input_shape_3, grad_shape_3);
|
||||
|
||||
var grad_acc = 0.0;
|
||||
|
||||
for (var oh = oh_start; oh < oh_end; oh++) {
|
||||
for (var ow = ow_start; ow < ow_end; ow++) {
|
||||
let ih_start = start_index(oh, grad_shape_2, input_shape_2);
|
||||
let ih_end = end_index(oh, grad_shape_2, input_shape_2);
|
||||
|
||||
let iw_start = start_index(ow, grad_shape_3, input_shape_3);
|
||||
let iw_end = end_index(ow, grad_shape_3, input_shape_3);
|
||||
|
||||
let contributed_h = ih >= ih_start && ih < ih_end;
|
||||
let contributed_w = iw >= iw_start && iw < iw_end;
|
||||
|
||||
// If no contribution skip
|
||||
if !contributed_h || !contributed_w {
|
||||
continue;
|
||||
}
|
||||
|
||||
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
|
||||
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
|
||||
|
||||
grad_acc += grad[index] / count;
|
||||
}
|
||||
}
|
||||
|
||||
output[id] = grad_acc;
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
|
||||
}
|
||||
|
||||
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
|
||||
|
||||
return min(index, input_size);
|
||||
}
|
Loading…
Reference in New Issue