RmsNorm kernel for metal. (#1895)

* RmsNorm kernel for metal.

* Wrapper for the metal kernel.

* Get the ops to actually work.

* Fix, get the tests to pass.
This commit is contained in:
Laurent Mazare 2024-03-21 09:48:56 +01:00 committed by GitHub
parent 74b7f59261
commit 0fddec762e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 1 deletions

View File

@ -750,6 +750,64 @@ pub fn call_last_softmax(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rms_norm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
eps: f32,
input: &Buffer,
input_offset: usize,
alpha: &Buffer,
alpha_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
length,
elements_to_sum,
(input, input_offset),
output,
(alpha, alpha_offset),
eps
)
);
let out_length = length / elements_to_sum;
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,

View File

@ -260,6 +260,59 @@ kernel void NAME(
} \
} \
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
constant float &eps, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = 0; \
while (idx < stop_idx) { \
tmp = tmp + float(src[idx]) * float(src[idx]); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
float inv_norm = 1.0f / norm; \
idx = start_idx + tid; \
while (idx < stop_idx) { \
float val = float(src[idx]) * inv_norm; \
if (alpha != nullptr) { \
val *= float(alpha[idx - start_idx]); \
} \
dst[idx] = T(val); \
idx += block_dim; \
} \
} \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
@ -286,6 +339,8 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
@ -303,4 +358,5 @@ REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
#endif

View File

@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.unwrap();
.map_err(candle::Error::wrap)?;
let newstorage =
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
Ok((newstorage, layout.shape().clone()))
@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm {
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &candle::MetalStorage,
l1: &Layout,
s2: &candle::MetalStorage,
l2: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = s1.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match (s1.dtype(), s2.dtype()) {
(DType::F32, DType::F32) => "rmsnorm_f32",
(DType::F16, DType::F16) => "rmsnorm_f16",
(DType::BF16, DType::BF16) => "rmsnorm_bf16",
(dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
};
if !(l1.is_contiguous() && l2.is_contiguous()) {
candle::bail!("Non contiguous rmsnorm is not implemented");
}
let last_dim = l1.dims()[l1.shape().rank() - 1];
let elem_count = l1.shape().elem_count();
let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
candle_metal_kernels::call_rms_norm(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
self.eps,
s1.buffer(),
l1.start_offset() * s1.dtype().size_in_bytes(),
s2.buffer(),
l2.start_offset() * s2.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
Ok((newstorage, l1.shape().clone()))
}
}
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {