mirror of https://github.com/tracel-ai/burn.git
refactor following cube changes
This commit is contained in:
parent
096ec13c48
commit
4c025aa95c
|
@ -1303,7 +1303,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl"
|
name = "cubecl"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cubecl-core",
|
"cubecl-core",
|
||||||
"cubecl-cuda",
|
"cubecl-cuda",
|
||||||
|
@ -1314,7 +1314,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-common"
|
name = "cubecl-common"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
|
@ -1328,7 +1328,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-core"
|
name = "cubecl-core"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-macros",
|
"cubecl-macros",
|
||||||
|
@ -1343,7 +1343,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-cuda"
|
name = "cubecl-cuda"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-common",
|
"cubecl-common",
|
||||||
|
@ -1358,7 +1358,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-linalg"
|
name = "cubecl-linalg"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-core",
|
"cubecl-core",
|
||||||
|
@ -1369,7 +1369,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-macros"
|
name = "cubecl-macros"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
|
@ -1380,7 +1380,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-runtime"
|
name = "cubecl-runtime"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-channel",
|
"async-channel",
|
||||||
"cubecl-common",
|
"cubecl-common",
|
||||||
|
@ -1399,7 +1399,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-wgpu"
|
name = "cubecl-wgpu"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl#2b95a9e245bf4362b497866ee24bec399d1c74fb"
|
source = "git+https://github.com/tracel-ai/cubecl#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-channel",
|
"async-channel",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
|
|
|
@ -67,7 +67,9 @@ rstest = "0.19.0"
|
||||||
rusqlite = { version = "0.31.0" }
|
rusqlite = { version = "0.31.0" }
|
||||||
rust-format = { version = "0.3.4" }
|
rust-format = { version = "0.3.4" }
|
||||||
sanitize-filename = "0.5.0"
|
sanitize-filename = "0.5.0"
|
||||||
serde_bytes = { version = "0.11.15", default-features = false, features = ["alloc"] } # alloc for no_std
|
serde_bytes = { version = "0.11.15", default-features = false, features = [
|
||||||
|
"alloc",
|
||||||
|
] } # alloc for no_std
|
||||||
serde_rusqlite = "0.35.0"
|
serde_rusqlite = "0.35.0"
|
||||||
serial_test = "3.1.1"
|
serial_test = "3.1.1"
|
||||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||||
|
@ -148,5 +150,5 @@ cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl"
|
||||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
||||||
|
|
||||||
[profile.dev]
|
[profile.dev]
|
||||||
debug = 0 # Speed up compilation time and not necessary.
|
debug = 0 # Speed up compilation time and not necessary.
|
||||||
opt-level = 2
|
opt-level = 2
|
||||||
|
|
|
@ -30,11 +30,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
|
||||||
outputs: &[&TensorDescription],
|
outputs: &[&TensorDescription],
|
||||||
stateful: bool,
|
stateful: bool,
|
||||||
) -> FusionKernel<R> {
|
) -> FusionKernel<R> {
|
||||||
let cube_dim_x = self.cube_dim.x;
|
assert_eq!(
|
||||||
let cube_dim_y = self.cube_dim.y;
|
self.cube_dim.x, self.cube_dim.y,
|
||||||
|
"The grid must be a square"
|
||||||
assert_eq!(cube_dim_x, cube_dim_y, "The grid must be a square");
|
);
|
||||||
let cube_dim = cube_dim_x as usize;
|
|
||||||
|
|
||||||
let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4);
|
let vectorize_4 = can_vectorize(handles_inputs, inputs, outputs, 4);
|
||||||
let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2);
|
let vectorize_2 = can_vectorize(handles_inputs, inputs, outputs, 2);
|
||||||
|
@ -69,7 +68,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
|
||||||
|
|
||||||
let reference_tensor = inputs[settings.mappings[0].pos_input];
|
let reference_tensor = inputs[settings.mappings[0].pos_input];
|
||||||
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
|
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
|
||||||
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
|
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
|
||||||
let output_infos =
|
let output_infos =
|
||||||
inplace_output2input
|
inplace_output2input
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -96,7 +95,7 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
|
||||||
false => {
|
false => {
|
||||||
let reference_tensor = outputs[0];
|
let reference_tensor = outputs[0];
|
||||||
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
|
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
|
||||||
let cube_count = calculate_cube_count_elemwise(num_elems / factor, cube_dim);
|
let cube_count = calculate_cube_count_elemwise(num_elems / factor, self.cube_dim);
|
||||||
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
|
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
|
||||||
let size = calculate_num_elems_dyn_rank(&tensor.shape)
|
let size = calculate_num_elems_dyn_rank(&tensor.shape)
|
||||||
* self.info.outputs[pos].elem_size::<R>();
|
* self.info.outputs[pos].elem_size::<R>();
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
|
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
|
||||||
use cubecl::linalg::tensor::index_offset_with_layout;
|
use cubecl::linalg::tensor::index_offset_with_layout;
|
||||||
use cubecl::{
|
use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
|
||||||
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, SUBCUBE_DIM_APPROX,
|
|
||||||
};
|
|
||||||
use cubecl::{ir::KernelDefinition, KernelSettings};
|
use cubecl::{ir::KernelDefinition, KernelSettings};
|
||||||
use std::any::TypeId;
|
use std::any::TypeId;
|
||||||
|
|
||||||
|
@ -46,10 +44,10 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
|
||||||
tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
|
tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
|
||||||
|
|
||||||
let num_elems: usize = input.shape.num_elements();
|
let num_elems: usize = input.shape.num_elements();
|
||||||
let cube_count = calculate_cube_count_elemwise(
|
|
||||||
num_elems / vectorization_factor as usize,
|
let cube_dim = CubeDim::default();
|
||||||
SUBCUBE_DIM_APPROX,
|
let cube_count =
|
||||||
);
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
let client = input.client.clone();
|
let client = input.client.clone();
|
||||||
let handle = client.empty(num_elems * core::mem::size_of::<EO>());
|
let handle = client.empty(num_elems * core::mem::size_of::<EO>());
|
||||||
let output =
|
let output =
|
||||||
|
@ -58,7 +56,7 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
|
||||||
cast_element::launch::<EI::Primitive, EO::Primitive, R>(
|
cast_element::launch::<EI::Primitive, EO::Primitive, R>(
|
||||||
&client,
|
&client,
|
||||||
cube_count,
|
cube_count,
|
||||||
CubeDim::default(),
|
cube_dim,
|
||||||
TensorArg::vectorized(
|
TensorArg::vectorized(
|
||||||
vectorization_factor,
|
vectorization_factor,
|
||||||
&input.handle,
|
&input.handle,
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
|
use crate::{tensor::JitTensor, JitElement, JitRuntime};
|
||||||
use cubecl::{
|
use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};
|
||||||
cpa,
|
|
||||||
frontend::TensorHandleRef,
|
#[cube(launch)]
|
||||||
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility},
|
fn bool_cast_kernel<T: Numeric>(input: &Tensor<UInt>, output: &mut Tensor<T>) {
|
||||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
if input[ABSOLUTE_POS] == UInt::new(1) {
|
||||||
OutputInfo,
|
output[ABSOLUTE_POS] = T::from_int(1);
|
||||||
};
|
} else {
|
||||||
use std::marker::PhantomData;
|
output[ABSOLUTE_POS] = T::from_int(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Cast a bool tensor to the given element type.
|
/// Cast a bool tensor to the given element type.
|
||||||
///
|
///
|
||||||
|
@ -17,7 +19,6 @@ use std::marker::PhantomData;
|
||||||
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
|
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
|
||||||
tensor: JitTensor<R, u32, D>,
|
tensor: JitTensor<R, u32, D>,
|
||||||
) -> JitTensor<R, EO, D> {
|
) -> JitTensor<R, EO, D> {
|
||||||
let kernel = BoolCastEagerKernel::<R, EO>::new();
|
|
||||||
let num_elems = tensor.shape.num_elements();
|
let num_elems = tensor.shape.num_elements();
|
||||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
|
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
|
||||||
let output = JitTensor::new_contiguous(
|
let output = JitTensor::new_contiguous(
|
||||||
|
@ -27,86 +28,16 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
|
||||||
buffer,
|
buffer,
|
||||||
);
|
);
|
||||||
|
|
||||||
Execution::start(kernel, tensor.client)
|
let cube_dim = CubeDim::default();
|
||||||
.inputs(&[TensorHandleRef::<R>::new(
|
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||||
&tensor.handle,
|
|
||||||
&tensor.strides,
|
bool_cast_kernel::launch::<EO::Primitive, R>(
|
||||||
&tensor.shape.dims,
|
&tensor.client,
|
||||||
)])
|
cube_count,
|
||||||
.outputs(&[TensorHandleRef::new(
|
cube_dim,
|
||||||
&output.handle,
|
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||||
&output.strides,
|
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
|
||||||
&output.shape.dims,
|
);
|
||||||
)])
|
|
||||||
.execute(CubeCountSettings::Output { pos: 0 });
|
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct BoolCastShader {
|
|
||||||
tensor: Variable,
|
|
||||||
output: Variable,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(new)]
|
|
||||||
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
|
|
||||||
_runtime: PhantomData<R>,
|
|
||||||
_elem_out: PhantomData<EO>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
|
|
||||||
fn define(&self) -> KernelDefinition {
|
|
||||||
let mut scope = Scope::root();
|
|
||||||
let item_input = Item::new(Elem::Bool);
|
|
||||||
let item_output = EO::cube_elem().into();
|
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_input,
|
|
||||||
};
|
|
||||||
let output = Variable::GlobalOutputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_output,
|
|
||||||
};
|
|
||||||
|
|
||||||
BoolCastShader { tensor, output }.expand(&mut scope);
|
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
|
||||||
|
|
||||||
let tensor = InputInfo::Array {
|
|
||||||
item: item_input,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = OutputInfo::Array { item: item_output };
|
|
||||||
|
|
||||||
let info = KernelExpansion {
|
|
||||||
inputs: vec![tensor],
|
|
||||||
outputs: vec![out],
|
|
||||||
scope,
|
|
||||||
};
|
|
||||||
|
|
||||||
let settings = KernelSettings::default();
|
|
||||||
KernelIntegrator::new(info).integrate(settings)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn id(&self) -> cubecl::KernelId {
|
|
||||||
cubecl::KernelId::new::<Self>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BoolCastShader {
|
|
||||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
|
||||||
let tensor = self.tensor;
|
|
||||||
let id = Variable::AbsolutePos;
|
|
||||||
let output = self.output;
|
|
||||||
|
|
||||||
let represents_true = scope.create_local(Elem::Bool);
|
|
||||||
cpa!(scope, represents_true = tensor[id]);
|
|
||||||
cpa!(scope, if(represents_true).then(|scope|{
|
|
||||||
cpa!(scope, output[id] = 1);
|
|
||||||
}).else(|scope|{
|
|
||||||
cpa!(scope, output[id] = 0);
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||||
use burn_tensor::Shape;
|
use burn_tensor::Shape;
|
||||||
use cubecl::{
|
use cubecl::{
|
||||||
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
|
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
|
||||||
tensor_vectorization_factor, Runtime, SUBCUBE_DIM_APPROX,
|
tensor_vectorization_factor, Runtime,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cube]
|
#[cube]
|
||||||
|
@ -139,17 +139,17 @@ pub(crate) fn launch_cmp<
|
||||||
let shape_out = Shape::new(shape_out);
|
let shape_out = Shape::new(shape_out);
|
||||||
let client = lhs.client.clone();
|
let client = lhs.client.clone();
|
||||||
let num_elems = shape_out.num_elements();
|
let num_elems = shape_out.num_elements();
|
||||||
let cube_count = calculate_cube_count_elemwise(
|
|
||||||
num_elems / vectorization_factor as usize,
|
let cube_dim = CubeDim::default();
|
||||||
SUBCUBE_DIM_APPROX,
|
let cube_count =
|
||||||
);
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
|
|
||||||
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
||||||
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
|
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
|
||||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||||
&client,
|
&client,
|
||||||
cube_count,
|
cube_count,
|
||||||
CubeDim::default(),
|
cube_dim,
|
||||||
TensorArg::vectorized(
|
TensorArg::vectorized(
|
||||||
vectorization_factor,
|
vectorization_factor,
|
||||||
&lhs.handle,
|
&lhs.handle,
|
||||||
|
@ -244,17 +244,17 @@ pub(crate) fn launch_scalar_cmp<
|
||||||
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
|
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
|
||||||
let client = tensor.client.clone();
|
let client = tensor.client.clone();
|
||||||
let num_elems = tensor.shape.num_elements();
|
let num_elems = tensor.shape.num_elements();
|
||||||
let cube_count = calculate_cube_count_elemwise(
|
|
||||||
num_elems / vectorization_factor as usize,
|
let cube_dim = CubeDim::default();
|
||||||
SUBCUBE_DIM_APPROX,
|
let cube_count =
|
||||||
);
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
|
|
||||||
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
||||||
if same_tensor_type && tensor.can_mut() {
|
if same_tensor_type && tensor.can_mut() {
|
||||||
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
|
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
|
||||||
&client,
|
&client,
|
||||||
cube_count,
|
cube_count,
|
||||||
CubeDim::default(),
|
cube_dim,
|
||||||
TensorArg::vectorized(
|
TensorArg::vectorized(
|
||||||
vectorization_factor,
|
vectorization_factor,
|
||||||
&tensor.handle,
|
&tensor.handle,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
|
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||||
|
|
||||||
use burn_tensor::{
|
use burn_tensor::{
|
||||||
ops::{conv::calculate_conv_output_size, ConvOptions},
|
ops::{conv::calculate_conv_output_size, ConvOptions},
|
||||||
|
@ -161,12 +161,13 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||||
};
|
};
|
||||||
|
|
||||||
let num_elems_output = output.shape.num_elements();
|
let num_elems_output = output.shape.num_elements();
|
||||||
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);
|
||||||
|
|
||||||
conv2d_kernel::launch::<E::FloatPrimitive, R>(
|
conv2d_kernel::launch::<E::FloatPrimitive, R>(
|
||||||
&input.client,
|
&input.client,
|
||||||
|
cube_count,
|
||||||
cube_dim,
|
cube_dim,
|
||||||
CubeDim::default(),
|
|
||||||
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
|
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
|
||||||
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
||||||
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
|
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||||
|
|
||||||
use burn_tensor::{
|
use burn_tensor::{
|
||||||
ops::{conv::calculate_conv_output_size, ConvOptions},
|
ops::{conv::calculate_conv_output_size, ConvOptions},
|
||||||
|
@ -188,10 +188,13 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
|
||||||
|
|
||||||
conv3d_kernel::launch::<E::FloatPrimitive, R>(
|
conv3d_kernel::launch::<E::FloatPrimitive, R>(
|
||||||
&input.client,
|
&input.client,
|
||||||
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
|
cube_count,
|
||||||
CubeDim::default(),
|
cube_dim,
|
||||||
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
|
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
|
||||||
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
||||||
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
TensorArg::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
||||||
|
|
|
@ -4,15 +4,15 @@ use crate::{
|
||||||
tensor::JitTensor,
|
tensor::JitTensor,
|
||||||
JitRuntime,
|
JitRuntime,
|
||||||
};
|
};
|
||||||
|
use cubecl::InputInfo;
|
||||||
use cubecl::{
|
use cubecl::{
|
||||||
calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings,
|
calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings, CubeDim,
|
||||||
KernelExpansion, KernelIntegrator, KernelSettings,
|
KernelExpansion, KernelIntegrator, KernelSettings,
|
||||||
};
|
};
|
||||||
use cubecl::{
|
use cubecl::{
|
||||||
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
|
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
|
||||||
Execution,
|
Execution,
|
||||||
};
|
};
|
||||||
use cubecl::{InputInfo, SUBCUBE_DIM_APPROX};
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
|
@ -223,7 +223,8 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
||||||
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
|
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
|
||||||
indices.strides = strides;
|
indices.strides = strides;
|
||||||
|
|
||||||
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||||
|
|
||||||
Execution::start(kernel, indices.client)
|
Execution::start(kernel, indices.client)
|
||||||
.inputs(&[
|
.inputs(&[
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
use crate::{
|
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
|
||||||
element::JitElement,
|
|
||||||
kernel::{Kernel, SUBCUBE_DIM_APPROX},
|
|
||||||
tensor::JitTensor,
|
|
||||||
JitRuntime,
|
|
||||||
};
|
|
||||||
use cubecl::{
|
use cubecl::{
|
||||||
calculate_cube_count_elemwise, cpa,
|
calculate_cube_count_elemwise, cpa,
|
||||||
frontend::TensorHandleRef,
|
frontend::TensorHandleRef,
|
||||||
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
|
ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
|
||||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
CubeCountSettings, CubeDim, Execution, InputInfo, KernelExpansion, KernelIntegrator,
|
||||||
|
KernelSettings,
|
||||||
};
|
};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
@ -208,7 +204,9 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
});
|
});
|
||||||
|
|
||||||
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
|
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
|
||||||
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
|
|
||||||
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||||
|
|
||||||
Execution::start(kernel, indices.client)
|
Execution::start(kernel, indices.client)
|
||||||
.inputs(&[
|
.inputs(&[
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||||
use cubecl::{
|
use cubecl::{
|
||||||
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
|
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
|
||||||
tensor_vectorization_factor, unexpanded, SUBCUBE_DIM_APPROX,
|
tensor_vectorization_factor, unexpanded,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
@ -71,17 +71,17 @@ where
|
||||||
|
|
||||||
let client = tensor.client.clone();
|
let client = tensor.client.clone();
|
||||||
let num_elems = tensor.shape.num_elements();
|
let num_elems = tensor.shape.num_elements();
|
||||||
let cube_count = calculate_cube_count_elemwise(
|
|
||||||
num_elems / vectorization_factor as usize,
|
let cube_dim = CubeDim::default();
|
||||||
SUBCUBE_DIM_APPROX,
|
let cube_count =
|
||||||
);
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
let is_contiguous = tensor.is_contiguous();
|
let is_contiguous = tensor.is_contiguous();
|
||||||
|
|
||||||
if tensor.can_mut() && is_contiguous {
|
if tensor.can_mut() && is_contiguous {
|
||||||
unary_kernel::launch::<E::Primitive, O, R>(
|
unary_kernel::launch::<E::Primitive, O, R>(
|
||||||
&client,
|
&client,
|
||||||
cube_count,
|
cube_count,
|
||||||
CubeDim::default(),
|
cube_dim,
|
||||||
TensorArg::vectorized(
|
TensorArg::vectorized(
|
||||||
vectorization_factor,
|
vectorization_factor,
|
||||||
&tensor.handle,
|
&tensor.handle,
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::{element::JitElement, tensor::JitTensor};
|
||||||
use burn_tensor::{ElementConversion, Shape};
|
use burn_tensor::{ElementConversion, Shape};
|
||||||
use cubecl::client::ComputeClient;
|
use cubecl::client::ComputeClient;
|
||||||
use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable};
|
use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable};
|
||||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX};
|
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||||
use cubecl::{tensor_vectorization_factor, Runtime};
|
use cubecl::{tensor_vectorization_factor, Runtime};
|
||||||
|
|
||||||
pub fn full<R: JitRuntime, E: JitElement, const D: usize>(
|
pub fn full<R: JitRuntime, E: JitElement, const D: usize>(
|
||||||
|
@ -37,15 +37,15 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
|
||||||
let num_elems = empty.shape.num_elements();
|
let num_elems = empty.shape.num_elements();
|
||||||
let vectorization_factor =
|
let vectorization_factor =
|
||||||
tensor_vectorization_factor(&[4, 2], &empty.shape.dims, &empty.strides, D - 1);
|
tensor_vectorization_factor(&[4, 2], &empty.shape.dims, &empty.strides, D - 1);
|
||||||
let cube_count = calculate_cube_count_elemwise(
|
|
||||||
num_elems / vectorization_factor as usize,
|
let cube_dim = CubeDim::default();
|
||||||
SUBCUBE_DIM_APPROX,
|
let cube_count =
|
||||||
);
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
|
|
||||||
full_kernel::launch::<E::Primitive, R>(
|
full_kernel::launch::<E::Primitive, R>(
|
||||||
&empty.client,
|
&empty.client,
|
||||||
cube_count,
|
cube_count,
|
||||||
CubeDim::default(),
|
cube_dim,
|
||||||
TensorArg::vectorized(
|
TensorArg::vectorized(
|
||||||
vectorization_factor,
|
vectorization_factor,
|
||||||
&empty.handle,
|
&empty.handle,
|
||||||
|
|
Loading…
Reference in New Issue