Enable cuda-jit in burn-core + in text classification example (#2160)

This commit is contained in:
Nathaniel Simard 2024-08-12 18:22:27 -04:00 committed by GitHub
parent 7c17e84a0e
commit ff8d0308fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 26 additions and 15 deletions

16
Cargo.lock generated
View File

@ -1379,7 +1379,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1390,7 +1390,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"derive-new",
"getrandom",
@ -1404,7 +1404,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"bytemuck",
"cubecl-macros",
@ -1419,7 +1419,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1434,7 +1434,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1445,7 +1445,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"derive-new",
"proc-macro2",
@ -1456,7 +1456,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"async-channel",
"cubecl-common",
@ -1475,7 +1475,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
dependencies = [
"async-channel",
"bytemuck",

View File

@ -143,8 +143,8 @@ sysinfo = "0.30.13"
systemstat = "0.2.3"
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }

View File

@ -99,6 +99,7 @@ record-backward-compat = []
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
[dependencies]

View File

@ -44,7 +44,12 @@ pub mod backend;
extern crate alloc;
#[cfg(all(test, not(feature = "test-tch"), not(feature = "test-wgpu"),))]
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda")
))]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-tch"))]
@ -53,6 +58,9 @@ pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;
#[cfg(feature = "std")]
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;

View File

@ -70,8 +70,10 @@ pub enum OutputRuntimeInfo {
impl<R: JitRuntime> ExecutableKernel<R> {
/// Execute the kernel.
pub fn execute(self) {
self.client
.execute(self.kernel, self.cube_count, self.bindings)
unsafe {
self.client
.execute_unchecked(self.kernel, self.cube_count, self.bindings)
}
}
}

View File

@ -20,7 +20,7 @@ cuda-jit = ["burn/cuda-jit"]
[dependencies]
# Burn
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "tui", "metrics", "autotune", "fusion", "default"], default-features = false}
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false}
# Tokenizer
tokenizers = { version = "0.19.1", default-features = false, features = [

View File

@ -29,7 +29,7 @@ use std::sync::Arc;
pub struct ExperimentConfig {
pub transformer: TransformerEncoderConfig,
pub optimizer: AdamConfig,
#[config(default = 512)]
#[config(default = 256)]
pub max_seq_length: usize,
#[config(default = 32)]
pub batch_size: usize,