mirror of https://github.com/tracel-ai/burn.git
Enable cuda-jit in burn-core + in text classification example (#2160)
This commit is contained in:
parent
7c17e84a0e
commit
ff8d0308fb
|
@ -1379,7 +1379,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cuda",
|
||||
|
@ -1390,7 +1390,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
|
@ -1404,7 +1404,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-macros",
|
||||
|
@ -1419,7 +1419,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
|
@ -1434,7 +1434,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-linalg"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-core",
|
||||
|
@ -1445,7 +1445,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"proc-macro2",
|
||||
|
@ -1456,7 +1456,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"cubecl-common",
|
||||
|
@ -1475,7 +1475,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"bytemuck",
|
||||
|
|
|
@ -143,8 +143,8 @@ sysinfo = "0.30.13"
|
|||
systemstat = "0.2.3"
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl" }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
||||
|
|
|
@ -99,6 +99,7 @@ record-backward-compat = []
|
|||
|
||||
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
|
||||
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
|
||||
|
||||
[dependencies]
|
||||
|
||||
|
|
|
@ -44,7 +44,12 @@ pub mod backend;
|
|||
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(all(test, not(feature = "test-tch"), not(feature = "test-wgpu"),))]
|
||||
#[cfg(all(
|
||||
test,
|
||||
not(feature = "test-tch"),
|
||||
not(feature = "test-wgpu"),
|
||||
not(feature = "test-cuda")
|
||||
))]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-tch"))]
|
||||
|
@ -53,6 +58,9 @@ pub type TestBackend = burn_tch::LibTorch<f32>;
|
|||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
pub type TestBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
pub type TestBackend = burn_cuda::Cuda;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg(test)]
|
||||
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
|
||||
|
|
|
@ -70,8 +70,10 @@ pub enum OutputRuntimeInfo {
|
|||
impl<R: JitRuntime> ExecutableKernel<R> {
|
||||
/// Execute the kernel.
|
||||
pub fn execute(self) {
|
||||
self.client
|
||||
.execute(self.kernel, self.cube_count, self.bindings)
|
||||
unsafe {
|
||||
self.client
|
||||
.execute_unchecked(self.kernel, self.cube_count, self.bindings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ cuda-jit = ["burn/cuda-jit"]
|
|||
|
||||
[dependencies]
|
||||
# Burn
|
||||
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "tui", "metrics", "autotune", "fusion", "default"], default-features = false}
|
||||
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false}
|
||||
|
||||
# Tokenizer
|
||||
tokenizers = { version = "0.19.1", default-features = false, features = [
|
||||
|
|
|
@ -29,7 +29,7 @@ use std::sync::Arc;
|
|||
pub struct ExperimentConfig {
|
||||
pub transformer: TransformerEncoderConfig,
|
||||
pub optimizer: AdamConfig,
|
||||
#[config(default = 512)]
|
||||
#[config(default = 256)]
|
||||
pub max_seq_length: usize,
|
||||
#[config(default = 32)]
|
||||
pub batch_size: usize,
|
||||
|
|
Loading…
Reference in New Issue