mirror of https://github.com/tracel-ai/burn.git
Add burn-hip (#2399)
This commit is contained in:
parent
09c65ca20f
commit
6678f79e58
|
@ -1,2 +1,3 @@
|
|||
[alias]
|
||||
bb = "run --release --bin burnbench --"
|
||||
xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"
|
|
@ -126,6 +126,19 @@ jobs:
|
|||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-burn-hip:
|
||||
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
|
||||
needs:
|
||||
- publish-burn-tensor
|
||||
- publish-burn-autodiff
|
||||
- publish-burn-ndarray
|
||||
- publish-burn-common
|
||||
- publish-burn-jit
|
||||
with:
|
||||
crate: burn-hip
|
||||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-burn-candle:
|
||||
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
|
||||
needs:
|
||||
|
|
|
@ -628,6 +628,20 @@ dependencies = [
|
|||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-hip"
|
||||
version = "0.15.0"
|
||||
dependencies = [
|
||||
"burn-fusion",
|
||||
"burn-jit",
|
||||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-import"
|
||||
version = "0.15.0"
|
||||
|
@ -1434,7 +1448,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cuda",
|
||||
|
@ -1447,7 +1461,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"embassy-futures",
|
||||
|
@ -1463,7 +1477,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
|
@ -1480,7 +1494,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-cpp"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
|
@ -1494,7 +1508,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
|
@ -1510,7 +1524,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-hip"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
|
@ -1535,7 +1549,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-linalg"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-core",
|
||||
|
@ -1546,7 +1560,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"darling",
|
||||
|
@ -1561,7 +1575,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-opt"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
|
@ -1576,7 +1590,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"cfg_aliases 0.2.1",
|
||||
|
@ -1596,7 +1610,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-spirv"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
|
@ -1609,7 +1623,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"async-channel",
|
||||
|
|
|
@ -152,8 +152,8 @@ ahash = { version = "0.8.11", default-features = false }
|
|||
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e3fdc96ec2d68dcdde8135bd011907b4e662388c" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e3fdc96ec2d68dcdde8135bd011907b4e662388c" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb2a5c87802c9d01801f16d93cada5fe924989f3" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb2a5c87802c9d01801f16d93cada5fe924989f3" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
|
||||
|
|
|
@ -17,6 +17,7 @@ candle-cuda = ["burn/candle-cuda"]
|
|||
candle-metal = ["burn/candle", "burn/metal"]
|
||||
cuda-jit = ["burn/cuda-jit"]
|
||||
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
|
||||
hip-jit = ["burn/hip-jit"]
|
||||
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
|
||||
|
|
|
@ -15,6 +15,9 @@ The end of options argument `--` is used to pass arguments to the `burnbench`
|
|||
application. For instance `cargo run --bin burnbench -- list` passes the `list`
|
||||
argument to `burnbench` effectively calling `burnbench list`.
|
||||
|
||||
There is also a cargo alias `cargo bb` which simplifies the command line.
|
||||
The example command above then becomes: `cargo bb list`.
|
||||
|
||||
### Commands
|
||||
|
||||
#### List benches and backends
|
||||
|
|
|
@ -86,6 +86,8 @@ enum BackendValues {
|
|||
CudaJit,
|
||||
#[strum(to_string = "cuda-jit-fusion")]
|
||||
CudaJitFusion,
|
||||
#[strum(to_string = "hip-jit")]
|
||||
HipJit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
|
||||
|
|
|
@ -65,6 +65,8 @@ macro_rules! bench_on_backend {
|
|||
let feature_name = "cuda-jit";
|
||||
#[cfg(feature = "cuda-jit-fusion")]
|
||||
let feature_name = "cuda-jit-fusion";
|
||||
#[cfg(feature = "hip-jit")]
|
||||
let feature_name = "hip-jit";
|
||||
|
||||
#[cfg(any(feature = "wgpu"))]
|
||||
{
|
||||
|
@ -146,6 +148,13 @@ macro_rules! bench_on_backend {
|
|||
|
||||
bench::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
|
||||
}
|
||||
|
||||
#[cfg(feature = "hip-jit")]
|
||||
{
|
||||
use burn::backend::hip_jit::{Hip, HipDevice};
|
||||
|
||||
bench::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ default = [
|
|||
"burn-tensor/default",
|
||||
"burn-wgpu?/default",
|
||||
"burn-cuda?/default",
|
||||
"burn-hip?/default",
|
||||
"burn-autodiff?/default",
|
||||
]
|
||||
doc = [
|
||||
|
@ -35,6 +36,7 @@ doc = [
|
|||
"tch",
|
||||
"wgpu",
|
||||
"cuda-jit",
|
||||
"hip-jit",
|
||||
"vision",
|
||||
"autodiff",
|
||||
# Doc features
|
||||
|
@ -46,6 +48,7 @@ doc = [
|
|||
"burn-tensor/doc",
|
||||
"burn-wgpu/doc",
|
||||
"burn-cuda/doc",
|
||||
"burn-hip/doc",
|
||||
]
|
||||
network = ["burn-common/network"]
|
||||
sqlite = ["burn-dataset?/sqlite"]
|
||||
|
@ -59,6 +62,7 @@ std = [
|
|||
"burn-tensor/std",
|
||||
"burn-wgpu?/std",
|
||||
"burn-cuda?/std",
|
||||
"burn-hip?/std",
|
||||
"flate2",
|
||||
"half/std",
|
||||
"log",
|
||||
|
@ -86,6 +90,7 @@ template = ["burn-wgpu?/template"]
|
|||
candle = ["burn-candle"]
|
||||
candle-cuda = ["candle", "burn-candle/cuda"]
|
||||
cuda-jit = ["burn-cuda"]
|
||||
hip-jit = ["burn-hip"]
|
||||
ndarray = ["burn-ndarray"]
|
||||
tch = ["burn-tch"]
|
||||
wgpu = ["burn-wgpu"]
|
||||
|
@ -101,6 +106,7 @@ experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
|||
record-backward-compat = []
|
||||
|
||||
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
|
||||
test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
|
||||
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
|
||||
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
|
||||
test-wgpu-spirv = [
|
||||
|
@ -121,6 +127,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features =
|
|||
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true }
|
||||
burn-candle = { path = "../burn-candle", version = "0.15.0", optional = true }
|
||||
burn-cuda = { path = "../burn-cuda", version = "0.15.0", optional = true, default-features = false }
|
||||
burn-hip = { path = "../burn-hip", version = "0.15.0", optional = true, default-features = false }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0", optional = true, default-features = false }
|
||||
burn-tch = { path = "../burn-tch", version = "0.15.0", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "0.15.0", optional = true, default-features = false }
|
||||
|
|
|
@ -28,6 +28,9 @@ pub use burn_candle as candle;
|
|||
#[cfg(feature = "candle")]
|
||||
pub use burn_candle::Candle;
|
||||
|
||||
#[cfg(feature = "hip-jit")]
|
||||
pub use burn_hip as hip_jit;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub use burn_tch as libtorch;
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "ROCm HIP backend for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "gpu", "rocm", "hip"]
|
||||
license.workspace = true
|
||||
name = "burn-hip"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-hip"
|
||||
documentation = "https://docs.rs/burn-hip"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["fusion", "burn-jit/default", "cubecl/default"]
|
||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||
autotune = ["burn-jit/autotune"]
|
||||
doc = ["burn-jit/doc"]
|
||||
std = ["burn-jit/std", "cubecl/std"]
|
||||
|
||||
[dependencies]
|
||||
cubecl = { workspace = true, features = ["hip"] }
|
||||
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl-hip"] }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
|
||||
|
||||
half = { workspace = true }
|
||||
bytemuck = { workspace = true }
|
||||
|
||||
log = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
|
@ -0,0 +1,7 @@
|
|||
# Burn-hip
|
||||
|
||||
Backend using ROCm HIP runtime.
|
||||
|
||||
To execute the tests for this backend set an environment variable called `ROCM_PATH` or `CUBECL_ROCM_PATH` to the installation path of ROCm. It is often `/opt/rocm`.
|
||||
|
||||
For now this backend requires the version `6.2.2` of ROCm or a compatible version.
|
|
@ -0,0 +1,22 @@
|
|||
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
use burn_jit::JitBackend;
|
||||
pub use cubecl::hip::HipDevice;
|
||||
use cubecl::hip::HipRuntime;
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
pub type Hip<F = f32, I = i32> = JitBackend<HipRuntime, F, I>;
|
||||
|
||||
#[cfg(feature = "fusion")]
|
||||
pub type Hip<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<HipRuntime, F, I>>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_jit::JitBackend;
|
||||
|
||||
pub type TestRuntime = cubecl::hip::HipRuntime;
|
||||
|
||||
burn_jit::testgen_all!();
|
||||
}
|
|
@ -28,6 +28,7 @@ repr = []
|
|||
cubecl = ["dep:cubecl"]
|
||||
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
|
||||
cubecl-cuda = ["cubecl", "cubecl/cuda"]
|
||||
cubecl-hip = ["cubecl", "cubecl/hip"]
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.15.0", default-features = false }
|
||||
|
|
|
@ -89,3 +89,15 @@ mod cube_cuda {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cubecl-hip")]
|
||||
mod cube_hip {
|
||||
use crate::backend::{DeviceId, DeviceOps};
|
||||
use cubecl::hip::HipDevice;
|
||||
|
||||
impl DeviceOps for HipDevice {
|
||||
fn id(&self) -> DeviceId {
|
||||
DeviceId::new(0, self.index as u32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ template = ["burn-core/template"]
|
|||
|
||||
candle = ["burn-core/candle"]
|
||||
cuda-jit = ["burn-core/cuda-jit"]
|
||||
hip-jit = ["burn-core/hip-jit"]
|
||||
ndarray = ["burn-core/ndarray"]
|
||||
tch = ["burn-core/tch"]
|
||||
wgpu = ["burn-core/wgpu"]
|
||||
|
|
|
@ -50,8 +50,11 @@ pub(crate) fn handle_command(
|
|||
ExecutionEnvironment::Std => {
|
||||
if args.ci {
|
||||
// Exclude crates that are not supported on CI
|
||||
args.exclude
|
||||
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
|
||||
args.exclude.extend(vec![
|
||||
"burn-cuda".to_string(),
|
||||
"burn-hip".to_string(),
|
||||
"burn-tch".to_string(),
|
||||
]);
|
||||
if std::env::var("DISABLE_WGPU").is_ok() {
|
||||
args.exclude.extend(vec!["burn-wgpu".to_string()]);
|
||||
};
|
||||
|
|
|
@ -2,7 +2,8 @@ use tracel_xtask::prelude::*;
|
|||
|
||||
pub(crate) fn handle_command(mut args: DocCmdArgs) -> anyhow::Result<()> {
|
||||
if args.get_command() == DocSubCommand::Build {
|
||||
args.exclude.push("burn-cuda".to_string());
|
||||
args.exclude
|
||||
.extend(vec!["burn-cuda".to_string(), "burn-hip".to_string()]);
|
||||
}
|
||||
|
||||
// Execute documentation command on workspace
|
||||
|
|
|
@ -34,8 +34,11 @@ pub(crate) fn handle_command(
|
|||
ExecutionEnvironment::Std => {
|
||||
if args.ci {
|
||||
// Exclude crates that are not supported on CI
|
||||
args.exclude
|
||||
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
|
||||
args.exclude.extend(vec![
|
||||
"burn-cuda".to_string(),
|
||||
"burn-hip".to_string(),
|
||||
"burn-tch".to_string(),
|
||||
]);
|
||||
}
|
||||
if std::env::var("DISABLE_WGPU").is_ok() {
|
||||
args.exclude.extend(vec!["burn-wgpu".to_string()]);
|
||||
|
|
Loading…
Reference in New Issue