From d3fbdeaa4866a94e3c0bf12d9e31b35ae54f3a0a Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 10 Sep 2024 12:13:48 -0400 Subject: [PATCH] Fix CI (#2268) --- .github/workflows/test.yml | 2 +- crates/burn-jit/src/kernel/conv/conv2d.rs | 7 +++---- crates/burn/Cargo.toml | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3980ef93..d667e7a8f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,7 +31,7 @@ env: # Note: It is not possible to define env vars in composite actions. # To work around this issue we use inputs and define all the env vars here. - RUST_PREVIOUS_VERSION: 1.79.0 + RUST_PREVIOUS_VERSION: 1.80.0 # Cargo CARGO_TERM_COLOR: "always" diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index 93d8b4e91..c984008c1 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -41,11 +41,10 @@ fn conv2d_kernel( let in_channels = weight.shape(1); - let kernel_size_0 = kernel_size_0_unroll.unwrap_or_else(|| weight.shape(2)); + let kernel_size_0 = weight.shape(2); let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3)); let unroll_1 = kernel_size_1_unroll.is_some(); - let b = ABSOLUTE_POS / output.stride(0) % output.shape(0); let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1); let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2); @@ -130,7 +129,7 @@ pub(crate) fn conv2d( let kernel_1_unroll = if kernel_1 > 8 { None } else { - Some(kernel_1.into()) + Some(kernel_1 as u32) }; let out_0 = calculate_conv_output_size( @@ -188,7 +187,7 @@ pub(crate) fn conv2d( ScalarArg::new(options.padding[1] as u32), ScalarArg::new(options.groups as u32), ), - Some(kernel_1 as u32), + kernel_1_unroll, ); output diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 230bc797c..70c030247 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -9,7 +9,7 @@ name = "burn" readme.workspace = true repository = "https://github.com/tracel-ai/burn" version.workspace = true -rust-version = "1.79" +rust-version = "1.80" [features] default = ["burn-core/default", "burn-train?/default", "std"]