RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal. * Wrapper for the metal kernel. * Get the ops to actually work. * Fix, get the tests to pass.
This commit is contained in:
parent
74b7f59261
commit
0fddec762e
|
@ -750,6 +750,64 @@ pub fn call_last_softmax(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_rms_norm(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
eps: f32,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
alpha: &Buffer,
|
||||
alpha_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output,
|
||||
(alpha, alpha_offset),
|
||||
eps
|
||||
)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine(
|
||||
device: &Device,
|
||||
|
|
|
@ -260,6 +260,59 @@ kernel void NAME(
|
|||
} \
|
||||
} \
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
constant float &eps, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = 0; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = tmp + float(src[idx]) * float(src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
|
||||
float inv_norm = 1.0f / norm; \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
float val = float(src[idx]) * inv_norm; \
|
||||
if (alpha != nullptr) { \
|
||||
val *= float(alpha[idx - start_idx]); \
|
||||
} \
|
||||
dst[idx] = T(val); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
|
@ -286,6 +339,8 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
|||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
RMSNORM(rmsnorm_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
|
@ -303,4 +358,5 @@ REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
|||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
RMSNORM(rmsnorm_bf16, bfloat)
|
||||
#endif
|
||||
|
|
|
@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let newstorage =
|
||||
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
|
@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm {
|
|||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
s1: &candle::MetalStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::MetalStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = s1.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match (s1.dtype(), s2.dtype()) {
|
||||
(DType::F32, DType::F32) => "rmsnorm_f32",
|
||||
(DType::F16, DType::F16) => "rmsnorm_f16",
|
||||
(DType::BF16, DType::BF16) => "rmsnorm_bf16",
|
||||
(dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
|
||||
};
|
||||
|
||||
if !(l1.is_contiguous() && l2.is_contiguous()) {
|
||||
candle::bail!("Non contiguous rmsnorm is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
||||
let elem_count = l1.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
|
||||
candle_metal_kernels::call_rms_norm(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
self.eps,
|
||||
s1.buffer(),
|
||||
l1.start_offset() * s1.dtype().size_in_bytes(),
|
||||
s2.buffer(),
|
||||
l2.start_offset() * s2.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
||||
Ok((newstorage, l1.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
|
|
Loading…
Reference in New Issue