mirror of https://github.com/tracel-ai/burn.git
Fix CI (#2268)
This commit is contained in:
parent
17050db57e
commit
d3fbdeaa48
|
@ -31,7 +31,7 @@ env:
|
||||||
# Note: It is not possible to define env vars in composite actions.
|
# Note: It is not possible to define env vars in composite actions.
|
||||||
# To work around this issue we use inputs and define all the env vars here.
|
# To work around this issue we use inputs and define all the env vars here.
|
||||||
|
|
||||||
RUST_PREVIOUS_VERSION: 1.79.0
|
RUST_PREVIOUS_VERSION: 1.80.0
|
||||||
|
|
||||||
# Cargo
|
# Cargo
|
||||||
CARGO_TERM_COLOR: "always"
|
CARGO_TERM_COLOR: "always"
|
||||||
|
|
|
@ -41,11 +41,10 @@ fn conv2d_kernel<F: Float>(
|
||||||
|
|
||||||
let in_channels = weight.shape(1);
|
let in_channels = weight.shape(1);
|
||||||
|
|
||||||
let kernel_size_0 = kernel_size_0_unroll.unwrap_or_else(|| weight.shape(2));
|
let kernel_size_0 = weight.shape(2);
|
||||||
let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3));
|
let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3));
|
||||||
let unroll_1 = kernel_size_1_unroll.is_some();
|
let unroll_1 = kernel_size_1_unroll.is_some();
|
||||||
|
|
||||||
|
|
||||||
let b = ABSOLUTE_POS / output.stride(0) % output.shape(0);
|
let b = ABSOLUTE_POS / output.stride(0) % output.shape(0);
|
||||||
let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1);
|
let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1);
|
||||||
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
|
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
|
||||||
|
@ -130,7 +129,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||||
let kernel_1_unroll = if kernel_1 > 8 {
|
let kernel_1_unroll = if kernel_1 > 8 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(kernel_1.into())
|
Some(kernel_1 as u32)
|
||||||
};
|
};
|
||||||
|
|
||||||
let out_0 = calculate_conv_output_size(
|
let out_0 = calculate_conv_output_size(
|
||||||
|
@ -188,7 +187,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||||
ScalarArg::new(options.padding[1] as u32),
|
ScalarArg::new(options.padding[1] as u32),
|
||||||
ScalarArg::new(options.groups as u32),
|
ScalarArg::new(options.groups as u32),
|
||||||
),
|
),
|
||||||
Some(kernel_1 as u32),
|
kernel_1_unroll,
|
||||||
);
|
);
|
||||||
|
|
||||||
output
|
output
|
||||||
|
|
|
@ -9,7 +9,7 @@ name = "burn"
|
||||||
readme.workspace = true
|
readme.workspace = true
|
||||||
repository = "https://github.com/tracel-ai/burn"
|
repository = "https://github.com/tracel-ai/burn"
|
||||||
version.workspace = true
|
version.workspace = true
|
||||||
rust-version = "1.79"
|
rust-version = "1.80"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["burn-core/default", "burn-train?/default", "std"]
|
default = ["burn-core/default", "burn-train?/default", "std"]
|
||||||
|
|
Loading…
Reference in New Issue