This commit is contained in:
Nathaniel Simard 2024-09-10 12:13:48 -04:00 committed by GitHub
parent 17050db57e
commit d3fbdeaa48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 6 deletions

View File

@ -31,7 +31,7 @@ env:
# Note: It is not possible to define env vars in composite actions. # Note: It is not possible to define env vars in composite actions.
# To work around this issue we use inputs and define all the env vars here. # To work around this issue we use inputs and define all the env vars here.
RUST_PREVIOUS_VERSION: 1.79.0 RUST_PREVIOUS_VERSION: 1.80.0
# Cargo # Cargo
CARGO_TERM_COLOR: "always" CARGO_TERM_COLOR: "always"

View File

@ -41,11 +41,10 @@ fn conv2d_kernel<F: Float>(
let in_channels = weight.shape(1); let in_channels = weight.shape(1);
let kernel_size_0 = kernel_size_0_unroll.unwrap_or_else(|| weight.shape(2)); let kernel_size_0 = weight.shape(2);
let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3)); let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3));
let unroll_1 = kernel_size_1_unroll.is_some(); let unroll_1 = kernel_size_1_unroll.is_some();
let b = ABSOLUTE_POS / output.stride(0) % output.shape(0); let b = ABSOLUTE_POS / output.stride(0) % output.shape(0);
let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1); let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1);
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2); let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
@ -130,7 +129,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let kernel_1_unroll = if kernel_1 > 8 { let kernel_1_unroll = if kernel_1 > 8 {
None None
} else { } else {
Some(kernel_1.into()) Some(kernel_1 as u32)
}; };
let out_0 = calculate_conv_output_size( let out_0 = calculate_conv_output_size(
@ -188,7 +187,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
ScalarArg::new(options.padding[1] as u32), ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(options.groups as u32), ScalarArg::new(options.groups as u32),
), ),
Some(kernel_1 as u32), kernel_1_unroll,
); );
output output

View File

@ -9,7 +9,7 @@ name = "burn"
readme.workspace = true readme.workspace = true
repository = "https://github.com/tracel-ai/burn" repository = "https://github.com/tracel-ai/burn"
version.workspace = true version.workspace = true
rust-version = "1.79" rust-version = "1.80"
[features] [features]
default = ["burn-core/default", "burn-train?/default", "std"] default = ["burn-core/default", "burn-train?/default", "std"]