refactor following cube changes

This commit is contained in:
louisfd 2024-07-29 11:05:18 -04:00
parent 096ec13c48
commit 4c025aa95c
12 changed files with 88 additions and 155 deletions

16
Cargo.lock generated
View File

@ -1303,7 +1303,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1314,7 +1314,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"derive-new",
"getrandom",
@ -1328,7 +1328,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"bytemuck",
"cubecl-macros",
@ -1343,7 +1343,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1358,7 +1358,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1369,7 +1369,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"derive-new",
"proc-macro2",
@ -1380,7 +1380,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"async-channel",
"cubecl-common",
@ -1399,7 +1399,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
dependencies = [
"async-channel",
"bytemuck",

View File

@ -67,7 +67,9 @@ rstest = "0.19.0"
rusqlite = { version = "0.31.0" }
rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0"
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
serde_bytes = { version = "0.11.15", default-features = false, features = [
"alloc",
] } # alloc for no_std
serde_rusqlite = "0.35.0"
serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
@ -148,5 +150,5 @@ cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl"
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
debug = 0 # Speed up compilation time and not necessary.
opt-level = 2

View File

@ -30,11 +30,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
outputs: &[&TensorDescription],
stateful: bool,
) -> FusionKernel<R> {
let cube_dim_x = self.cube_dim.x;
let cube_dim_y = self.cube_dim.y;
assert_eq!(cube_dim_x, cube_dim_y, "The grid must be a square");
let cube_dim = cube_dim_x as usize;
assert_eq!(
self.cube_dim.x, self.cube_dim.y,
"The grid must be a square"
);
let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4);
let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2);
@ -69,7 +68,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
let reference_tensor = inputs[settings.mappings[0].pos_input];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
let output_infos =
inplace_output2input
.iter()
@ -96,7 +95,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
false => {
let reference_tensor = outputs[0];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
let size = calculate_num_elems_dyn_rank(&tensor.shape)
* self.info.outputs[pos].elem_size::<R>();

View File

@ -1,8 +1,6 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::linalg::tensor::index_offset_with_layout;
use cubecl::{
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, SUBCUBE_DIM_APPROX,
};
use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
use cubecl::{ir::KernelDefinition, KernelSettings};
use std::any::TypeId;
@ -46,10 +44,10 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
let num_elems: usize = input.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let client = input.client.clone();
let handle = client.empty(num_elems * core::mem::size_of::<EO>());
let output =
@ -58,7 +56,7 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
cast_element::launch::<EI::Primitive, EO::Primitive, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&input.handle,

View File

@ -1,12 +1,14 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{
cpa,
frontend::TensorHandleRef,
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;
use crate::{tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};
#[cube(launch)]
fn bool_cast_kernel<T: Numeric>(input: &Tensor<UInt>, output: &mut Tensor<T>) {
if input[ABSOLUTE_POS] == UInt::new(1) {
output[ABSOLUTE_POS] = T::from_int(1);
} else {
output[ABSOLUTE_POS] = T::from_int(0);
}
}
/// Cast a bool tensor to the given element type.
///
@ -17,7 +19,6 @@ use std::marker::PhantomData;
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::<R, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new_contiguous(
@ -27,86 +28,16 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
buffer,
);
Execution::start(kernel, tensor.client)
.inputs(&[TensorHandleRef::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)])
.outputs(&[TensorHandleRef::new(
&output.handle,
&output.strides,
&output.shape.dims,
)])
.execute(CubeCountSettings::Output { pos: 0 });
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
bool_cast_kernel::launch::<EO::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
);
output
}
pub(crate) struct BoolCastShader {
tensor: Variable,
output: Variable,
}
#[derive(new)]
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
}
impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_input = Item::new(Elem::Bool);
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
BoolCastShader { tensor, output }.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>()
}
}
impl BoolCastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::AbsolutePos;
let output = self.output;
let represents_true = scope.create_local(Elem::Bool);
cpa!(scope, represents_true = tensor[id]);
cpa!(scope, if(represents_true).then(|scope|{
cpa!(scope, output[id] = 1);
}).else(|scope|{
cpa!(scope, output[id] = 0);
}));
}
}

View File

@ -3,7 +3,7 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_tensor::Shape;
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, Runtime, SUBCUBE_DIM_APPROX,
tensor_vectorization_factor, Runtime,
};
#[cube]
@ -139,17 +139,17 @@ pub(crate) fn launch_cmp<
let shape_out = Shape::new(shape_out);
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
kernel_cmp::launch::<E::Primitive, O, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&lhs.handle,
@ -244,17 +244,17 @@ pub(crate) fn launch_scalar_cmp<
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && tensor.can_mut() {
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,

View File

@ -1,4 +1,4 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
@ -161,12 +161,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
};
let num_elems_output = output.shape.num_elements();
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);
conv2d_kernel::launch::<E::FloatPrimitive, R>(
&input.client,
cube_count,
cube_dim,
CubeDim::default(),
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),

View File

@ -1,4 +1,4 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
@ -188,10 +188,13 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
}
};
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
conv3d_kernel::launch::<E::FloatPrimitive, R>(
&input.client,
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
CubeDim::default(),
cube_count,
cube_dim,
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),

View File

@ -4,15 +4,15 @@ use crate::{
tensor::JitTensor,
JitRuntime,
};
use cubecl::InputInfo;
use cubecl::{
calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings,
calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings, CubeDim,
KernelExpansion, KernelIntegrator, KernelSettings,
};
use cubecl::{
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
Execution,
};
use cubecl::{InputInfo, SUBCUBE_DIM_APPROX};
use std::marker::PhantomData;
#[derive(new)]
@ -223,7 +223,8 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
indices.strides = strides;
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
Execution::start(kernel, indices.client)
.inputs(&[

View File

@ -1,14 +1,10 @@
use crate::{
element::JitElement,
kernel::{Kernel, SUBCUBE_DIM_APPROX},
tensor::JitTensor,
JitRuntime,
};
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use cubecl::{
calculate_cube_count_elemwise, cpa,
frontend::TensorHandleRef,
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
CubeCountSettings, CubeDim, Execution, InputInfo, KernelExpansion, KernelIntegrator,
KernelSettings,
};
use std::marker::PhantomData;
@ -208,7 +204,9 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
});
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
Execution::start(kernel, indices.client)
.inputs(&[

View File

@ -1,7 +1,7 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, unexpanded, SUBCUBE_DIM_APPROX,
tensor_vectorization_factor, unexpanded,
};
use super::Kernel;
@ -71,17 +71,17 @@ where
let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let is_contiguous = tensor.is_contiguous();
if tensor.can_mut() && is_contiguous {
unary_kernel::launch::<E::Primitive, O, R>(
&client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,

View File

@ -4,7 +4,7 @@ use crate::{element::JitElement, tensor::JitTensor};
use burn_tensor::{ElementConversion, Shape};
use cubecl::client::ComputeClient;
use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable};
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use cubecl::{tensor_vectorization_factor, Runtime};
pub fn full<R: JitRuntime, E: JitElement, const D: usize>(
@ -37,15 +37,15 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
let num_elems = empty.shape.num_elements();
let vectorization_factor =
tensor_vectorization_factor(&[4, 2], &empty.shape.dims, &empty.strides, D - 1);
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
full_kernel::launch::<E::Primitive, R>(
&empty.client,
cube_count,
CubeDim::default(),
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&empty.handle,