Migrate/jit/adaptive avg pool backward (#1530)

* separate forward backward

* refactor with pool strategy

* refactor further

* pooling refactored

* refactoring for adaptive wip

* wip adaptive

* adaptive

* delete some wgsl

* avg pool backward

* clippy

* minor refactor

* works

* delete wgsl
This commit is contained in:
Louis Fortier-Dubois 2024-03-26 08:38:06 -04:00 committed by GitHub
parent a77979e0b6
commit 37b61ea646
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 257 additions and 139 deletions

View File

@ -1,17 +1,244 @@
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
tensor::JitTensor,
Runtime,
};
use burn_compute::server::Handle;
use std::marker::PhantomData;
kernel_wgsl!(
AdaptiveAvgPool2dBackward,
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
);
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
OutputInfo, WorkgroupLaunch,
},
element::JitElement,
gpu::{gpu, Elem, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Compiler, Runtime, RuntimeInt,
};
#[derive(new)]
struct AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
struct AdaptiveAvgPool2dBackwardComputeShader {
grad: Variable,
output: Variable,
}
impl AdaptiveAvgPool2dBackwardComputeShader {
fn expand(self, scope: &mut Scope) {
let grad = self.grad;
let output = self.output;
let id = Variable::Id;
let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);
let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);
let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);
let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);
gpu!(scope, grad_stride_0 = stride(grad, 0u32));
gpu!(scope, grad_stride_1 = stride(grad, 1u32));
gpu!(scope, grad_stride_2 = stride(grad, 2u32));
gpu!(scope, grad_stride_3 = stride(grad, 3u32));
gpu!(scope, grad_shape_2 = shape(grad, 2u32));
gpu!(scope, grad_shape_3 = shape(grad, 3u32));
gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));
gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);
gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);
gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);
gpu!(scope, ih = id / output_stride_2);
gpu!(scope, ih = ih % output_shape_2);
gpu!(scope, iw = id / output_stride_3);
gpu!(scope, iw = iw % output_shape_3);
let oh_start = Self::start_index(scope, ih, output_shape_2, grad_shape_2);
let oh_end = Self::end_index(scope, ih, output_shape_2, grad_shape_2);
let ow_start = Self::start_index(scope, iw, output_shape_3, grad_shape_3);
let ow_end = Self::end_index(scope, iw, output_shape_3, grad_shape_3);
let grad_acc = scope.create_local(output.item());
let contributed_h = scope.create_local(Elem::Bool);
let contributed_w = scope.create_local(Elem::Bool);
let contributed_tmp = scope.create_local(Elem::Bool);
let count = scope.create_local(Elem::UInt);
let count_tmp = scope.create_local(Elem::UInt);
let count_float = scope.create_local(output.item());
let the_grad = scope.create_local(output.item());
let avg = scope.create_local(output.item());
let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index_base = b * grad_stride_0);
gpu!(scope, index_tmp = c * grad_stride_1);
gpu!(scope, index_base += index_tmp);
gpu!(
scope,
range(oh_start, oh_end).for_each(|oh, scope| {
let ih_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2);
let ih_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2);
gpu!(scope, contributed_h = ih >= ih_start);
gpu!(scope, contributed_tmp = ih < ih_end);
gpu!(scope, contributed_h = contributed_h && contributed_tmp);
gpu!(scope, if(contributed_h).then(|scope|{
gpu!(
scope,
range(ow_start, ow_end).for_each(|ow, scope| {
let iw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3);
let iw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3);
gpu!(scope, contributed_w = iw >= iw_start);
gpu!(scope, contributed_tmp = iw < iw_end);
gpu!(scope, contributed_w = contributed_w && contributed_tmp);
gpu!(scope, if(contributed_w).then(|scope|{
gpu!(scope, count = ih_end - ih_start);
gpu!(scope, count_tmp = iw_end - iw_start);
gpu!(scope, count *= count_tmp);
gpu!(scope, count_float = cast(count));
gpu!(scope, index = index_base);
gpu!(scope, index_tmp = oh * grad_stride_2);
gpu!(scope, index += index_tmp);
gpu!(scope, index_tmp = ow * grad_stride_3);
gpu!(scope, index += index_tmp);
gpu!(scope, the_grad = grad[index]);
gpu!(scope, avg = the_grad / count_float);
gpu!(scope, grad_acc += avg);
}));
})
);
}));
})
);
gpu!(scope, output[id] = grad_acc);
}
fn start_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index * input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = floor(div));
gpu!(scope, index = cast(div));
index
}
fn end_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);
gpu!(scope, index = output_size_index + 1u32);
gpu!(scope, index *= input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = ceil(div));
gpu!(scope, index = cast(div));
gpu!(scope, min = input_size < index);
gpu!(scope, if(min).then(|scope|{
gpu!(scope, end_index = input_size);
}).else(|scope|{
gpu!(scope, end_index = index);
}));
end_index
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
fn source(&self) -> SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
AdaptiveAvgPool2dBackwardComputeShader { grad, output }.expand(&mut scope);
let grad = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![grad, scalars],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>(),)
}
}
pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
@ -27,45 +254,24 @@ pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
output_buffer,
);
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::new();
let info_handle = build_info(&x, &out_grad);
x.client.execute(
Box::new(kernel),
&[&out_grad.handle, &output.handle, &info_handle],
execute_dynamic::<R, AdaptiveAvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
&out_grad.handle,
&out_grad.strides,
&out_grad.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
output
}
fn build_info<R: Runtime, E: JitElement>(
x: &JitTensor<R, E, 4>,
output: &JitTensor<R, E, 4>,
) -> Handle<R::Server> {
let mut info: [u32; 16] = [0; 16];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
info[2] = x.strides[2] as u32;
info[3] = x.strides[3] as u32;
info[4] = x.shape.dims[0] as u32;
info[5] = x.shape.dims[1] as u32;
info[6] = x.shape.dims[2] as u32;
info[7] = x.shape.dims[3] as u32;
info[8] = output.strides[0] as u32;
info[9] = output.strides[1] as u32;
info[10] = output.strides[2] as u32;
info[11] = output.strides[3] as u32;
info[12] = output.shape.dims[0] as u32;
info[13] = output.shape.dims[1] as u32;
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;
output.client.create(bytemuck::cast_slice(&info))
}

View File

@ -11,7 +11,7 @@ pub(crate) use adaptive_pool2d_shader::*;
pub(crate) use pool2d_shader::*;
pub(crate) use adaptive_avg_pool2d::*;
pub use adaptive_avg_pool2d_backward::*;
pub(crate) use adaptive_avg_pool2d_backward::*;
pub(crate) use avg_pool2d::*;
pub(crate) use avg_pool2d_backward::*;
pub(super) use base::*;

View File

@ -1,88 +0,0 @@
@group(0)
@binding(0)
var<storage, read> grad: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 16>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let grad_stride_0 = info[8];
let grad_stride_1 = info[9];
let grad_stride_2 = info[10];
let grad_stride_3 = info[11];
let grad_shape_0 = info[12];
let grad_shape_1 = info[13];
let grad_shape_2 = info[14];
let grad_shape_3 = info[15];
let b = id / input_stride_0 % input_shape_0;
let c = id / input_stride_1 % input_shape_1;
let ih = id / input_stride_2 % input_shape_2;
let iw = id / input_stride_3 % input_shape_3;
let oh_start = start_index(ih, input_shape_2, grad_shape_2);
let oh_end = end_index(ih, input_shape_2, grad_shape_2);
let ow_start = start_index(iw, input_shape_3, grad_shape_3);
let ow_end = end_index(iw, input_shape_3, grad_shape_3);
var grad_acc = 0.0;
for (var oh = oh_start; oh < oh_end; oh++) {
for (var ow = ow_start; ow < ow_end; ow++) {
let ih_start = start_index(oh, grad_shape_2, input_shape_2);
let ih_end = end_index(oh, grad_shape_2, input_shape_2);
let iw_start = start_index(ow, grad_shape_3, input_shape_3);
let iw_end = end_index(ow, grad_shape_3, input_shape_3);
let contributed_h = ih >= ih_start && ih < ih_end;
let contributed_w = iw >= iw_start && iw < iw_end;
// If no contribution skip
if !contributed_h || !contributed_w {
continue;
}
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
grad_acc += grad[index] / count;
}
}
output[id] = grad_acc;
}
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
}
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
return min(index, input_size);
}