refactor following cube changes

This commit is contained in:
louisfd 2024-07-29 11:05:18 -04:00
parent 096ec13c48
commit 4c025aa95c
12 changed files with 88 additions and 155 deletions

16
Cargo.lock generated
View File

@ -1303,7 +1303,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl" name = "cubecl"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"cubecl-core", "cubecl-core",
"cubecl-cuda", "cubecl-cuda",
@ -1314,7 +1314,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-common" name = "cubecl-common"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"getrandom", "getrandom",
@ -1328,7 +1328,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-core" name = "cubecl-core"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-macros", "cubecl-macros",
@ -1343,7 +1343,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-cuda" name = "cubecl-cuda"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common",
@ -1358,7 +1358,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-linalg" name = "cubecl-linalg"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-core", "cubecl-core",
@ -1369,7 +1369,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-macros" name = "cubecl-macros"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"proc-macro2", "proc-macro2",
@ -1380,7 +1380,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-runtime" name = "cubecl-runtime"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"cubecl-common", "cubecl-common",
@ -1399,7 +1399,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-wgpu" name = "cubecl-wgpu"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb" source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"bytemuck", "bytemuck",

View File

@ -67,7 +67,9 @@ rstest = "0.19.0"
rusqlite = { version = "0.31.0" } rusqlite = { version = "0.31.0" }
rust-format = { version = "0.3.4" } rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0" sanitize-filename = "0.5.0"
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std serde_bytes = { version = "0.11.15", default-features = false, features = [
"alloc",
] } # alloc for no_std
serde_rusqlite = "0.35.0" serde_rusqlite = "0.35.0"
serial_test = "3.1.1" serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
@ -148,5 +150,5 @@ cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl"
# cubecl-common = { path = "../cubecl/crates/cubecl-common" } # cubecl-common = { path = "../cubecl/crates/cubecl-common" }
[profile.dev] [profile.dev]
debug = 0 # Speed up compilation time and not necessary. debug = 0 # Speed up compilation time and not necessary.
opt-level = 2 opt-level = 2

View File

@ -30,11 +30,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
outputs: &[&TensorDescription], outputs: &[&TensorDescription],
stateful: bool, stateful: bool,
) -> FusionKernel<R> { ) -> FusionKernel<R> {
let cube_dim_x = self.cube_dim.x; assert_eq!(
let cube_dim_y = self.cube_dim.y; self.cube_dim.x, self.cube_dim.y,
"The grid must be a square"
assert_eq!(cube_dim_x, cube_dim_y, "The grid must be a square"); );
let cube_dim = cube_dim_x as usize;
let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4); let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4);
let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2); let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2);
@ -69,7 +68,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
let reference_tensor = inputs[settings.mappings[0].pos_input]; let reference_tensor = inputs[settings.mappings[0].pos_input];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape); let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim); let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
let output_infos = let output_infos =
inplace_output2input inplace_output2input
.iter() .iter()
@ -96,7 +95,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
false => { false => {
let reference_tensor = outputs[0]; let reference_tensor = outputs[0];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape); let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim); let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| { let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
let size = calculate_num_elems_dyn_rank(&tensor.shape) let size = calculate_num_elems_dyn_rank(&tensor.shape)
* self.info.outputs[pos].elem_size::<R>(); * self.info.outputs[pos].elem_size::<R>();

View File

@ -1,8 +1,6 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime}; use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::linalg::tensor::index_offset_with_layout; use cubecl::linalg::tensor::index_offset_with_layout;
use cubecl::{ use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, SUBCUBE_DIM_APPROX,
};
use cubecl::{ir::KernelDefinition, KernelSettings}; use cubecl::{ir::KernelDefinition, KernelSettings};
use std::any::TypeId; use std::any::TypeId;
@ -46,10 +44,10 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1); tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
let num_elems: usize = input.shape.num_elements(); let num_elems: usize = input.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize, let cube_dim = CubeDim::default();
SUBCUBE_DIM_APPROX, let cube_count =
); calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let client = input.client.clone(); let client = input.client.clone();
let handle = client.empty(num_elems * core::mem::size_of::<EO>()); let handle = client.empty(num_elems * core::mem::size_of::<EO>());
let output = let output =
@ -58,7 +56,7 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
cast_element::launch::<EI::Primitive, EO::Primitive, R>( cast_element::launch::<EI::Primitive, EO::Primitive, R>(
&client, &client,
cube_count, cube_count,
CubeDim::default(), cube_dim,
TensorArg::vectorized( TensorArg::vectorized(
vectorization_factor, vectorization_factor,
&input.handle, &input.handle,

View File

@ -1,12 +1,14 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime}; use crate::{tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{ use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};
cpa,
frontend::TensorHandleRef, #[cube(launch)]
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility}, fn bool_cast_kernel<T: Numeric>(input: &Tensor<UInt>, output: &mut Tensor<T>) {
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, if input[ABSOLUTE_POS] == UInt::new(1) {
OutputInfo, output[ABSOLUTE_POS] = T::from_int(1);
}; } else {
use std::marker::PhantomData; output[ABSOLUTE_POS] = T::from_int(0);
}
}
/// Cast a bool tensor to the given element type. /// Cast a bool tensor to the given element type.
/// ///
@ -17,7 +19,6 @@ use std::marker::PhantomData;
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>( pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>, tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, EO, D> { ) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::<R, EO>::new();
let num_elems = tensor.shape.num_elements(); let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>()); let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new_contiguous( let output = JitTensor::new_contiguous(
@ -27,86 +28,16 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
buffer, buffer,
); );
Execution::start(kernel, tensor.client) let cube_dim = CubeDim::default();
.inputs(&[TensorHandleRef::<R>::new( let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
&tensor.handle,
&tensor.strides, bool_cast_kernel::launch::<EO::Primitive, R>(
&tensor.shape.dims, &tensor.client,
)]) cube_count,
.outputs(&[TensorHandleRef::new( cube_dim,
&output.handle, TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
&output.strides, TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
&output.shape.dims, );
)])
.execute(CubeCountSettings::Output { pos: 0 });
output output
} }
pub(crate) struct BoolCastShader {
tensor: Variable,
output: Variable,
}
#[derive(new)]
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
}
impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_input = Item::new(Elem::Bool);
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
BoolCastShader { tensor, output }.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>()
}
}
impl BoolCastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::AbsolutePos;
let output = self.output;
let represents_true = scope.create_local(Elem::Bool);
cpa!(scope, represents_true = tensor[id]);
cpa!(scope, if(represents_true).then(|scope|{
cpa!(scope, output[id] = 1);
}).else(|scope|{
cpa!(scope, output[id] = 0);
}));
}
}

View File

@ -3,7 +3,7 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_tensor::Shape; use burn_tensor::Shape;
use cubecl::{ use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, Runtime, SUBCUBE_DIM_APPROX, tensor_vectorization_factor, Runtime,
}; };
#[cube] #[cube]
@ -139,17 +139,17 @@ pub(crate) fn launch_cmp<
let shape_out = Shape::new(shape_out); let shape_out = Shape::new(shape_out);
let client = lhs.client.clone(); let client = lhs.client.clone();
let num_elems = shape_out.num_elements(); let num_elems = shape_out.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize, let cube_dim = CubeDim::default();
SUBCUBE_DIM_APPROX, let cube_count =
); calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>(); let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && lhs.can_mut_broadcast(&rhs) { if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
kernel_cmp::launch::<E::Primitive, O, R>( kernel_cmp::launch::<E::Primitive, O, R>(
&client, &client,
cube_count, cube_count,
CubeDim::default(), cube_dim,
TensorArg::vectorized( TensorArg::vectorized(
vectorization_factor, vectorization_factor,
&lhs.handle, &lhs.handle,
@ -244,17 +244,17 @@ pub(crate) fn launch_scalar_cmp<
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1); tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
let client = tensor.client.clone(); let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements(); let num_elems = tensor.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize, let cube_dim = CubeDim::default();
SUBCUBE_DIM_APPROX, let cube_count =
); calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>(); let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && tensor.can_mut() { if same_tensor_type && tensor.can_mut() {
kernel_scalar_cmp::launch::<E::Primitive, O, R>( kernel_scalar_cmp::launch::<E::Primitive, O, R>(
&client, &client,
cube_count, cube_count,
CubeDim::default(), cube_dim,
TensorArg::vectorized( TensorArg::vectorized(
vectorization_factor, vectorization_factor,
&tensor.handle, &tensor.handle,

View File

@ -1,4 +1,4 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; use cubecl::{calculate_cube_count_elemwise, prelude::*};
use burn_tensor::{ use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions}, ops::{conv::calculate_conv_output_size, ConvOptions},
@ -161,12 +161,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
}; };
let num_elems_output = output.shape.num_elements(); let num_elems_output = output.shape.num_elements();
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX); let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);
conv2d_kernel::launch::<E::FloatPrimitive, R>( conv2d_kernel::launch::<E::FloatPrimitive, R>(
&input.client, &input.client,
cube_count,
cube_dim, cube_dim,
CubeDim::default(),
TensorArg::new(&input.handle, &input.strides, &input.shape.dims), TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims), TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims), TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),

View File

@ -1,4 +1,4 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; use cubecl::{calculate_cube_count_elemwise, prelude::*};
use burn_tensor::{ use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions}, ops::{conv::calculate_conv_output_size, ConvOptions},
@ -188,10 +188,13 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
} }
}; };
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
conv3d_kernel::launch::<E::FloatPrimitive, R>( conv3d_kernel::launch::<E::FloatPrimitive, R>(
&input.client, &input.client,
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX), cube_count,
CubeDim::default(), cube_dim,
TensorArg::new(&input.handle, &input.strides, &input.shape.dims), TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims), TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims), TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),

View File

@ -4,15 +4,15 @@ use crate::{
tensor::JitTensor, tensor::JitTensor,
JitRuntime, JitRuntime,
}; };
use cubecl::InputInfo;
use cubecl::{ use cubecl::{
calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings, calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings, CubeDim,
KernelExpansion, KernelIntegrator, KernelSettings, KernelExpansion, KernelIntegrator, KernelSettings,
}; };
use cubecl::{ use cubecl::{
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
Execution, Execution,
}; };
use cubecl::{InputInfo, SUBCUBE_DIM_APPROX};
use std::marker::PhantomData; use std::marker::PhantomData;
#[derive(new)] #[derive(new)]
@ -223,7 +223,8 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
// Fake strides of the virtual output where the strides of dim is hardcoded to one. // Fake strides of the virtual output where the strides of dim is hardcoded to one.
indices.strides = strides; indices.strides = strides;
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX); let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
Execution::start(kernel, indices.client) Execution::start(kernel, indices.client)
.inputs(&[ .inputs(&[

View File

@ -1,14 +1,10 @@
use crate::{ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
element::JitElement,
kernel::{Kernel, SUBCUBE_DIM_APPROX},
tensor::JitTensor,
JitRuntime,
};
use cubecl::{ use cubecl::{
calculate_cube_count_elemwise, cpa, calculate_cube_count_elemwise, cpa,
frontend::TensorHandleRef, frontend::TensorHandleRef,
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, CubeDim, Execution, InputInfo, KernelExpansion, KernelIntegrator,
KernelSettings,
}; };
use std::marker::PhantomData; use std::marker::PhantomData;
@ -208,7 +204,9 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
}); });
let kernel = SelectAssignEagerKernel::<R, E>::new(dim); let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
Execution::start(kernel, indices.client) Execution::start(kernel, indices.client)
.inputs(&[ .inputs(&[

View File

@ -1,7 +1,7 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use cubecl::{ use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, unexpanded, SUBCUBE_DIM_APPROX, tensor_vectorization_factor, unexpanded,
}; };
use super::Kernel; use super::Kernel;
@ -71,17 +71,17 @@ where
let client = tensor.client.clone(); let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements(); let num_elems = tensor.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize, let cube_dim = CubeDim::default();
SUBCUBE_DIM_APPROX, let cube_count =
); calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let is_contiguous = tensor.is_contiguous(); let is_contiguous = tensor.is_contiguous();
if tensor.can_mut() && is_contiguous { if tensor.can_mut() && is_contiguous {
unary_kernel::launch::<E::Primitive, O, R>( unary_kernel::launch::<E::Primitive, O, R>(
&client, &client,
cube_count, cube_count,
CubeDim::default(), cube_dim,
TensorArg::vectorized( TensorArg::vectorized(
vectorization_factor, vectorization_factor,
&tensor.handle, &tensor.handle,

View File

@ -4,7 +4,7 @@ use crate::{element::JitElement, tensor::JitTensor};
use burn_tensor::{ElementConversion, Shape}; use burn_tensor::{ElementConversion, Shape};
use cubecl::client::ComputeClient; use cubecl::client::ComputeClient;
use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable}; use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable};
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; use cubecl::{calculate_cube_count_elemwise, prelude::*};
use cubecl::{tensor_vectorization_factor, Runtime}; use cubecl::{tensor_vectorization_factor, Runtime};
pub fn full<R: JitRuntime, E: JitElement, const D: usize>( pub fn full<R: JitRuntime, E: JitElement, const D: usize>(
@ -37,15 +37,15 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
let num_elems = empty.shape.num_elements(); let num_elems = empty.shape.num_elements();
let vectorization_factor = let vectorization_factor =
tensor_vectorization_factor(&[4, 2], &empty.shape.dims, &empty.strides, D - 1); tensor_vectorization_factor(&[4, 2], &empty.shape.dims, &empty.strides, D - 1);
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize, let cube_dim = CubeDim::default();
SUBCUBE_DIM_APPROX, let cube_count =
); calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
full_kernel::launch::<E::Primitive, R>( full_kernel::launch::<E::Primitive, R>(
&empty.client, &empty.client,
cube_count, cube_count,
CubeDim::default(), cube_dim,
TensorArg::vectorized( TensorArg::vectorized(
vectorization_factor, vectorization_factor,
&empty.handle, &empty.handle,