mirror of https://github.com/tracel-ai/burn.git
Cleanup
This commit is contained in:
parent
1f56aae659
commit
15f7745264
|
@ -454,10 +454,9 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
|
|||
|
||||
let ident = &sig.ident;
|
||||
|
||||
let mut ident_expand = TokenStream::new();
|
||||
ident_expand.extend(quote::quote! {
|
||||
#ident::__expand
|
||||
});
|
||||
let ident_expand = quote::quote! {
|
||||
__expand
|
||||
};
|
||||
|
||||
let generics = add_runtime(add_lifetime(sig.generics.clone()));
|
||||
let body = codegen.gen_launch_body();
|
||||
|
|
|
@ -113,7 +113,7 @@ fn codegen_cube(
|
|||
let signature = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
|
||||
// it from an outside module.
|
||||
// it from an outside module.
|
||||
Some(variable_tracker),
|
||||
ExpandMode::FuncImpl,
|
||||
);
|
||||
|
@ -143,6 +143,8 @@ fn codegen_cube(
|
|||
return Err(code);
|
||||
}
|
||||
|
||||
let launch_doc = if launch { "and launch function " } else { "" };
|
||||
|
||||
let launch = if launch {
|
||||
codegen_launch(&func.sig)
|
||||
} else {
|
||||
|
@ -151,12 +153,15 @@ fn codegen_cube(
|
|||
|
||||
let mod_name = &func.sig.ident;
|
||||
let vis = &func.vis;
|
||||
let doc = format!("Module containing the expand method {launch_doc}of {mod_name}.");
|
||||
|
||||
Ok(quote::quote! {
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#func
|
||||
|
||||
|
||||
#[doc = #doc]
|
||||
#vis mod #mod_name {
|
||||
use super::*;
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ pub trait Clamp: CubePrimitive + Sized {
|
|||
fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
fn clamp_expand(
|
||||
fn __expand_clamp(
|
||||
context: &mut CubeContext,
|
||||
input: Self::ExpandType,
|
||||
min_value: Self::ExpandType,
|
||||
|
|
|
@ -3,6 +3,10 @@ use crate::ir::Synchronization;
|
|||
|
||||
pub fn sync_units() {}
|
||||
|
||||
pub fn sync_units_expand(context: &mut CubeContext) {
|
||||
context.register(Synchronization::SyncUnits)
|
||||
pub mod sync_units {
|
||||
use super::*;
|
||||
|
||||
pub fn __expand(context: &mut CubeContext) {
|
||||
context.register(Synchronization::SyncUnits)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,23 +3,23 @@ use burn_cube::prelude::*;
|
|||
use crate::kernel::{launch_unary, UnaryOp};
|
||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct Options<C: Numeric> {
|
||||
min_value: C,
|
||||
max_value: C,
|
||||
}
|
||||
|
||||
pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
input: JitTensor<R, E, D>,
|
||||
min_value: E,
|
||||
max_value: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
#[derive(CubeLaunch)]
|
||||
struct Options<C: Numeric> {
|
||||
min_value: C,
|
||||
max_value: C,
|
||||
}
|
||||
|
||||
struct ClampOp;
|
||||
|
||||
impl<C: Numeric> UnaryOp<C> for ClampOp {
|
||||
type Options = Options<C>;
|
||||
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
options: OptionsExpand<C>,
|
||||
|
@ -29,7 +29,7 @@ pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
C::clamp(input, options.min_value, options.max_value)
|
||||
}
|
||||
|
||||
execute_expand(context, input, options)
|
||||
execute::__expand(context, input, options)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
|
||||
use super::{index_offset_with_layout, Kernel};
|
||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
use burn_cube::{
|
||||
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime,
|
||||
|
@ -146,7 +146,7 @@ pub(crate) fn launch_cmp<
|
|||
|
||||
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
||||
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
|
||||
kernel_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -170,7 +170,7 @@ pub(crate) fn launch_cmp<
|
|||
|
||||
JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides)
|
||||
} else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {
|
||||
kernel_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -199,7 +199,7 @@ pub(crate) fn launch_cmp<
|
|||
let to_contiguous_rhs = !rhs.is_contiguous();
|
||||
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
|
||||
|
||||
kernel_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -251,7 +251,7 @@ pub(crate) fn launch_scalar_cmp<
|
|||
|
||||
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
||||
if same_tensor_type && tensor.can_mut() {
|
||||
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -282,7 +282,7 @@ pub(crate) fn launch_scalar_cmp<
|
|||
tensor.strides,
|
||||
);
|
||||
|
||||
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -87,7 +87,7 @@ pub fn into_contiguous<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
SUBCUBE_DIM_APPROX,
|
||||
);
|
||||
|
||||
into_contiguous_kernel_launch::<E::Primitive, R>(
|
||||
into_contiguous_kernel::launch::<E::Primitive, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -163,7 +163,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
let num_elems_output = output.shape.num_elements();
|
||||
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
|
||||
|
||||
conv2d_kernel_launch::<E::FloatPrimitive, R>(
|
||||
conv2d_kernel::launch::<E::FloatPrimitive, R>(
|
||||
input.client,
|
||||
cube_dim,
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -188,7 +188,7 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
|
|||
}
|
||||
};
|
||||
|
||||
conv3d_kernel_launch::<E::FloatPrimitive, R>(
|
||||
conv3d_kernel::launch::<E::FloatPrimitive, R>(
|
||||
input.client,
|
||||
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -116,7 +116,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
false => 1,
|
||||
};
|
||||
|
||||
matmul_kernel_launch::<E::FloatPrimitive, R>(
|
||||
matmul_kernel::launch::<E::FloatPrimitive, R>(
|
||||
lhs.client,
|
||||
cube_count,
|
||||
CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
|
||||
|
|
|
@ -2,11 +2,11 @@ use burn_cube::prelude::*;
|
|||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::block_loop::{block_loop, block_loop_expand};
|
||||
use super::block_loop::block_loop;
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn tiling2d_cube<F: Float>(
|
||||
pub fn tiling2d_cube_kernel<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
rhs: &Tensor<F>,
|
||||
out: &mut Tensor<F>,
|
||||
|
|
|
@ -4,10 +4,10 @@ use crate::kernel::matmul::config::CubeTiling2dConfig;
|
|||
|
||||
use super::{
|
||||
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
|
||||
compute_loop::{compute_loop, compute_loop_expand},
|
||||
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
|
||||
compute_loop::compute_loop,
|
||||
load_shared_memory::load_to_shared_memories,
|
||||
tile::{loader::TileLoader, writer::TileWriter},
|
||||
write_output::{write_to_output, write_to_output_expand},
|
||||
write_output::write_to_output,
|
||||
};
|
||||
|
||||
#[cube]
|
||||
|
|
|
@ -2,10 +2,7 @@ use burn_cube::prelude::*;
|
|||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::{
|
||||
base::Coordinates,
|
||||
outer_product::{tile_outer_product, tile_outer_product_expand},
|
||||
};
|
||||
use super::{base::Coordinates, outer_product::tile_outer_product};
|
||||
|
||||
#[cube]
|
||||
#[allow(unused_mut)]
|
||||
|
@ -105,7 +102,7 @@ pub mod tests {
|
|||
const SOME_DIM: usize = 12;
|
||||
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
|
||||
|
||||
compute_loop_test_launch::<F32, R>(
|
||||
compute_loop_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -134,7 +131,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(4, 8, 4);
|
||||
|
||||
compute_loop_test_launch::<F32, R>(
|
||||
compute_loop_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::{
|
|||
into_contiguous,
|
||||
matmul::{
|
||||
config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig},
|
||||
tiling2d_cube::base::tiling2d_cube_kernel,
|
||||
Tiling2dConfig,
|
||||
},
|
||||
},
|
||||
|
@ -14,8 +15,6 @@ use crate::{
|
|||
FloatElement, JitRuntime,
|
||||
};
|
||||
|
||||
use super::base::tiling2d_cube_launch;
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm
|
||||
pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
|
@ -69,7 +68,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
let cube_dim = tiling2d_cube_dim(&config);
|
||||
let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed);
|
||||
|
||||
tiling2d_cube_launch::<E::FloatPrimitive, R>(
|
||||
tiling2d_cube_kernel::launch::<E::FloatPrimitive, R>(
|
||||
client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -385,7 +385,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_test_launch::<F32, R>(
|
||||
load_tensor_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -417,7 +417,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(5, 1, 1);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -451,7 +451,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -481,7 +481,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 16);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -511,7 +511,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 16, 16);
|
||||
|
||||
load_tensor_test_launch::<F32, R>(
|
||||
load_tensor_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -543,7 +543,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -573,7 +573,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -603,7 +603,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -636,7 +636,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(m, k, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -667,7 +667,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -699,7 +699,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, k, n);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -70,7 +70,7 @@ pub mod tests {
|
|||
const SOME_DIM: usize = 12;
|
||||
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
|
||||
|
||||
tile_outer_product_test_launch::<F32, R>(
|
||||
tile_outer_product_test::launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -99,7 +99,7 @@ pub mod tests {
|
|||
const SOME_DIM: usize = 12;
|
||||
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
|
||||
|
||||
tile_outer_product_test_launch::<F32, R>(
|
||||
tile_outer_product_test::launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -14,10 +14,7 @@ use crate::kernel::matmul::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::base::{
|
||||
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
|
||||
BlockLoader, BlockWriter,
|
||||
};
|
||||
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
|
||||
|
||||
pub(crate) struct HorizontalCheckBlockIO;
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ use crate::kernel::matmul::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::base::{all_zeros_runtime, all_zeros_runtime_expand, BlockLoader, BlockWriter};
|
||||
use super::base::{all_zeros_runtime, BlockLoader, BlockWriter};
|
||||
|
||||
pub(crate) struct VerticalCheckBlockIO;
|
||||
|
||||
|
|
|
@ -14,10 +14,7 @@ use crate::kernel::matmul::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::base::{
|
||||
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
|
||||
BlockLoader, BlockWriter,
|
||||
};
|
||||
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
|
||||
|
||||
pub(crate) struct WholeCheckBlockIO;
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(6, 8, 8);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -156,7 +156,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 4);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -182,7 +182,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -215,7 +215,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -248,7 +248,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(5, 8, 1);
|
||||
|
||||
write_results_to_output_out_of_bounds_test_launch::<F32, R>(
|
||||
write_results_to_output_out_of_bounds_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -11,7 +11,7 @@ pub use binary::*;
|
|||
pub use cast::*;
|
||||
pub use contiguous::*;
|
||||
pub use mask::*;
|
||||
pub use unary::*;
|
||||
pub(crate) use unary::*;
|
||||
|
||||
pub use burn_cube::{Kernel, SUBCUBE_DIM_APPROX};
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use burn_cube::{
|
|||
SUBCUBE_DIM_APPROX,
|
||||
};
|
||||
|
||||
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
|
||||
use super::{index_offset_with_layout, Kernel};
|
||||
|
||||
pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
|
@ -13,7 +13,7 @@ pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
|
|||
fn execute(_input: C, _options: &Self::Options) -> C {
|
||||
unexpanded!();
|
||||
}
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
options: <Self::Options as CubeType>::ExpandType,
|
||||
|
@ -78,7 +78,7 @@ where
|
|||
let is_contiguous = tensor.is_contiguous();
|
||||
|
||||
if tensor.can_mut() && is_contiguous {
|
||||
unary_kernel_launch::<E::Primitive, O, R>(
|
||||
unary_kernel::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -104,7 +104,7 @@ where
|
|||
buffer,
|
||||
);
|
||||
|
||||
unary_kernel_launch::<E::Primitive, O, R>(
|
||||
unary_kernel::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -136,7 +136,7 @@ macro_rules! unary_op {
|
|||
type Options = ();
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
_options: <Self::Options as CubeType>::ExpandType,
|
||||
|
@ -152,7 +152,7 @@ macro_rules! unary_op {
|
|||
type Options = C;
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
scalar: C::ExpandType,
|
||||
|
|
|
@ -334,7 +334,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::exp(input)
|
||||
}
|
||||
execute_expand::<C>(context, input)
|
||||
execute::__expand::<C>(context, input)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -344,7 +344,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::log(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -354,7 +354,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::log1p(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ where
|
|||
fn execute<C: Float>(input: C, scalar: C) -> C {
|
||||
C::powf(input, scalar)
|
||||
}
|
||||
execute_expand::<C>(context, tensor, scalar)
|
||||
execute::__expand::<C>(context, tensor, scalar)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -377,7 +377,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::sqrt(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -387,7 +387,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::abs(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -397,7 +397,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::cos(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -407,7 +407,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::sin(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -417,7 +417,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::tanh(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -427,7 +427,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::erf(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -463,7 +463,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::recip(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ where
|
|||
fn execute<C: Numeric>(input: C) -> C {
|
||||
C::abs(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
let empty = empty_device(client, device, shape);
|
||||
|
||||
#[cube(launch)]
|
||||
pub(crate) fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
|
||||
pub fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
|
||||
if ABSOLUTE_POS >= tensor.len() {
|
||||
return;
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
SUBCUBE_DIM_APPROX,
|
||||
);
|
||||
|
||||
full_kernel_launch::<E::Primitive, R>(
|
||||
full_kernel::launch::<E::Primitive, R>(
|
||||
empty.client.clone(),
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -127,7 +127,7 @@ pub fn add_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ pub fn sub_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs - rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -185,7 +185,7 @@ pub fn mul_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs * rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -214,7 +214,7 @@ pub fn div_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs / rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ where
|
|||
fn execute<C: Numeric>(input: C) -> C {
|
||||
input
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue