mirror of https://github.com/tracel-ai/burn.git
Fix burn-jit conv2d excessive loop unrolling (#2263)
* Related to issue #2260
This commit is contained in:
parent
94cd8a2556
commit
ccb5b2214e
|
@ -33,7 +33,6 @@ fn conv2d_kernel<F: Float>(
|
|||
bias: &Tensor<F>,
|
||||
output: &mut Tensor<F>,
|
||||
args: &Conv2dArgs,
|
||||
kernel_size_0_unroll: Comptime<Option<UInt>>,
|
||||
kernel_size_1_unroll: Comptime<Option<UInt>>,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.len() {
|
||||
|
@ -42,8 +41,7 @@ fn conv2d_kernel<F: Float>(
|
|||
|
||||
let in_channels = weight.shape(1);
|
||||
|
||||
let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2));
|
||||
let unroll_0 = Comptime::is_some(kernel_size_0_unroll);
|
||||
let kernel_size_0 = weight.shape(2);
|
||||
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
|
||||
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
|
||||
|
||||
|
@ -82,7 +80,7 @@ fn conv2d_kernel<F: Float>(
|
|||
let index_input_1 = ic * input_stride_1;
|
||||
let index_weight_1 = (ic - ic_start) * weight_stride_1;
|
||||
|
||||
for kh in range(0, kernel_size_0, unroll_0) {
|
||||
for kh in range(0, kernel_size_0, Comptime::new(false)) {
|
||||
for kw in range(0, kernel_size_1, unroll_1) {
|
||||
let ih = kh * args.dilation_0 + ih_base;
|
||||
let iw = kw * args.dilation_1 + iw_base;
|
||||
|
@ -126,6 +124,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims;
|
||||
|
||||
// Limit loop unrolling factor to 8 or smaller
|
||||
let kernel_1_unroll = if kernel_1 > 8 {
|
||||
None
|
||||
} else {
|
||||
Some(kernel_1.into())
|
||||
};
|
||||
|
||||
let out_0 = calculate_conv_output_size(
|
||||
kernel_0,
|
||||
options.stride[0],
|
||||
|
@ -181,8 +186,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
ScalarArg::new(options.padding[1] as u32),
|
||||
ScalarArg::new(options.groups as u32),
|
||||
),
|
||||
Some(kernel_0.into()),
|
||||
Some(kernel_1.into()),
|
||||
kernel_1_unroll,
|
||||
);
|
||||
|
||||
output
|
||||
|
|
Loading…
Reference in New Issue