Fix burn-jit conv2d excessive loop unrolling (#2263)

* Related to issue #2260
This commit is contained in:
Asher Jingkong Chen 2024-09-09 23:16:13 +08:00 committed by GitHub
parent 94cd8a2556
commit ccb5b2214e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 6 deletions

View File

@ -33,7 +33,6 @@ fn conv2d_kernel<F: Float>(
bias: &Tensor<F>, bias: &Tensor<F>,
output: &mut Tensor<F>, output: &mut Tensor<F>,
args: &Conv2dArgs, args: &Conv2dArgs,
kernel_size_0_unroll: Comptime<Option<UInt>>,
kernel_size_1_unroll: Comptime<Option<UInt>>, kernel_size_1_unroll: Comptime<Option<UInt>>,
) { ) {
if ABSOLUTE_POS >= output.len() { if ABSOLUTE_POS >= output.len() {
@ -42,8 +41,7 @@ fn conv2d_kernel<F: Float>(
let in_channels = weight.shape(1); let in_channels = weight.shape(1);
let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2)); let kernel_size_0 = weight.shape(2);
let unroll_0 = Comptime::is_some(kernel_size_0_unroll);
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3)); let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
let unroll_1 = Comptime::is_some(kernel_size_1_unroll); let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
@ -82,7 +80,7 @@ fn conv2d_kernel<F: Float>(
let index_input_1 = ic * input_stride_1; let index_input_1 = ic * input_stride_1;
let index_weight_1 = (ic - ic_start) * weight_stride_1; let index_weight_1 = (ic - ic_start) * weight_stride_1;
for kh in range(0, kernel_size_0, unroll_0) { for kh in range(0, kernel_size_0, Comptime::new(false)) {
for kw in range(0, kernel_size_1, unroll_1) { for kw in range(0, kernel_size_1, unroll_1) {
let ih = kh * args.dilation_0 + ih_base; let ih = kh * args.dilation_0 + ih_base;
let iw = kw * args.dilation_1 + iw_base; let iw = kw * args.dilation_1 + iw_base;
@ -126,6 +124,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let [batch_size, _, in_height, in_width] = input.shape.dims; let [batch_size, _, in_height, in_width] = input.shape.dims;
let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims;
// Limit loop unrolling factor to 8 or smaller
let kernel_1_unroll = if kernel_1 > 8 {
None
} else {
Some(kernel_1.into())
};
let out_0 = calculate_conv_output_size( let out_0 = calculate_conv_output_size(
kernel_0, kernel_0,
options.stride[0], options.stride[0],
@ -181,8 +186,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
ScalarArg::new(options.padding[1] as u32), ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(options.groups as u32), ScalarArg::new(options.groups as u32),
), ),
Some(kernel_0.into()), kernel_1_unroll,
Some(kernel_1.into()),
); );
output output