Add burn-hip (#2399)

This commit is contained in:
Sylvain Benner 2024-10-24 16:38:56 -04:00 committed by GitHub
parent 09c65ca20f
commit 6678f79e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 163 additions and 20 deletions

View File

@ -1,2 +1,3 @@
[alias]
bb = "run --release --bin burnbench --"
xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"

View File

@ -126,6 +126,19 @@ jobs:
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-hip:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
- publish-burn-autodiff
- publish-burn-ndarray
- publish-burn-common
- publish-burn-jit
with:
crate: burn-hip
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-candle:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:

38
Cargo.lock generated
View File

@ -628,6 +628,20 @@ dependencies = [
"spin",
]
[[package]]
name = "burn-hip"
version = "0.15.0"
dependencies = [
"burn-fusion",
"burn-jit",
"burn-tensor",
"bytemuck",
"cubecl",
"derive-new",
"half",
"log",
]
[[package]]
name = "burn-import"
version = "0.15.0"
@ -1434,7 +1448,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1447,7 +1461,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"derive-new",
"embassy-futures",
@ -1463,7 +1477,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1480,7 +1494,7 @@ dependencies = [
[[package]]
name = "cubecl-cpp"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1494,7 +1508,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1510,7 +1524,7 @@ dependencies = [
[[package]]
name = "cubecl-hip"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1535,7 +1549,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1546,7 +1560,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"cubecl-common",
"darling",
@ -1561,7 +1575,7 @@ dependencies = [
[[package]]
name = "cubecl-opt"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"cubecl-common",
"cubecl-core",
@ -1576,7 +1590,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"async-channel",
"cfg_aliases 0.2.1",
@ -1596,7 +1610,7 @@ dependencies = [
[[package]]
name = "cubecl-spirv"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"cubecl-common",
"cubecl-core",
@ -1609,7 +1623,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
source = "git+https://github.com/tracel-ai/cubecl?rev=fb2a5c87802c9d01801f16d93cada5fe924989f3#fb2a5c87802c9d01801f16d93cada5fe924989f3"
dependencies = [
"ash",
"async-channel",

View File

@ -152,8 +152,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e3fdc96ec2d68dcdde8135bd011907b4e662388c" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e3fdc96ec2d68dcdde8135bd011907b4e662388c" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb2a5c87802c9d01801f16d93cada5fe924989f3" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb2a5c87802c9d01801f16d93cada5fe924989f3" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

View File

@ -17,6 +17,7 @@ candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
hip-jit = ["burn/hip-jit"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]

View File

@ -6,7 +6,7 @@ to complex models.
## burnbench CLI
This crate comes with a CLI binary called `burnbench` which can be executed via
`cargo run --release --bin burnbench`.
`cargo run --release --bin burnbench`.
Note that you need to run the `release` target of `burnbench` otherwise you won't
be able to share your benchmark results.
@ -15,6 +15,9 @@ The end of options argument `--` is used to pass arguments to the `burnbench`
application. For instance `cargo run --bin burnbench -- list` passes the `list`
argument to `burnbench` effectively calling `burnbench list`.
There is also a cargo alias `cargo bb` which simplifies the command line.
The example command above then becomes: `cargo bb list`.
### Commands
#### List benches and backends

View File

@ -86,6 +86,8 @@ enum BackendValues {
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
#[strum(to_string = "hip-jit")]
HipJit,
}
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]

View File

@ -65,6 +65,8 @@ macro_rules! bench_on_backend {
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "hip-jit")]
let feature_name = "hip-jit";
#[cfg(any(feature = "wgpu"))]
{
@ -146,6 +148,13 @@ macro_rules! bench_on_backend {
bench::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}
#[cfg(feature = "hip-jit")]
{
use burn::backend::hip_jit::{Hip, HipDevice};
bench::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
};
}

View File

@ -23,6 +23,7 @@ default = [
"burn-tensor/default",
"burn-wgpu?/default",
"burn-cuda?/default",
"burn-hip?/default",
"burn-autodiff?/default",
]
doc = [
@ -35,6 +36,7 @@ doc = [
"tch",
"wgpu",
"cuda-jit",
"hip-jit",
"vision",
"autodiff",
# Doc features
@ -46,6 +48,7 @@ doc = [
"burn-tensor/doc",
"burn-wgpu/doc",
"burn-cuda/doc",
"burn-hip/doc",
]
network = ["burn-common/network"]
sqlite = ["burn-dataset?/sqlite"]
@ -59,6 +62,7 @@ std = [
"burn-tensor/std",
"burn-wgpu?/std",
"burn-cuda?/std",
"burn-hip?/std",
"flate2",
"half/std",
"log",
@ -86,6 +90,7 @@ template = ["burn-wgpu?/template"]
candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
cuda-jit = ["burn-cuda"]
hip-jit = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
@ -101,6 +106,7 @@ experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
record-backward-compat = []
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
@ -121,6 +127,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features =
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true }
burn-candle = { path = "../burn-candle", version = "0.15.0", optional = true }
burn-cuda = { path = "../burn-cuda", version = "0.15.0", optional = true, default-features = false }
burn-hip = { path = "../burn-hip", version = "0.15.0", optional = true, default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "0.15.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.15.0", optional = true, default-features = false }

View File

@ -28,6 +28,9 @@ pub use burn_candle as candle;
#[cfg(feature = "candle")]
pub use burn_candle::Candle;
#[cfg(feature = "hip-jit")]
pub use burn_hip as hip_jit;
#[cfg(feature = "tch")]
pub use burn_tch as libtorch;

View File

@ -0,0 +1,40 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "ROCm HIP backend for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu", "rocm", "hip"]
license.workspace = true
name = "burn-hip"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-hip"
documentation = "https://docs.rs/burn-hip"
version.workspace = true
[features]
default = ["fusion", "burn-jit/default", "cubecl/default"]
fusion = ["burn-fusion", "burn-jit/fusion"]
autotune = ["burn-jit/autotune"]
doc = ["burn-jit/doc"]
std = ["burn-jit/std", "cubecl/std"]
[dependencies]
cubecl = { workspace = true, features = ["hip"] }
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl-hip"] }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
half = { workspace = true }
bytemuck = { workspace = true }
log = { workspace = true }
derive-new = { workspace = true }
[dev-dependencies]
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false, features = [
"export_tests",
] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@ -0,0 +1,7 @@
# Burn-hip
Backend using ROCm HIP runtime.
To execute the tests for this backend set an environment variable called `ROCM_PATH` or `CUBECL_ROCM_PATH` to the installation path of ROCm. It is often `/opt/rocm`.
For now this backend requires the version `6.2.2` of ROCm or a compatible version.

View File

@ -0,0 +1,22 @@
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
extern crate alloc;
use burn_jit::JitBackend;
pub use cubecl::hip::HipDevice;
use cubecl::hip::HipRuntime;
#[cfg(not(feature = "fusion"))]
pub type Hip<F = f32, I = i32> = JitBackend<HipRuntime, F, I>;
#[cfg(feature = "fusion")]
pub type Hip<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<HipRuntime, F, I>>;
#[cfg(test)]
mod tests {
use burn_jit::JitBackend;
pub type TestRuntime = cubecl::hip::HipRuntime;
burn_jit::testgen_all!();
}

View File

@ -28,6 +28,7 @@ repr = []
cubecl = ["dep:cubecl"]
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
cubecl-cuda = ["cubecl", "cubecl/cuda"]
cubecl-hip = ["cubecl", "cubecl/hip"]
[dependencies]
burn-common = { path = "../burn-common", version = "0.15.0", default-features = false }

View File

@ -89,3 +89,15 @@ mod cube_cuda {
}
}
}
#[cfg(feature = "cubecl-hip")]
mod cube_hip {
use crate::backend::{DeviceId, DeviceOps};
use cubecl::hip::HipDevice;
impl DeviceOps for HipDevice {
fn id(&self) -> DeviceId {
DeviceId::new(0, self.index as u32)
}
}
}

View File

@ -50,6 +50,7 @@ template = ["burn-core/template"]
candle = ["burn-core/candle"]
cuda-jit = ["burn-core/cuda-jit"]
hip-jit = ["burn-core/hip-jit"]
ndarray = ["burn-core/ndarray"]
tch = ["burn-core/tch"]
wgpu = ["burn-core/wgpu"]

View File

@ -50,8 +50,11 @@ pub(crate) fn handle_command(
ExecutionEnvironment::Std => {
if args.ci {
// Exclude crates that are not supported on CI
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
args.exclude.extend(vec![
"burn-cuda".to_string(),
"burn-hip".to_string(),
"burn-tch".to_string(),
]);
if std::env::var("DISABLE_WGPU").is_ok() {
args.exclude.extend(vec!["burn-wgpu".to_string()]);
};

View File

@ -2,7 +2,8 @@ use tracel_xtask::prelude::*;
pub(crate) fn handle_command(mut args: DocCmdArgs) -> anyhow::Result<()> {
if args.get_command() == DocSubCommand::Build {
args.exclude.push("burn-cuda".to_string());
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-hip".to_string()]);
}
// Execute documentation command on workspace

View File

@ -34,8 +34,11 @@ pub(crate) fn handle_command(
ExecutionEnvironment::Std => {
if args.ci {
// Exclude crates that are not supported on CI
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
args.exclude.extend(vec![
"burn-cuda".to_string(),
"burn-hip".to_string(),
"burn-tch".to_string(),
]);
}
if std::env::var("DISABLE_WGPU").is_ok() {
args.exclude.extend(vec!["burn-wgpu".to_string()]);