Bump cube version (#2447)

This commit is contained in:
Arthur Brussee 2024-10-31 22:15:53 +00:00 committed by GitHub
parent 8e466d7ce1
commit 9eb25780ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 16 deletions

25
Cargo.lock generated
View File

@ -1460,7 +1460,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1491,7 +1491,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"derive-new",
"embassy-futures",
@ -1508,7 +1508,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
@ -1525,7 +1525,7 @@ dependencies = [
[[package]]
name = "cubecl-cpp"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
@ -1539,7 +1539,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
@ -1555,7 +1555,7 @@ dependencies = [
[[package]]
name = "cubecl-hip"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"bytemuck",
"cubecl-common 0.4.0",
@ -1580,7 +1580,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1591,7 +1591,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"cubecl-common 0.4.0",
"darling",
@ -1606,7 +1606,7 @@ dependencies = [
[[package]]
name = "cubecl-opt"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"cubecl-common 0.4.0",
"cubecl-core",
@ -1643,7 +1643,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"async-channel",
"async-lock",
@ -1664,7 +1664,7 @@ dependencies = [
[[package]]
name = "cubecl-spirv"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"cubecl-common 0.4.0",
"cubecl-core",
@ -1677,11 +1677,12 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
source = "git+https://github.com/tracel-ai/cubecl?rev=0dff475fec254e884f6b82e305e7a52adebf1dd7#0dff475fec254e884f6b82e305e7a52adebf1dd7"
dependencies = [
"ash",
"async-channel",
"bytemuck",
"cfg-if",
"cfg_aliases 0.2.1",
"cubecl-common 0.4.0",
"cubecl-core",

View File

@ -152,8 +152,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

View File

@ -9,7 +9,7 @@ use core::convert::Into;
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel};
use burn::{
backend::{wgpu::init_device, NdArray},
backend::{wgpu::init_setup_async, NdArray},
prelude::*,
tensor::activation::softmax,
};
@ -110,7 +110,7 @@ impl ImageClassifier {
log::info!("Loading the model to the Wgpu backend");
let start = Instant::now();
let device = WgpuDevice::default();
init_device::<AutoGraphicsApi>(&device, Default::default()).await;
init_setup_async::<AutoGraphicsApi>(&device, Default::default()).await;
self.model = ModelType::WithWgpuBackend(Model::new(&device));
let duration = start.elapsed();
log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);