mirror of https://github.com/tracel-ai/burn.git
Update cubecl (#2376)
This commit is contained in:
parent
3d77efc305
commit
353120ea5f
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
|
@ -94,7 +94,7 @@ zip = "2.2.0"
|
|||
|
||||
# Async handling
|
||||
async-channel = "2.3"
|
||||
pollster = "0.3"
|
||||
futures-lite = { version = "2.3.0", default-features = false }
|
||||
|
||||
# Terminal UI
|
||||
crossterm = "0.27.0"
|
||||
|
@ -152,11 +152,11 @@ tch = "0.15.0"
|
|||
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb47090f7a44952ae3e3b2b72f8c5a88d8af56fd" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb47090f7a44952ae3e3b2b72f8c5a88d8af56fd" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ed136e2385b17e36680589f8a6245926f430f59f" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ed136e2385b17e36680589f8a6245926f430f59f" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl" }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
||||
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
|
||||
### For the release. ###
|
||||
# cubecl = { version="0.2.0", default-features = false }
|
||||
# cubecl-common = { version="0.2.0", default-features = false }
|
||||
|
|
|
@ -16,6 +16,7 @@ candle-cpu = ["burn/candle"]
|
|||
candle-cuda = ["burn/candle-cuda"]
|
||||
candle-metal = ["burn/candle", "burn/metal"]
|
||||
cuda-jit = ["burn/cuda-jit"]
|
||||
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
|
||||
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
|
||||
|
|
|
@ -7,10 +7,7 @@ use burn::{
|
|||
Distribution, Tensor,
|
||||
},
|
||||
};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct AutodiffOverheadBenchmark<B: AutodiffBackend> {
|
||||
config: nn::LstmConfig,
|
||||
|
@ -50,7 +47,7 @@ impl<B: AutodiffBackend> Benchmark for AutodiffOverheadBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct BinaryBenchmark<B: Backend, const D: usize> {
|
||||
shape: Shape,
|
||||
|
@ -33,7 +30,7 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait);
|
||||
B::sync(&self.device);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,10 +2,7 @@ use backend_comparison::persistence::save;
|
|||
use burn::tensor::{
|
||||
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
|
||||
};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct Conv2dBenchmark<B: Backend> {
|
||||
input_shape: Shape,
|
||||
|
@ -51,7 +48,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,10 +2,7 @@ use backend_comparison::persistence::save;
|
|||
use burn::tensor::{
|
||||
backend::Backend, module::conv3d, ops::ConvOptions, Distribution, Shape, Tensor,
|
||||
};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct Conv3dBenchmark<B: Backend> {
|
||||
input_shape: Shape,
|
||||
|
@ -51,7 +48,7 @@ impl<B: Backend> Benchmark for Conv3dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,7 @@ use burn::tensor::{
|
|||
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
|
||||
Tensor,
|
||||
};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct ConvTranspose2dBenchmark<B: Backend> {
|
||||
input_shape: Shape,
|
||||
|
@ -52,7 +49,7 @@ impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,7 @@ use burn::tensor::{
|
|||
backend::Backend, module::conv_transpose3d, ops::ConvTransposeOptions, Distribution, Shape,
|
||||
Tensor,
|
||||
};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct ConvTranspose3dBenchmark<B: Backend> {
|
||||
input_shape: Shape,
|
||||
|
@ -52,7 +49,7 @@ impl<B: Backend> Benchmark for ConvTranspose3dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ use backend_comparison::persistence::save;
|
|||
use burn::backend::Autodiff;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::sync_type::SyncType;
|
||||
use core::f64::consts::SQRT_2;
|
||||
use derive_new::new;
|
||||
|
||||
|
@ -69,7 +68,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor, TensorData};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -32,7 +29,7 @@ impl<B: Backend, const D: usize> Benchmark for ToDataBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,7 +66,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ use burn::tensor::backend::Backend;
|
|||
use burn::tensor::Device;
|
||||
use burn::{config::Config, module::Module, nn};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::sync_type::SyncType;
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
@ -94,7 +93,7 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -40,7 +37,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
|
||||
pub struct MaxPool2dBenchmark<B: Backend> {
|
||||
shape: Shape,
|
||||
|
@ -40,7 +37,7 @@ impl<B: Backend> Benchmark for MaxPool2dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use cubecl::client::SyncType;
|
||||
|
||||
// Files retrieved during build to avoid reimplementing ResNet for benchmarks
|
||||
mod block {
|
||||
|
@ -42,7 +41,7 @@ impl<B: Backend> Benchmark for ResNetBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -33,7 +30,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -80,6 +80,8 @@ enum BackendValues {
|
|||
WgpuFusion,
|
||||
#[strum(to_string = "cuda-jit")]
|
||||
CudaJit,
|
||||
#[strum(to_string = "cuda-jit-fusion")]
|
||||
CudaJitFusion,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
|
||||
|
|
|
@ -59,6 +59,8 @@ macro_rules! bench_on_backend {
|
|||
let feature_name = "wgpu-fusion";
|
||||
#[cfg(feature = "cuda-jit")]
|
||||
let feature_name = "cuda-jit";
|
||||
#[cfg(feature = "cuda-jit-fusion")]
|
||||
let feature_name = "cuda-jit-fusion";
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
{
|
||||
|
|
|
@ -5,7 +5,6 @@ use crate::{
|
|||
tensor::AutodiffTensor,
|
||||
AutodiffBridge,
|
||||
};
|
||||
use burn_common::sync_type::SyncType;
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
ops::{BoolTensor, IntTensor, QuantizedTensor},
|
||||
|
@ -50,8 +49,8 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
|||
B::seed(seed)
|
||||
}
|
||||
|
||||
fn sync(device: &B::Device, sync_type: SyncType) {
|
||||
B::sync(device, sync_type)
|
||||
fn sync(device: &B::Device) {
|
||||
B::sync(device)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceId, DeviceOps, SyncType},
|
||||
backend::{Backend, DeviceId, DeviceOps},
|
||||
quantization::{QTensorPrimitive, QuantizationStrategy},
|
||||
Device,
|
||||
};
|
||||
|
@ -187,25 +187,20 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
panic!("Manual seed not supported by Candle. ")
|
||||
}
|
||||
|
||||
fn sync(device: &Device<Self>, sync_type: SyncType) {
|
||||
match sync_type {
|
||||
SyncType::Wait => {
|
||||
let device: candle_core::Device = (device.clone()).into();
|
||||
fn sync(device: &Device<Self>) {
|
||||
let device: candle_core::Device = (device.clone()).into();
|
||||
|
||||
match device {
|
||||
candle_core::Device::Cpu => (),
|
||||
candle_core::Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
device.synchronize().unwrap();
|
||||
}
|
||||
candle_core::Device::Metal(device) => {
|
||||
// For some reason, device.wait_until_completed() does not seem to work,
|
||||
// and neither does writing and reading a value with into_data
|
||||
panic!("Device synchronization unavailable with Metal device on Candle backend")
|
||||
}
|
||||
}
|
||||
match device {
|
||||
candle_core::Device::Cpu => (),
|
||||
candle_core::Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
device.synchronize().unwrap();
|
||||
}
|
||||
SyncType::Flush => (), // Nothhing to flush.
|
||||
};
|
||||
candle_core::Device::Metal(device) => {
|
||||
// For some reason, device.wait_until_completed() does not seem to work,
|
||||
// and neither does writing and reading a value with into_data
|
||||
panic!("Device synchronization unavailable with Metal device on Candle backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{
|
|||
QFusionTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceOps, SyncType},
|
||||
backend::{Backend, DeviceOps},
|
||||
ops::FloatTensor,
|
||||
repr::{OperationDescription, ReprBackend},
|
||||
Device,
|
||||
|
@ -50,10 +50,10 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
B::seed(seed);
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device, sync_type: SyncType) {
|
||||
fn sync(device: &Self::Device) {
|
||||
let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
|
||||
client.drain();
|
||||
B::sync(device, sync_type);
|
||||
B::sync(device);
|
||||
}
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
|
|
|
@ -43,6 +43,9 @@ num-traits = { workspace = true }
|
|||
rand = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
|
||||
# Async
|
||||
futures-lite = { workspace = true, features = ["std"] }
|
||||
|
||||
# Template
|
||||
serde = { workspace = true }
|
||||
text_placeholder = { workspace = true, features = ["struct_context"] }
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
tensor::{JitTensor, QJitTensor},
|
||||
FloatElement, IntElement, JitRuntime, PrecisionBridge,
|
||||
};
|
||||
use burn_tensor::backend::{Backend, DeviceOps, SyncType};
|
||||
use burn_tensor::backend::{Backend, DeviceOps};
|
||||
use cubecl::server::ComputeServer;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::{marker::PhantomData, sync::Mutex};
|
||||
|
@ -51,13 +51,9 @@ where
|
|||
false
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device, sync_type: SyncType) {
|
||||
let sync = match sync_type {
|
||||
SyncType::Flush => cubecl::client::SyncType::Flush,
|
||||
SyncType::Wait => cubecl::client::SyncType::Wait,
|
||||
};
|
||||
fn sync(device: &Self::Device) {
|
||||
let client = R::client(device);
|
||||
client.sync(sync);
|
||||
futures_lite::future::block_on(client.sync());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{PrecisionBridge, QuantElement, TchQTensor};
|
|||
|
||||
use super::element::TchElement;
|
||||
use super::TchTensor;
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps, SyncType};
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
|
||||
use burn_tensor::ops::IntTensorOps;
|
||||
use burn_tensor::{Int, Tensor};
|
||||
|
||||
|
@ -118,20 +118,19 @@ impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
|
|||
"tch".to_string()
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device, sync_type: SyncType) {
|
||||
if sync_type == SyncType::Wait {
|
||||
match device {
|
||||
LibTorchDevice::Cpu => (),
|
||||
LibTorchDevice::Cuda(index) => {
|
||||
tch::Cuda::synchronize(*index as i64);
|
||||
}
|
||||
_ => {
|
||||
// When there is no explicit way to synchronize, we write and read one value to sync
|
||||
Tensor::<Self, 1, Int>::from_primitive(
|
||||
<Self as IntTensorOps<Self>>::int_zeros([1].into(), device),
|
||||
)
|
||||
.into_data();
|
||||
}
|
||||
fn sync(device: &Self::Device) {
|
||||
match device {
|
||||
LibTorchDevice::Cpu => (),
|
||||
LibTorchDevice::Cuda(index) => {
|
||||
tch::Cuda::synchronize(*index as i64);
|
||||
}
|
||||
_ => {
|
||||
// When there is no explicit way to synchronize, we write and read one value to sync
|
||||
Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
|
||||
[1].into(),
|
||||
device,
|
||||
))
|
||||
.into_data();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use alloc::string::String;
|
||||
pub use burn_common::sync_type::SyncType;
|
||||
|
||||
use crate::tensor::Element;
|
||||
use crate::{ops::*, quantization::QTensorPrimitive};
|
||||
|
@ -103,7 +102,7 @@ pub trait Backend:
|
|||
fn seed(seed: u64);
|
||||
|
||||
/// Sync the backend, ensure that all computation are finished.
|
||||
fn sync(_device: &Self::Device, _sync_type: SyncType) {}
|
||||
fn sync(_device: &Self::Device) {}
|
||||
}
|
||||
|
||||
/// Trait that allows a backend to support autodiff.
|
||||
|
|
|
@ -16,11 +16,7 @@ fn main() {
|
|||
);
|
||||
|
||||
text_generation::training::train::<Backend, DbPediaDataset>(
|
||||
if cfg!(target_os = "macos") {
|
||||
burn::tensor::Device::<Backend>::Mps
|
||||
} else {
|
||||
burn::tensor::Device::<Backend>::Cuda(0)
|
||||
},
|
||||
Default::default(),
|
||||
DbPediaDataset::train(),
|
||||
DbPediaDataset::test(),
|
||||
config,
|
||||
|
|
Loading…
Reference in New Issue