mirror of https://github.com/tracel-ai/burn.git
Fix burn-jit conv2d excessive loop unrolling (#2263)
* Related to issue #2260
This commit is contained in:
parent
94cd8a2556
commit
ccb5b2214e
|
@ -33,7 +33,6 @@ fn conv2d_kernel<F: Float>(
|
||||||
bias: &Tensor<F>,
|
bias: &Tensor<F>,
|
||||||
output: &mut Tensor<F>,
|
output: &mut Tensor<F>,
|
||||||
args: &Conv2dArgs,
|
args: &Conv2dArgs,
|
||||||
kernel_size_0_unroll: Comptime<Option<UInt>>,
|
|
||||||
kernel_size_1_unroll: Comptime<Option<UInt>>,
|
kernel_size_1_unroll: Comptime<Option<UInt>>,
|
||||||
) {
|
) {
|
||||||
if ABSOLUTE_POS >= output.len() {
|
if ABSOLUTE_POS >= output.len() {
|
||||||
|
@ -42,8 +41,7 @@ fn conv2d_kernel<F: Float>(
|
||||||
|
|
||||||
let in_channels = weight.shape(1);
|
let in_channels = weight.shape(1);
|
||||||
|
|
||||||
let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2));
|
let kernel_size_0 = weight.shape(2);
|
||||||
let unroll_0 = Comptime::is_some(kernel_size_0_unroll);
|
|
||||||
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
|
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
|
||||||
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
|
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
|
||||||
|
|
||||||
|
@ -82,7 +80,7 @@ fn conv2d_kernel<F: Float>(
|
||||||
let index_input_1 = ic * input_stride_1;
|
let index_input_1 = ic * input_stride_1;
|
||||||
let index_weight_1 = (ic - ic_start) * weight_stride_1;
|
let index_weight_1 = (ic - ic_start) * weight_stride_1;
|
||||||
|
|
||||||
for kh in range(0, kernel_size_0, unroll_0) {
|
for kh in range(0, kernel_size_0, Comptime::new(false)) {
|
||||||
for kw in range(0, kernel_size_1, unroll_1) {
|
for kw in range(0, kernel_size_1, unroll_1) {
|
||||||
let ih = kh * args.dilation_0 + ih_base;
|
let ih = kh * args.dilation_0 + ih_base;
|
||||||
let iw = kw * args.dilation_1 + iw_base;
|
let iw = kw * args.dilation_1 + iw_base;
|
||||||
|
@ -126,6 +124,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||||
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
||||||
let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims;
|
let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims;
|
||||||
|
|
||||||
|
// Limit loop unrolling factor to 8 or smaller
|
||||||
|
let kernel_1_unroll = if kernel_1 > 8 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(kernel_1.into())
|
||||||
|
};
|
||||||
|
|
||||||
let out_0 = calculate_conv_output_size(
|
let out_0 = calculate_conv_output_size(
|
||||||
kernel_0,
|
kernel_0,
|
||||||
options.stride[0],
|
options.stride[0],
|
||||||
|
@ -181,8 +186,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||||
ScalarArg::new(options.padding[1] as u32),
|
ScalarArg::new(options.padding[1] as u32),
|
||||||
ScalarArg::new(options.groups as u32),
|
ScalarArg::new(options.groups as u32),
|
||||||
),
|
),
|
||||||
Some(kernel_0.into()),
|
kernel_1_unroll,
|
||||||
Some(kernel_1.into()),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
output
|
output
|
||||||
|
|
Loading…
Reference in New Issue