diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index 60b107ad0..b5fced121 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -33,7 +33,6 @@ fn conv2d_kernel( bias: &Tensor, output: &mut Tensor, args: &Conv2dArgs, - kernel_size_0_unroll: Comptime>, kernel_size_1_unroll: Comptime>, ) { if ABSOLUTE_POS >= output.len() { @@ -42,8 +41,7 @@ fn conv2d_kernel( let in_channels = weight.shape(1); - let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2)); - let unroll_0 = Comptime::is_some(kernel_size_0_unroll); + let kernel_size_0 = weight.shape(2); let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3)); let unroll_1 = Comptime::is_some(kernel_size_1_unroll); @@ -82,7 +80,7 @@ fn conv2d_kernel( let index_input_1 = ic * input_stride_1; let index_weight_1 = (ic - ic_start) * weight_stride_1; - for kh in range(0, kernel_size_0, unroll_0) { + for kh in range(0, kernel_size_0, Comptime::new(false)) { for kw in range(0, kernel_size_1, unroll_1) { let ih = kh * args.dilation_0 + ih_base; let iw = kw * args.dilation_1 + iw_base; @@ -126,6 +124,13 @@ pub(crate) fn conv2d( let [batch_size, _, in_height, in_width] = input.shape.dims; let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; + // Limit loop unrolling factor to 8 or smaller + let kernel_1_unroll = if kernel_1 > 8 { + None + } else { + Some(kernel_1.into()) + }; + let out_0 = calculate_conv_output_size( kernel_0, options.stride[0], @@ -181,8 +186,7 @@ pub(crate) fn conv2d( ScalarArg::new(options.padding[1] as u32), ScalarArg::new(options.groups as u32), ), - Some(kernel_0.into()), - Some(kernel_1.into()), + kernel_1_unroll, ); output