Update cubecl (#2376)

This commit is contained in:
Nathaniel Simard 2024-10-16 15:56:12 -04:00 committed by GitHub
parent 3d77efc305
commit 353120ea5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 338 additions and 456 deletions

603
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -94,7 +94,7 @@ zip = "2.2.0"
# Async handling
async-channel = "2.3"
pollster = "0.3"
futures-lite = { version = "2.3.0", default-features = false }
# Terminal UI
crossterm = "0.27.0"
@ -152,11 +152,11 @@ tch = "0.15.0"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb47090f7a44952ae3e3b2b72f8c5a88d8af56fd" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "fb47090f7a44952ae3e3b2b72f8c5a88d8af56fd" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ed136e2385b17e36680589f8a6245926f430f59f" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ed136e2385b17e36680589f8a6245926f430f59f" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version="0.2.0", default-features = false }
# cubecl-common = { version="0.2.0", default-features = false }

View File

@ -16,6 +16,7 @@ candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]

View File

@ -7,10 +7,7 @@ use burn::{
Distribution, Tensor,
},
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct AutodiffOverheadBenchmark<B: AutodiffBackend> {
config: nn::LstmConfig,
@ -50,7 +47,7 @@ impl<B: AutodiffBackend> Benchmark for AutodiffOverheadBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct BinaryBenchmark<B: Backend, const D: usize> {
shape: Shape,
@ -33,7 +30,7 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait);
B::sync(&self.device);
}
}

View File

@ -2,10 +2,7 @@ use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct Conv2dBenchmark<B: Backend> {
input_shape: Shape,
@ -51,7 +48,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -2,10 +2,7 @@ use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv3d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct Conv3dBenchmark<B: Backend> {
input_shape: Shape,
@ -51,7 +48,7 @@ impl<B: Backend> Benchmark for Conv3dBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -3,10 +3,7 @@ use burn::tensor::{
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct ConvTranspose2dBenchmark<B: Backend> {
input_shape: Shape,
@ -52,7 +49,7 @@ impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -3,10 +3,7 @@ use burn::tensor::{
backend::Backend, module::conv_transpose3d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct ConvTranspose3dBenchmark<B: Backend> {
input_shape: Shape,
@ -52,7 +49,7 @@ impl<B: Backend> Benchmark for ConvTranspose3dBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -2,7 +2,6 @@ use backend_comparison::persistence::save;
use burn::backend::Autodiff;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::sync_type::SyncType;
use core::f64::consts::SQRT_2;
use derive_new::new;
@ -69,7 +68,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
fn num_samples(&self) -> usize {

View File

@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor, TensorData};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
@ -32,7 +29,7 @@ impl<B: Backend, const D: usize> Benchmark for ToDataBenchmark<B, D> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}
@ -69,7 +66,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -3,7 +3,6 @@ use burn::tensor::backend::Backend;
use burn::tensor::Device;
use burn::{config::Config, module::Module, nn};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::sync_type::SyncType;
use derive_new::new;
#[derive(Module, Debug)]
@ -94,7 +93,7 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
@ -40,7 +37,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
pub struct MaxPool2dBenchmark<B: Backend> {
shape: Shape,
@ -40,7 +37,7 @@ impl<B: Backend> Benchmark for MaxPool2dBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -1,7 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use cubecl::client::SyncType;
// Files retrieved during build to avoid reimplementing ResNet for benchmarks
mod block {
@ -42,7 +41,7 @@ impl<B: Backend> Benchmark for ResNetBenchmark<B> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -1,9 +1,6 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
@ -33,7 +30,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
B::sync(&self.device)
}
}

View File

@ -80,6 +80,8 @@ enum BackendValues {
WgpuFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
}
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]

View File

@ -59,6 +59,8 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "wgpu")]
{

View File

@ -5,7 +5,6 @@ use crate::{
tensor::AutodiffTensor,
AutodiffBridge,
};
use burn_common::sync_type::SyncType;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::{BoolTensor, IntTensor, QuantizedTensor},
@ -50,8 +49,8 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
B::seed(seed)
}
fn sync(device: &B::Device, sync_type: SyncType) {
B::sync(device, sync_type)
fn sync(device: &B::Device) {
B::sync(device)
}
}

View File

@ -1,7 +1,7 @@
use std::marker::PhantomData;
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps, SyncType},
backend::{Backend, DeviceId, DeviceOps},
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
@ -187,25 +187,20 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
panic!("Manual seed not supported by Candle. ")
}
fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (device.clone()).into();
fn sync(device: &Device<Self>) {
let device: candle_core::Device = (device.clone()).into();
match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
SyncType::Flush => (), // Nothhing to flush.
};
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
}
}

View File

@ -3,7 +3,7 @@ use crate::{
QFusionTensor,
};
use burn_tensor::{
backend::{Backend, DeviceOps, SyncType},
backend::{Backend, DeviceOps},
ops::FloatTensor,
repr::{OperationDescription, ReprBackend},
Device,
@ -50,10 +50,10 @@ impl<B: FusionBackend> Backend for Fusion<B> {
B::seed(seed);
}
fn sync(device: &Self::Device, sync_type: SyncType) {
fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
client.drain();
B::sync(device, sync_type);
B::sync(device);
}
fn ad_enabled() -> bool {

View File

@ -43,6 +43,9 @@ num-traits = { workspace = true }
rand = { workspace = true }
spin = { workspace = true }
# Async
futures-lite = { workspace = true, features = ["std"] }
# Template
serde = { workspace = true }
text_placeholder = { workspace = true, features = ["struct_context"] }

View File

@ -2,7 +2,7 @@ use crate::{
tensor::{JitTensor, QJitTensor},
FloatElement, IntElement, JitRuntime, PrecisionBridge,
};
use burn_tensor::backend::{Backend, DeviceOps, SyncType};
use burn_tensor::backend::{Backend, DeviceOps};
use cubecl::server::ComputeServer;
use rand::{rngs::StdRng, SeedableRng};
use std::{marker::PhantomData, sync::Mutex};
@ -51,13 +51,9 @@ where
false
}
fn sync(device: &Self::Device, sync_type: SyncType) {
let sync = match sync_type {
SyncType::Flush => cubecl::client::SyncType::Flush,
SyncType::Wait => cubecl::client::SyncType::Wait,
};
fn sync(device: &Self::Device) {
let client = R::client(device);
client.sync(sync);
futures_lite::future::block_on(client.sync());
}
}

View File

@ -2,7 +2,7 @@ use crate::{PrecisionBridge, QuantElement, TchQTensor};
use super::element::TchElement;
use super::TchTensor;
use burn_tensor::backend::{Backend, DeviceId, DeviceOps, SyncType};
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
use burn_tensor::ops::IntTensorOps;
use burn_tensor::{Int, Tensor};
@ -118,20 +118,19 @@ impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
"tch".to_string()
}
fn sync(device: &Self::Device, sync_type: SyncType) {
if sync_type == SyncType::Wait {
match device {
LibTorchDevice::Cpu => (),
LibTorchDevice::Cuda(index) => {
tch::Cuda::synchronize(*index as i64);
}
_ => {
// When there is no explicit way to synchronize, we write and read one value to sync
Tensor::<Self, 1, Int>::from_primitive(
<Self as IntTensorOps<Self>>::int_zeros([1].into(), device),
)
.into_data();
}
fn sync(device: &Self::Device) {
match device {
LibTorchDevice::Cpu => (),
LibTorchDevice::Cuda(index) => {
tch::Cuda::synchronize(*index as i64);
}
_ => {
// When there is no explicit way to synchronize, we write and read one value to sync
Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
[1].into(),
device,
))
.into_data();
}
}
}

View File

@ -1,5 +1,4 @@
use alloc::string::String;
pub use burn_common::sync_type::SyncType;
use crate::tensor::Element;
use crate::{ops::*, quantization::QTensorPrimitive};
@ -103,7 +102,7 @@ pub trait Backend:
fn seed(seed: u64);
/// Sync the backend, ensure that all computation are finished.
fn sync(_device: &Self::Device, _sync_type: SyncType) {}
fn sync(_device: &Self::Device) {}
}
/// Trait that allows a backend to support autodiff.

View File

@ -16,11 +16,7 @@ fn main() {
);
text_generation::training::train::<Backend, DbPediaDataset>(
if cfg!(target_os = "macos") {
burn::tensor::Device::<Backend>::Mps
} else {
burn::tensor::Device::<Backend>::Cuda(0)
},
Default::default(),
DbPediaDataset::train(),
DbPediaDataset::test(),
config,