From ad7d41fdf12e4f4e68dce4d6bb7c3c494776cac4 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 22 Mar 2024 08:38:16 -0400 Subject: [PATCH] delete some wgsl --- .../kernel/pool/adaptive_avg_pool2d copy.rs | 100 ----- .../src/kernel/pool/max_pool2d copy.rs | 403 ------------------ .../matmul/blocktiling_2d/unpadded.wgsl | 243 ----------- .../template/matmul/blocktiling_2d/vec4.wgsl | 165 ------- .../src/template/matmul/mem_coalescing.wgsl | 70 --- .../template/pool/adaptive_avg_pool2d.wgsl | 74 ---- 6 files changed, 1055 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d copy.rs delete mode 100644 crates/burn-jit/src/kernel/pool/max_pool2d copy.rs delete mode 100644 crates/burn-jit/src/template/matmul/blocktiling_2d/unpadded.wgsl delete mode 100644 crates/burn-jit/src/template/matmul/blocktiling_2d/vec4.wgsl delete mode 100644 crates/burn-jit/src/template/matmul/mem_coalescing.wgsl delete mode 100644 crates/burn-jit/src/template/pool/adaptive_avg_pool2d.wgsl diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d copy.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d copy.rs deleted file mode 100644 index a2c73c36a..000000000 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d copy.rs +++ /dev/null @@ -1,100 +0,0 @@ -// use crate::{ -// compute::StaticKernel, -// element::JitElement, -// kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, -// kernel_wgsl, -// ops::numeric::empty_device, -// tensor::JitTensor, -// Runtime, -// }; -// use burn_compute::server::Handle; -// use burn_tensor::Shape; - -// kernel_wgsl!( -// AdaptiveAvgPool2d, -// "../../template/pool/adaptive_avg_pool2d.wgsl" -// ); -// kernel_wgsl!( -// AdaptiveAvgPool2dBackward, -// "../../template/pool/adaptive_avg_pool2d_backward.wgsl" -// ); - -// pub(crate) fn adaptive_avg_pool2d( -// x: JitTensor, -// output_size: [usize; 2], -// ) -> JitTensor { -// let [batch_size, channels, _, _] = x.shape.dims; - -// let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); -// let output = empty_device(x.client.clone(), x.device.clone(), output_shape); - -// let kernel = StaticKernel::< -// KernelSettings, -// >::new(elemwise_workgroup( -// output.shape.num_elements(), -// WORKGROUP_DEFAULT, -// )); - -// let info_handle = build_info(&x, &output); -// x.client -// .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); - -// output -// } - -// pub(crate) fn adaptive_avg_pool2d_backward( -// x: JitTensor, -// out_grad: JitTensor, -// ) -> JitTensor { -// let output_shape = x.shape.clone(); -// let num_elems = output_shape.num_elements(); -// let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); -// let output = JitTensor::new( -// x.client.clone(), -// x.device.clone(), -// output_shape, -// output_buffer, -// ); - -// let kernel = StaticKernel::< -// KernelSettings, -// >::new(elemwise_workgroup( -// output.shape.num_elements(), -// WORKGROUP_DEFAULT, -// )); - -// let info_handle = build_info(&x, &out_grad); - -// x.client.execute( -// Box::new(kernel), -// &[&out_grad.handle, &output.handle, &info_handle], -// ); - -// output -// } - -// fn build_info( -// x: &JitTensor, -// output: &JitTensor, -// ) -> Handle { -// let mut info: [u32; 16] = [0; 16]; -// info[0] = x.strides[0] as u32; -// info[1] = x.strides[1] as u32; -// info[2] = x.strides[2] as u32; -// info[3] = x.strides[3] as u32; -// info[4] = x.shape.dims[0] as u32; -// info[5] = x.shape.dims[1] as u32; -// info[6] = x.shape.dims[2] as u32; -// info[7] = x.shape.dims[3] as u32; - -// info[8] = output.strides[0] as u32; -// info[9] = output.strides[1] as u32; -// info[10] = output.strides[2] as u32; -// info[11] = output.strides[3] as u32; -// info[12] = output.shape.dims[0] as u32; -// info[13] = output.shape.dims[1] as u32; -// info[14] = output.shape.dims[2] as u32; -// info[15] = output.shape.dims[3] as u32; - -// output.client.create(bytemuck::cast_slice(&info)) -// } diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d copy.rs b/crates/burn-jit/src/kernel/pool/max_pool2d copy.rs deleted file mode 100644 index b6c21a8f4..000000000 --- a/crates/burn-jit/src/kernel/pool/max_pool2d copy.rs +++ /dev/null @@ -1,403 +0,0 @@ -// use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape}; -// use std::marker::PhantomData; - -// use crate::{ -// codegen::{ -// dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility}, -// execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, -// InputInfo, OutputInfo, WorkgroupLaunch, -// }, -// element::JitElement, -// kernel::{DynamicKernelSource, SourceTemplate}, -// ops::numeric::empty_device, -// tensor::JitTensor, -// Runtime, RuntimeInt, -// }; - -// #[derive(new)] -// struct MaxPool2dEagerKernel { -// kernel_size: [usize; 2], -// _runtime: PhantomData, -// _elem: PhantomData, -// } - -// #[derive(new)] -// struct MaxPool2dWithIndicesEagerKernel { -// kernel_size: [usize; 2], -// _runtime: PhantomData, -// _elem: PhantomData, -// } - -// struct MaxPool2dComputeShader { -// x: Variable, -// output: Variable, -// kernel_size: [usize; 2], -// indices: Option, -// _elem: PhantomData, -// } - -// impl MaxPool2dComputeShader { -// fn expand(self, scope: &mut Scope) { -// let x = self.x; -// let output = self.output; -// let id = Variable::Id; - -// let input_stride_0 = scope.create_local(Elem::UInt); -// let input_stride_1 = scope.create_local(Elem::UInt); -// let input_stride_2 = scope.create_local(Elem::UInt); -// let input_stride_3 = scope.create_local(Elem::UInt); - -// let input_shape_2 = scope.create_local(Elem::UInt); -// let input_shape_3 = scope.create_local(Elem::UInt); - -// let output_stride_0 = scope.create_local(Elem::UInt); -// let output_stride_1 = scope.create_local(Elem::UInt); -// let output_stride_2 = scope.create_local(Elem::UInt); -// let output_stride_3 = scope.create_local(Elem::UInt); - -// let output_shape_0 = scope.create_local(Elem::UInt); -// let output_shape_1 = scope.create_local(Elem::UInt); -// let output_shape_2 = scope.create_local(Elem::UInt); -// let output_shape_3 = scope.create_local(Elem::UInt); - -// gpu!(scope, input_stride_0 = stride(x, 0u32)); -// gpu!(scope, input_stride_1 = stride(x, 1u32)); -// gpu!(scope, input_stride_2 = stride(x, 2u32)); -// gpu!(scope, input_stride_3 = stride(x, 3u32)); - -// gpu!(scope, input_shape_2 = shape(x, 2u32)); -// gpu!(scope, input_shape_3 = shape(x, 3u32)); - -// gpu!(scope, output_stride_0 = stride(output, 0u32)); -// gpu!(scope, output_stride_1 = stride(output, 1u32)); -// gpu!(scope, output_stride_2 = stride(output, 2u32)); -// gpu!(scope, output_stride_3 = stride(output, 3u32)); - -// gpu!(scope, output_shape_0 = shape(output, 0u32)); -// gpu!(scope, output_shape_1 = shape(output, 1u32)); -// gpu!(scope, output_shape_2 = shape(output, 2u32)); -// gpu!(scope, output_shape_3 = shape(output, 3u32)); - -// let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); -// let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); -// let dilation_0 = Variable::GlobalScalar(2, Elem::UInt); -// let dilation_1 = Variable::GlobalScalar(3, Elem::UInt); -// let padding_0 = Variable::GlobalScalar(4, Elem::UInt); -// let padding_1 = Variable::GlobalScalar(5, Elem::UInt); - -// let [kernel_size_0, kernel_size_1] = self.kernel_size; - -// let b = scope.create_local(Elem::UInt); -// let c = scope.create_local(Elem::UInt); -// let oh = scope.create_local(Elem::UInt); -// let ow = scope.create_local(Elem::UInt); - -// gpu!(scope, b = id / output_stride_0); -// gpu!(scope, b = b % output_shape_0); - -// gpu!(scope, c = id / output_stride_1); -// gpu!(scope, c = c % output_shape_1); - -// gpu!(scope, oh = id / output_stride_2); -// gpu!(scope, oh = oh % output_shape_2); - -// gpu!(scope, ow = id / output_stride_3); -// gpu!(scope, ow = ow % output_shape_3); - -// let tmp = scope.create_local(Elem::UInt); -// let ih = scope.create_local(Elem::UInt); -// let iw = scope.create_local(Elem::UInt); - -// let ih_pad = scope.create_local(Elem::UInt); -// let iw_pad = scope.create_local(Elem::UInt); -// let result = scope.create_local(x.item()); - -// let cond = scope.create_local(Elem::Bool); -// let cond_tmp = scope.create_local(Elem::Bool); - -// let index_input = scope.create_local(Elem::UInt); -// let index_input_1 = scope.create_local(Elem::UInt); -// let index_input_2 = scope.create_local(Elem::UInt); -// let index_input_3 = scope.create_local(Elem::UInt); -// let index_input_4 = scope.create_local(Elem::UInt); - -// let is_max = scope.create_local(Elem::Bool); -// let max_index = self.indices.map(|_| scope.create_local(Elem::UInt)); - -// let max_val = scope.create_local(x.item()); -// let max_initial = -// Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), x.item().elem()); -// gpu!(scope, max_val = max_initial); - -// (0..kernel_size_0).for_each(|kh| { -// gpu!(scope, ih = oh * pool_stride_0); -// gpu!(scope, tmp = kh * dilation_0); -// gpu!(scope, ih += tmp); - -// // Up -// gpu!(scope, cond = ih < padding_0); -// // Down -// gpu!(scope, tmp = input_shape_2 + padding_0); -// gpu!(scope, cond_tmp = ih >= tmp); -// gpu!(scope, cond = cond || cond_tmp); -// gpu!(scope, cond = !cond); - -// gpu!(scope, if (cond).then(|scope| { -// (0..kernel_size_1).for_each(|kw| { -// gpu!(scope, iw = ow * pool_stride_1); -// gpu!(scope, tmp = kw * dilation_1); -// gpu!(scope, iw = iw + tmp); - -// // Left -// gpu!(scope, cond = iw < padding_1); -// // Right -// gpu!(scope, tmp = input_shape_3 + padding_1); -// gpu!(scope, cond_tmp = iw >= tmp); -// gpu!(scope, cond = cond || cond_tmp); -// gpu!(scope, cond = !cond); - -// gpu!(scope, if (cond).then(|scope| { -// gpu!(scope, ih_pad = ih - padding_0); -// gpu!(scope, iw_pad = iw - padding_1); - -// gpu!(scope, index_input_1 = b * input_stride_0); -// gpu!(scope, index_input_2 = c * input_stride_1); -// gpu!(scope, index_input_3 = ih_pad * input_stride_2); -// gpu!(scope, index_input_4 = iw_pad * input_stride_3); - -// gpu!(scope, index_input = index_input_1); -// gpu!(scope, index_input += index_input_2); -// gpu!(scope, index_input += index_input_3); -// gpu!(scope, index_input += index_input_4); - -// gpu!(scope, result = x[index_input]); - -// gpu!(scope, is_max = result > max_val); - -// gpu!(scope, if(is_max).then(|scope|{ -// gpu!(scope, max_val = result); -// if let Some(max_index) = max_index { -// gpu!(scope, max_index = ih_pad * input_shape_2); -// gpu!(scope, max_index += iw_pad); -// } -// })); -// })); -// }); -// })); -// }); - -// gpu!(scope, output[id] = max_val); - -// if let Some(indices) = self.indices { -// let max_index = max_index.unwrap(); -// gpu!(scope, indices[id] = max_index); -// } -// } -// } - -// impl DynamicKernelSource for MaxPool2dEagerKernel { -// fn source(&self) -> crate::kernel::SourceTemplate { -// let mut scope = Scope::root(); -// let item = E::gpu_elem().into(); - -// let x = Variable::GlobalInputArray(0, item); -// let output = Variable::GlobalOutputArray(0, item); - -// scope.write_global_custom(output); - -// MaxPool2dComputeShader { -// x, -// output, -// kernel_size: self.kernel_size, -// indices: None, -// _elem: PhantomData::, -// } -// .expand(&mut scope); - -// let input = InputInfo::Array { -// item, -// visibility: Visibility::Read, -// }; -// let scalars = InputInfo::Scalar { -// elem: Elem::UInt, -// size: 6, -// }; -// let output = OutputInfo::Array { item }; - -// let info = CompilationInfo { -// inputs: vec![input, scalars], -// outputs: vec![output], -// scope, -// }; - -// let settings = CompilationSettings::default(); -// let shader = Compilation::new(info).compile(settings); -// let shader = ::compile(shader); -// SourceTemplate::new(shader.to_string()) -// } - -// fn id(&self) -> String { -// format!( -// "{:?}k={:?}", -// core::any::TypeId::of::(), -// self.kernel_size, -// ) -// } -// } - -// impl DynamicKernelSource for MaxPool2dWithIndicesEagerKernel { -// fn source(&self) -> crate::kernel::SourceTemplate { -// let mut scope = Scope::root(); -// let item = E::gpu_elem().into(); - -// let x = Variable::GlobalInputArray(0, item); -// let output = Variable::GlobalOutputArray(0, item); -// let indices = Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int)); - -// scope.write_global_custom(output); - -// MaxPool2dComputeShader { -// x, -// output, -// kernel_size: self.kernel_size, -// indices: Some(indices), -// _elem: PhantomData::, -// } -// .expand(&mut scope); - -// let input = InputInfo::Array { -// item, -// visibility: Visibility::Read, -// }; -// let scalars = InputInfo::Scalar { -// elem: Elem::UInt, -// size: 6, -// }; -// let output = OutputInfo::Array { item }; -// let indices = OutputInfo::Array { -// item: Item::Scalar(Elem::Int), -// }; - -// let info = CompilationInfo { -// inputs: vec![input, scalars], -// outputs: vec![output, indices], -// scope, -// }; - -// let settings = CompilationSettings::default(); -// let shader = Compilation::new(info).compile(settings); -// let shader = ::compile(shader); -// SourceTemplate::new(shader.to_string()) -// } - -// fn id(&self) -> String { -// format!( -// "{:?}k={:?}", -// core::any::TypeId::of::(), -// self.kernel_size, -// ) -// } -// } - -pub(crate) fn max_pool2d_with_indices( - x: JitTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], -) -> (JitTensor, JitTensor) { - let [batch_size, channels, _, _] = x.shape.dims; - - let size_0 = calculate_pool_output_size( - kernel_size[0], - stride[0], - padding[0], - dilation[0], - x.shape.dims[2], - ); - let size_1 = calculate_pool_output_size( - kernel_size[1], - stride[1], - padding[1], - dilation[1], - x.shape.dims[3], - ); - - let shape_out = Shape::new([batch_size, channels, size_0, size_1]); - let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone()); - let indices = empty_device(x.client.clone(), x.device.clone(), shape_out); - - let kernel = MaxPool2dWithIndicesEagerKernel::new(kernel_size); - execute_dynamic::, I>( - &[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)], - &[ - EagerHandle::new(&output.handle, &output.strides, &output.shape.dims), - EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims), - ], - Some(&[ - (stride[0] as i32).elem(), - (stride[1] as i32).elem(), - (dilation[0] as i32).elem(), - (dilation[1] as i32).elem(), - (padding[0] as i32).elem(), - (padding[1] as i32).elem(), - ]), - kernel, - WorkgroupLaunch::Output { pos: 0 }, - x.client, - ); - - (output, indices) -} - -// pub(crate) fn max_pool2d( -// x: JitTensor, -// kernel_size: [usize; 2], -// stride: [usize; 2], -// padding: [usize; 2], -// dilation: [usize; 2], -// ) -> JitTensor { -// let [batch_size, channels, _, _] = x.shape.dims; - -// let size_0 = calculate_pool_output_size( -// kernel_size[0], -// stride[0], -// padding[0], -// dilation[0], -// x.shape.dims[2], -// ); -// let size_1 = calculate_pool_output_size( -// kernel_size[1], -// stride[1], -// padding[1], -// dilation[1], -// x.shape.dims[3], -// ); - -// let shape_out = Shape::new([batch_size, channels, size_0, size_1]); -// let output = empty_device(x.client.clone(), x.device.clone(), shape_out); - -// let kernel = MaxPool2dEagerKernel::new(kernel_size); - -// execute_dynamic::, RuntimeInt>( -// &[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)], -// &[EagerHandle::new( -// &output.handle, -// &output.strides, -// &output.shape.dims, -// )], -// Some(&[ -// (stride[0] as i32).elem(), -// (stride[1] as i32).elem(), -// (dilation[0] as i32).elem(), -// (dilation[1] as i32).elem(), -// (padding[0] as i32).elem(), -// (padding[1] as i32).elem(), -// ]), -// kernel, -// WorkgroupLaunch::Output { pos: 0 }, -// x.client, -// ); - -// output -// } diff --git a/crates/burn-jit/src/template/matmul/blocktiling_2d/unpadded.wgsl b/crates/burn-jit/src/template/matmul/blocktiling_2d/unpadded.wgsl deleted file mode 100644 index 9dd3f9407..000000000 --- a/crates/burn-jit/src/template/matmul/blocktiling_2d/unpadded.wgsl +++ /dev/null @@ -1,243 +0,0 @@ -@group(0) -@binding(0) -var lhs: array<{{ elem }}>; - -@group(0) -@binding(1) -var rhs: array<{{ elem }}>; - -@group(0) -@binding(2) -var output: array<{{ elem }}>; - -@group(0) -@binding(3) -var info: array; - -const B_M = {{b_m}}u; -const B_N = {{b_n}}u; -const B_K = {{b_k}}u; -const B_M_X_B_K_4 = {{bm_x_bk_4}}u; -const B_K_X_B_N_4 = {{bk_x_bn_4}}u; - -const T_M = 4u; -const T_N = 4u; -const T_M_X_T_N = 16u; - -var shared_lhs: array, B_M_X_B_K_4>; -var shared_rhs: array, B_K_X_B_N_4>; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_index) local_idx: u32, - @builtin(workgroup_id) workgroup_id: vec3, -) { - let skip_row = workgroup_id.x * B_M; - let skip_col = workgroup_id.y * B_N; - - let n_thread_per_row = ((B_N - 1u) / T_N) + 1u; - - // Position of the first element of the thread, relative to the block - let thread_row = (local_idx / n_thread_per_row) * T_M; - let thread_col = (local_idx % n_thread_per_row) * T_N; - - // Position of the first element of the thread, in absolute (in one batch) - let row = skip_row + thread_row; - let col = skip_col + thread_col; - - let batch = global_id.z; - - // Basic information - let dim = info[0]; - let n_rows = info[6u * dim - 1u]; - let n_cols = info[6u * dim]; - let K = info[5u * dim - 1u]; - - // Row / col strides - let lhs_stride_row = info[dim - 1u]; - let lhs_stride_col = info[dim]; - let rhs_stride_row = info[2u * dim - 1u]; - let rhs_stride_col = info[2u * dim]; - let out_stride_row = info [3u * dim - 1u]; - let out_stride_col = info [3u * dim]; - - // Calculate the corresponding offsets with support for broadcasting. - let offset_output = batch * n_rows * n_cols; - var offset_lhs: u32 = skip_row * lhs_stride_row; - var offset_rhs: u32 = skip_col * rhs_stride_col; - - let batch_dims = dim - 2u; - for (var b: u32 = 1u; b <= batch_dims; b++) { - let stride_lhs = info[b]; - let stride_rhs = info[b + dim]; - let stride_output = info[b + 2u * dim]; - let shape_lhs = info[b + 3u * dim]; - let shape_rhs = info[b + 4u * dim]; - - offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; - offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; - } - - // Registers used in the compute pass - var results: array<{{ elem }}, T_M_X_T_N>; - var register_M: vec4<{{ elem }}>; - var register_N: vec4<{{ elem }}>; - - // How close is the thread to the end of the matrix. - // If < 4 then it is an edge case - let remain_row_lhs = n_rows - row; - let remain_col_rhs = n_cols - col; - - for (var k = 0u; k < K; k += B_K) { - - // LHS LOAD PASS - - // For the 4 vec4 columns of this thread - for (var j = 0u; j < 4u; j++) { - - // The precise - let current_col = thread_col + j; - - // Position of the column vec4 in shared memory - let lhs_sm_position = (thread_row/4u) * B_K + current_col; - - // To avoid overwriting following row in share memory - if current_col < B_K { - // To pad with zeros if outside lhs - if current_col + k < K && remain_row_lhs >= 1u { - let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row; - let lhs_position1 = lhs_position0 + lhs_stride_row; - let lhs_position2 = lhs_position1 + lhs_stride_row; - let lhs_position3 = lhs_position2 + lhs_stride_row; - - if remain_row_lhs >= 4u { - shared_lhs[lhs_sm_position] = vec4( - lhs[lhs_position0], - lhs[lhs_position1], - lhs[lhs_position2], - lhs[lhs_position3], - ); - } else if remain_row_lhs == 3u { - shared_lhs[lhs_sm_position] = vec4( - lhs[lhs_position0], - lhs[lhs_position1], - lhs[lhs_position2], - 0. - ); - } else if remain_row_lhs == 2u { - shared_lhs[lhs_sm_position] = vec4( - lhs[lhs_position0], - lhs[lhs_position1], - 0., - 0. - ); - } else if remain_row_lhs == 1u { - shared_lhs[lhs_sm_position] = vec4( - lhs[lhs_position0], - 0., - 0., - 0. - ); - } - } else { - shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.); - } - } - } - - // RHS LOAD PASS - - for (var i = 0u; i < 4u; i++) { - let current_row = thread_row + i; - - let rhs_sm_position = (current_row * B_N + thread_col) / 4u; - - if current_row < B_K { - if current_row + k < K && remain_col_rhs >= 1u { - - let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col; - let rhs_position1 = rhs_position0 + rhs_stride_col; - let rhs_position2 = rhs_position1 + rhs_stride_col; - let rhs_position3 = rhs_position2 + rhs_stride_col; - - if remain_col_rhs >= 4u { - shared_rhs[rhs_sm_position] = vec4( - rhs[rhs_position0], - rhs[rhs_position1], - rhs[rhs_position2], - rhs[rhs_position3], - ); - } else if remain_col_rhs == 3u { - shared_rhs[rhs_sm_position] = vec4( - rhs[rhs_position0], - rhs[rhs_position1], - rhs[rhs_position2], - 0. - ); - } else if remain_col_rhs == 2u { - shared_rhs[rhs_sm_position] = vec4( - rhs[rhs_position0], - rhs[rhs_position1], - 0., - 0. - ); - } else if remain_col_rhs == 1u { - shared_rhs[rhs_sm_position] = vec4( - rhs[rhs_position0], - 0., - 0., - 0. - ); - } - } else { - shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.); - } - } - } - - workgroupBarrier(); - - // COMPUTE PASS - - // Compute intermediate results - // Results are cumulated in results array and updated at each block - // Outer loop indicates which subcolumns/subrows to read from shared memories - for (var dot_index = 0u; dot_index < B_K; dot_index++) { - - // Load a subcolumn of values from lhs - let lhs_sm_position = (thread_row/4u) * B_K + dot_index; - register_M = shared_lhs[lhs_sm_position]; - - // Load a subrow of values from rhs - let rhs_sm_position = (dot_index * B_N + thread_col) / 4u; - register_N = shared_rhs[rhs_sm_position]; - - // Multiply subcolumn and subrow and store results - for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { - for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { - results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; - } - } - } - - workgroupBarrier(); - } - - // OUTPUT PASS - - // Write output matrix - // Each thread is responsible of writing T_M x T_N results - for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { - for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { - let row_index = row + res_idx_M; - let col_index = col + res_idx_N; - if row_index < n_rows && col_index < n_cols { - let result_position = res_idx_M * T_N + res_idx_N; - let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col; - output[output_position] = results[result_position]; - } - } - } -} diff --git a/crates/burn-jit/src/template/matmul/blocktiling_2d/vec4.wgsl b/crates/burn-jit/src/template/matmul/blocktiling_2d/vec4.wgsl deleted file mode 100644 index e22d9d0f7..000000000 --- a/crates/burn-jit/src/template/matmul/blocktiling_2d/vec4.wgsl +++ /dev/null @@ -1,165 +0,0 @@ -@group(0) -@binding(0) -var lhs: array<{{ elem }}>; - -@group(0) -@binding(1) -var rhs: array<{{ elem }}>; - -@group(0) -@binding(2) -var output: array<{{ elem }}>; - -@group(0) -@binding(3) -var info: array; - -const B_M = {{b_m}}u; -const B_N = {{b_n}}u; -const B_K = {{b_k}}u; -const B_M_X_B_K_4 = {{bm_x_bk_4}}u; -const B_K_X_B_N_4 = {{bk_x_bn_4}}u; - -const T_M = 4u; -const T_N = 4u; -const T_M_X_T_N = 16u; - -var shared_lhs: array, B_M_X_B_K_4>; -var shared_rhs: array, B_K_X_B_N_4>; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_index) local_idx: u32, - @builtin(workgroup_id) workgroup_id: vec3, -) { - let skip_row = workgroup_id.x * B_M; - let skip_col = workgroup_id.y * B_N; - - let n_thread_per_row = ((B_N - 1u) / T_N) + 1u; - let thread_row = (local_idx / n_thread_per_row) * T_M; - let thread_col = (local_idx % n_thread_per_row) * T_N; - - let row = skip_row + thread_row; - let col = skip_col + thread_col; - - let batch = global_id.z; - - // Basic information - let dim = info[0]; - let n_rows = info[6u * dim - 1u]; - let n_cols = info[6u * dim]; - let K = info[5u * dim - 1u]; - - // Row / col strides - let lhs_stride_row = info[dim - 1u]; - let lhs_stride_col = info[dim]; - let rhs_stride_row = info[2u * dim - 1u]; - let rhs_stride_col = info[2u * dim]; - let out_stride_row = info [3u * dim - 1u]; - let out_stride_col = info [3u * dim]; - - // Calculate the corresponding offsets with support for broadcasting. - let offset_output = batch * n_rows * n_cols; - var offset_lhs: u32 = skip_row * lhs_stride_row; - var offset_rhs: u32 = skip_col * rhs_stride_col; - - let batch_dims = dim - 2u; - for (var b: u32 = 1u; b <= batch_dims; b++) { - let stride_lhs = info[b]; - let stride_rhs = info[b + dim]; - let stride_output = info[b + 2u * dim]; - let shape_lhs = info[b + 3u * dim]; - let shape_rhs = info[b + 4u * dim]; - - offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; - offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; - } - - var results: array<{{ elem }}, T_M_X_T_N>; - var register_M: vec4<{{ elem }}>; - var register_N: vec4<{{ elem }}>; - - for (var k = 0u; k < K; k += B_K) { - // Load data into shared memories - // Each thread is responsible of loading T_M x T_N values from both lhs and rhs - - for (var j = 0u; j < 4u; j++) { - let current_col = thread_col + j; - - if current_col < B_K { // so that threads who work on between B_K and B_N store nothing - - let lhs_sm_position = (thread_row/4u) * B_K + current_col; - - let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row; - let lhs_position1 = lhs_position0 + lhs_stride_row; - let lhs_position2 = lhs_position1 + lhs_stride_row; - let lhs_position3 = lhs_position2 + lhs_stride_row; - - shared_lhs[lhs_sm_position] = vec4( - lhs[lhs_position0], - lhs[lhs_position1], - lhs[lhs_position2], - lhs[lhs_position3], - ); - } - } - - for (var i = 0u; i < 4u; i++) { - let current_row = thread_row + i; - - if current_row < B_K { // so that threads who work on between B_K and B_N store nothing - - let rhs_sm_position = (current_row * B_N + thread_col) / 4u; - - let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col; - let rhs_position1 = rhs_position0 + rhs_stride_col; - let rhs_position2 = rhs_position1 + rhs_stride_col; - let rhs_position3 = rhs_position2 + rhs_stride_col; - - shared_rhs[rhs_sm_position] = vec4( - rhs[rhs_position0], - rhs[rhs_position1], - rhs[rhs_position2], - rhs[rhs_position3], - ); - } - } - - workgroupBarrier(); - - // Compute intermediate results - // Results are cumulated in results array and updated at each block - // Outer loop indicates which subcolumns/subrows to read from shared memories - for (var dot_index = 0u; dot_index < B_K; dot_index++) { - - // Load a subcolumn of values from lhs - let lhs_sm_position = (thread_row/4u) * B_K + dot_index; - register_M = shared_lhs[lhs_sm_position]; - - // Load a subrow of values from rhs - let rhs_sm_position = (dot_index * B_N + thread_col) / 4u; - register_N = shared_rhs[rhs_sm_position]; - - // Multiply subcolumn and subrow and store results - for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { - for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { - results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; - } - } - } - - workgroupBarrier(); - } - - // Write output matrix - // Each thread is responsible of writing T_M x T_N results - for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { - for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { - let result_position = res_idx_M * T_N + res_idx_N; - let output_position = offset_output + (row + res_idx_M) * out_stride_row + (col + res_idx_N) * out_stride_col; - output[output_position] = results[result_position]; - } - } -} diff --git a/crates/burn-jit/src/template/matmul/mem_coalescing.wgsl b/crates/burn-jit/src/template/matmul/mem_coalescing.wgsl deleted file mode 100644 index f6e28202d..000000000 --- a/crates/burn-jit/src/template/matmul/mem_coalescing.wgsl +++ /dev/null @@ -1,70 +0,0 @@ -@group(0) -@binding(0) -var lhs: array<{{ elem }}>; - -@group(0) -@binding(1) -var rhs: array<{{ elem }}>; - -@group(0) -@binding(2) -var output: array<{{ elem }}>; - -@group(0) -@binding(3) -var info: array; - -const BLOCK_SIZE = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_index) local_idx: u32, - @builtin(workgroup_id) workgroup_id: vec3, -) { - // Indices - let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE); - let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE); - let batch = global_id.z; - - // Basic information - let dim = info[0]; - let n_rows = info[6u * dim - 1u]; - let n_cols = info[6u * dim]; - let K = info[5u * dim - 1u]; - - // Returns if outside the output dimension - if row >= n_rows || col >= n_cols { - return; - } - - // Calculate the corresponding offsets with support for broadcasting. - let offset_output = batch * n_rows * n_cols; - var offset_lhs: u32 = 0u; - var offset_rhs: u32 = 0u; - - let batch_dims = dim - 2u; - for (var b: u32 = 1u; b <= batch_dims; b++) { - let stride_lhs = info[b]; - let stride_rhs = info[b + dim]; - let stride_output = info[b + 2u * dim]; - let shape_lhs = info[b + 3u * dim]; - let shape_rhs = info[b + 4u * dim]; - - offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; - offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; - } - - // Basic matmul implementation - var sum = 0.0; - for (var k: u32 = 0u; k < K; k++) { - let lhs_index = row * K + k; - let rhs_index = k * n_cols + col; - - sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index]; - } - - let output_index = row * n_cols + col; - output[offset_output + output_index] = sum; -} diff --git a/crates/burn-jit/src/template/pool/adaptive_avg_pool2d.wgsl b/crates/burn-jit/src/template/pool/adaptive_avg_pool2d.wgsl deleted file mode 100644 index f1b57e70f..000000000 --- a/crates/burn-jit/src/template/pool/adaptive_avg_pool2d.wgsl +++ /dev/null @@ -1,74 +0,0 @@ -@group(0) -@binding(0) -var x: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@group(0) -@binding(2) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - - let input_stride_0 = info[0]; - let input_stride_1 = info[1]; - let input_stride_2 = info[2]; - let input_stride_3 = info[3]; - let input_shape_0 = info[4]; - let input_shape_1 = info[5]; - let input_shape_2 = info[6]; - let input_shape_3 = info[7]; - - let output_stride_0 = info[8]; - let output_stride_1 = info[9]; - let output_stride_2 = info[10]; - let output_stride_3 = info[11]; - let output_shape_0 = info[12]; - let output_shape_1 = info[13]; - let output_shape_2 = info[14]; - let output_shape_3 = info[15]; - - let b = id / output_stride_0 % output_shape_0; - let c = id / output_stride_1 % output_shape_1; - let oh = id / output_stride_2 % output_shape_2; - let ow = id / output_stride_3 % output_shape_3; - - let ih_start = start_index(oh, output_shape_2, input_shape_2); - let ih_end = end_index(oh, output_shape_2, input_shape_2); - - let iw_start = start_index(ow, output_shape_3, input_shape_3); - let iw_end = end_index(ow, output_shape_3, input_shape_3); - - var sum = 0.0; - - for (var ih = ih_start; ih < ih_end; ih++) { - for (var iw = iw_start; iw < iw_end; iw++) { - let index_input = b * input_stride_0 + c * input_stride_1 + ih * input_stride_2 + iw * input_stride_3; - sum += x[index_input]; - } - } - - let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start)); - output[id] = sum / count; -} - -fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { - return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size))); -} - -fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { - let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size))); - - return min(index, input_size); -} -