This commit is contained in:
nathaniel 2024-07-11 17:25:40 -04:00
parent 1f56aae659
commit 15f7745264
26 changed files with 91 additions and 93 deletions

View File

@ -454,10 +454,9 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
let ident = &sig.ident;
let mut ident_expand = TokenStream::new();
ident_expand.extend(quote::quote! {
#ident::__expand
});
let ident_expand = quote::quote! {
__expand
};
let generics = add_runtime(add_lifetime(sig.generics.clone()));
let body = codegen.gen_launch_body();

View File

@ -113,7 +113,7 @@ fn codegen_cube(
let signature = expand_sig(
&func.sig,
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
// it from an outside module.
// it from an outside module.
Some(variable_tracker),
ExpandMode::FuncImpl,
);
@ -143,6 +143,8 @@ fn codegen_cube(
return Err(code);
}
let launch_doc = if launch { "and launch function " } else { "" };
let launch = if launch {
codegen_launch(&func.sig)
} else {
@ -151,12 +153,15 @@ fn codegen_cube(
let mod_name = &func.sig.ident;
let vis = &func.vis;
let doc = format!("Module containing the expand method {launch_doc}of {mod_name}.");
Ok(quote::quote! {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
#func
#[doc = #doc]
#vis mod #mod_name {
use super::*;

View File

@ -12,7 +12,7 @@ pub trait Clamp: CubePrimitive + Sized {
fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
unexpanded!()
}
fn clamp_expand(
fn __expand_clamp(
context: &mut CubeContext,
input: Self::ExpandType,
min_value: Self::ExpandType,

View File

@ -3,6 +3,10 @@ use crate::ir::Synchronization;
pub fn sync_units() {}
pub fn sync_units_expand(context: &mut CubeContext) {
context.register(Synchronization::SyncUnits)
pub mod sync_units {
use super::*;
pub fn __expand(context: &mut CubeContext) {
context.register(Synchronization::SyncUnits)
}
}

View File

@ -3,23 +3,23 @@ use burn_cube::prelude::*;
use crate::kernel::{launch_unary, UnaryOp};
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
#[derive(CubeLaunch)]
struct Options<C: Numeric> {
min_value: C,
max_value: C,
}
pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
min_value: E,
max_value: E,
) -> JitTensor<R, E, D> {
#[derive(CubeLaunch)]
struct Options<C: Numeric> {
min_value: C,
max_value: C,
}
struct ClampOp;
impl<C: Numeric> UnaryOp<C> for ClampOp {
type Options = Options<C>;
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
options: OptionsExpand<C>,
@ -29,7 +29,7 @@ pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
C::clamp(input, options.min_value, options.max_value)
}
execute_expand(context, input, options)
execute::__expand(context, input, options)
}
}

View File

@ -1,4 +1,4 @@
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
use super::{index_offset_with_layout, Kernel};
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_cube::{
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime,
@ -146,7 +146,7 @@ pub(crate) fn launch_cmp<
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
kernel_cmp_launch::<E::Primitive, O, R>(
kernel_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -170,7 +170,7 @@ pub(crate) fn launch_cmp<
JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides)
} else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {
kernel_cmp_launch::<E::Primitive, O, R>(
kernel_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -199,7 +199,7 @@ pub(crate) fn launch_cmp<
let to_contiguous_rhs = !rhs.is_contiguous();
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
kernel_cmp_launch::<E::Primitive, O, R>(
kernel_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -251,7 +251,7 @@ pub(crate) fn launch_scalar_cmp<
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && tensor.can_mut() {
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -282,7 +282,7 @@ pub(crate) fn launch_scalar_cmp<
tensor.strides,
);
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),

View File

@ -87,7 +87,7 @@ pub fn into_contiguous<R: JitRuntime, E: JitElement, const D: usize>(
SUBCUBE_DIM_APPROX,
);
into_contiguous_kernel_launch::<E::Primitive, R>(
into_contiguous_kernel::launch::<E::Primitive, R>(
client,
cube_count,
CubeDim::default(),

View File

@ -163,7 +163,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let num_elems_output = output.shape.num_elements();
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
conv2d_kernel_launch::<E::FloatPrimitive, R>(
conv2d_kernel::launch::<E::FloatPrimitive, R>(
input.client,
cube_dim,
CubeDim::default(),

View File

@ -188,7 +188,7 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
}
};
conv3d_kernel_launch::<E::FloatPrimitive, R>(
conv3d_kernel::launch::<E::FloatPrimitive, R>(
input.client,
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
CubeDim::default(),

View File

@ -116,7 +116,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
false => 1,
};
matmul_kernel_launch::<E::FloatPrimitive, R>(
matmul_kernel::launch::<E::FloatPrimitive, R>(
lhs.client,
cube_count,
CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),

View File

@ -2,11 +2,11 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::block_loop::{block_loop, block_loop_expand};
use super::block_loop::block_loop;
#[cube(launch)]
#[allow(unused_mut)]
fn tiling2d_cube<F: Float>(
pub fn tiling2d_cube_kernel<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
out: &mut Tensor<F>,

View File

@ -4,10 +4,10 @@ use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
compute_loop::{compute_loop, compute_loop_expand},
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
compute_loop::compute_loop,
load_shared_memory::load_to_shared_memories,
tile::{loader::TileLoader, writer::TileWriter},
write_output::{write_to_output, write_to_output_expand},
write_output::write_to_output,
};
#[cube]

View File

@ -2,10 +2,7 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::Coordinates,
outer_product::{tile_outer_product, tile_outer_product_expand},
};
use super::{base::Coordinates, outer_product::tile_outer_product};
#[cube]
#[allow(unused_mut)]
@ -105,7 +102,7 @@ pub mod tests {
const SOME_DIM: usize = 12;
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
compute_loop_test_launch::<F32, R>(
compute_loop_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -134,7 +131,7 @@ pub mod tests {
let config = make_config(4, 8, 4);
compute_loop_test_launch::<F32, R>(
compute_loop_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,

View File

@ -7,6 +7,7 @@ use crate::{
into_contiguous,
matmul::{
config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig},
tiling2d_cube::base::tiling2d_cube_kernel,
Tiling2dConfig,
},
},
@ -14,8 +15,6 @@ use crate::{
FloatElement, JitRuntime,
};
use super::base::tiling2d_cube_launch;
/// Matrix multiplication using tiling 2d algorithm
pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>,
@ -69,7 +68,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
let cube_dim = tiling2d_cube_dim(&config);
let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed);
tiling2d_cube_launch::<E::FloatPrimitive, R>(
tiling2d_cube_kernel::launch::<E::FloatPrimitive, R>(
client,
cube_count,
cube_dim,

View File

@ -385,7 +385,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_test_launch::<F32, R>(
load_tensor_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -417,7 +417,7 @@ pub mod tests {
let config = make_config(5, 1, 1);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -451,7 +451,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -481,7 +481,7 @@ pub mod tests {
let config = make_config(8, 8, 16);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -511,7 +511,7 @@ pub mod tests {
let config = make_config(8, 16, 16);
load_tensor_test_launch::<F32, R>(
load_tensor_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -543,7 +543,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -573,7 +573,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -603,7 +603,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -636,7 +636,7 @@ pub mod tests {
let config = make_config(m, k, 8);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -667,7 +667,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -699,7 +699,7 @@ pub mod tests {
let config = make_config(8, k, n);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,

View File

@ -70,7 +70,7 @@ pub mod tests {
const SOME_DIM: usize = 12;
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
tile_outer_product_test_launch::<F32, R>(
tile_outer_product_test::launch::<F32, R>(
client.clone(),
cube_count,
cube_dim,
@ -99,7 +99,7 @@ pub mod tests {
const SOME_DIM: usize = 12;
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
tile_outer_product_test_launch::<F32, R>(
tile_outer_product_test::launch::<F32, R>(
client.clone(),
cube_count,
cube_dim,

View File

@ -14,10 +14,7 @@ use crate::kernel::matmul::{
},
};
use super::base::{
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
BlockLoader, BlockWriter,
};
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
pub(crate) struct HorizontalCheckBlockIO;

View File

@ -14,7 +14,7 @@ use crate::kernel::matmul::{
},
};
use super::base::{all_zeros_runtime, all_zeros_runtime_expand, BlockLoader, BlockWriter};
use super::base::{all_zeros_runtime, BlockLoader, BlockWriter};
pub(crate) struct VerticalCheckBlockIO;

View File

@ -14,10 +14,7 @@ use crate::kernel::matmul::{
},
};
use super::base::{
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
BlockLoader, BlockWriter,
};
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
pub(crate) struct WholeCheckBlockIO;

View File

@ -130,7 +130,7 @@ pub mod tests {
let config = make_config(6, 8, 8);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -156,7 +156,7 @@ pub mod tests {
let config = make_config(8, 8, 4);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -182,7 +182,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -215,7 +215,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -248,7 +248,7 @@ pub mod tests {
let config = make_config(5, 8, 1);
write_results_to_output_out_of_bounds_test_launch::<F32, R>(
write_results_to_output_out_of_bounds_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,

View File

@ -11,7 +11,7 @@ pub use binary::*;
pub use cast::*;
pub use contiguous::*;
pub use mask::*;
pub use unary::*;
pub(crate) use unary::*;
pub use burn_cube::{Kernel, SUBCUBE_DIM_APPROX};

View File

@ -4,7 +4,7 @@ use burn_cube::{
SUBCUBE_DIM_APPROX,
};
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
use super::{index_offset_with_layout, Kernel};
pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
type Options: LaunchArg;
@ -13,7 +13,7 @@ pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
fn execute(_input: C, _options: &Self::Options) -> C {
unexpanded!();
}
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
options: <Self::Options as CubeType>::ExpandType,
@ -78,7 +78,7 @@ where
let is_contiguous = tensor.is_contiguous();
if tensor.can_mut() && is_contiguous {
unary_kernel_launch::<E::Primitive, O, R>(
unary_kernel::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -104,7 +104,7 @@ where
buffer,
);
unary_kernel_launch::<E::Primitive, O, R>(
unary_kernel::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -136,7 +136,7 @@ macro_rules! unary_op {
type Options = ();
#[allow(clippy::redundant_closure_call)]
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
_options: <Self::Options as CubeType>::ExpandType,
@ -152,7 +152,7 @@ macro_rules! unary_op {
type Options = C;
#[allow(clippy::redundant_closure_call)]
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
scalar: C::ExpandType,

View File

@ -334,7 +334,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::exp(input)
}
execute_expand::<C>(context, input)
execute::__expand::<C>(context, input)
})
}
@ -344,7 +344,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::log(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -354,7 +354,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::log1p(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -367,7 +367,7 @@ where
fn execute<C: Float>(input: C, scalar: C) -> C {
C::powf(input, scalar)
}
execute_expand::<C>(context, tensor, scalar)
execute::__expand::<C>(context, tensor, scalar)
})
}
@ -377,7 +377,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::sqrt(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -387,7 +387,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::abs(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -397,7 +397,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::cos(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -407,7 +407,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::sin(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -417,7 +417,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::tanh(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -427,7 +427,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::erf(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -463,7 +463,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::recip(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}

View File

@ -299,7 +299,7 @@ where
fn execute<C: Numeric>(input: C) -> C {
C::abs(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}

View File

@ -26,7 +26,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
let empty = empty_device(client, device, shape);
#[cube(launch)]
pub(crate) fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
pub fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
if ABSOLUTE_POS >= tensor.len() {
return;
}
@ -42,7 +42,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
SUBCUBE_DIM_APPROX,
);
full_kernel_launch::<E::Primitive, R>(
full_kernel::launch::<E::Primitive, R>(
empty.client.clone(),
cube_count,
CubeDim::default(),
@ -127,7 +127,7 @@ pub fn add_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs + rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}
@ -156,7 +156,7 @@ pub fn sub_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs - rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}
@ -185,7 +185,7 @@ pub fn mul_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs * rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}
@ -214,7 +214,7 @@ pub fn div_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs / rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}

View File

@ -146,7 +146,7 @@ where
fn execute<C: Numeric>(input: C) -> C {
input
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}