Add a custom softmax implementation. (#744)

* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
This commit is contained in:
Laurent Mazare 2023-09-05 15:20:23 +02:00 committed by GitHub
parent a8410bf35e
commit 1c9e5394a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 18 deletions

View File

@ -198,7 +198,7 @@ impl CrossAttention {
let xs = query.matmul(&(key.t()? * self.scale)?)?;
let xs = {
let _enter = self.span_softmax.enter();
nn::ops::softmax(&xs, D::Minus1)?
nn::ops::softmax_last_dim(&xs)?
};
xs.matmul(&value)?.to_dtype(in_dtype)?
};

View File

@ -12,12 +12,16 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
[features]
default = []

View File

@ -5,19 +5,10 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::quantized::GgmlType;
use candle_core::{Device, Result, Tensor, D};
use candle::quantized::GgmlType;
use candle::{Device, Result, Tensor, D};
use clap::{Parser, Subcommand};
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(xs.shape(), "softmax")?;
let max = xs.max_keepdim(dim)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(dim)?;
num.broadcast_div(&den)
}
trait Benchmark {
type PreProcessData;
type RunResult;
@ -86,12 +77,12 @@ impl Benchmark for Matmul {
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
struct QMatMul;
impl Benchmark for QMatMul {
type PreProcessData = (candle_core::quantized::QMatMul, Tensor);
type PreProcessData = (candle::quantized::QMatMul, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle::quantized::QMatMul::from_qtensor(mm);
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))
}
@ -114,7 +105,24 @@ impl Benchmark for Softmax {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
softmax(d, D::Minus1)
candle_nn::ops::softmax(d, D::Minus1)
}
const ITERS: usize = 100;
}
struct SoftmaxLastDim;
impl Benchmark for SoftmaxLastDim {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
candle_nn::ops::softmax_last_dim(d)
}
const ITERS: usize = 100;
@ -140,6 +148,7 @@ enum Task {
Matmul,
Qmatmul,
Softmax,
SoftmaxLastDim,
}
#[derive(Parser, Debug)]
@ -160,6 +169,7 @@ fn main() -> Result<()> {
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,
}
Ok(())

View File

@ -1,4 +1,5 @@
use candle::{Result, Tensor};
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@ -77,3 +78,69 @@ impl Dropout {
}
}
}
struct SoftmaxLastDim;
impl candle::CustomOp1 for SoftmaxLastDim {
fn name(&self) -> &'static str {
"softmax-last-dim"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
fn softmax<T: candle::WithDType + num_traits::Float>(
src: &[T],
layout: &Layout,
) -> Result<(CpuStorage, Shape)> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let el_count = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let mut dst = vec![T::zero(); el_count];
src.par_chunks(dim_m1)
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut max = T::neg_infinity();
for &s in src.iter() {
max = T::max(s, max)
}
let mut sum_exp = T::zero();
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*s - max).exp();
sum_exp += *d
}
for d in dst.iter_mut() {
*d /= sum_exp
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, Shape::from_dims(dims)))
}
match storage {
CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
_ => candle::bail!("unsupported dtype for softmax {:?}", storage),
}
}
fn cuda_fwd(
&self,
_storage: &candle::CudaStorage,
_layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
candle::bail!("TODO: implement a cuda kernel")
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
if xs.device().is_cpu() {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
} else {
softmax(xs, candle::D::Minus1)
}
}

View File

@ -41,6 +41,16 @@ fn softmax() -> Result<()> {
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?;
assert_eq!(
to_vec3_round(&t2, 4)?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
Ok(())
}