Improved mamba model optimized for inference (#1694)

* Sketch the mamba model for inference.

* Complete the forward pass.

* Add the mamba example.

* Optimize the selective-scan part.

* Fix a couple shape mismatches and get inference to work.

* Tweak the readmes.

* More readme tweaks.
This commit is contained in:
Laurent Mazare 2024-02-11 17:04:57 +01:00 committed by GitHub
parent 74497e6bf7
commit 1e26d539d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 533 additions and 2 deletions

View File

@ -67,7 +67,7 @@ We also provide a some command line based examples using state of the art models
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
@ -186,7 +186,7 @@ If you have an addition to this list, please submit a pull request.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
- Minimal Mamba
- Mamba, Minimal Mamba
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.

View File

@ -2,6 +2,9 @@
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
Compared to the mamba example, this version can handle training but is much
slower.
## Running the example
```bash

View File

@ -0,0 +1,17 @@
# candle-mamba: Mamba implementation
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
the transformer architecture. It leverages State Space Models (SSMs) with the
goal of being computationally efficient on long sequences. The implementation is
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
Compared to the mamba-minimal example, this version is far more efficient but
would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
```

View File

@ -0,0 +1,299 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mamba::{Config, Model, State};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let input = Tensor::new(&[next_token], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,211 @@
#![allow(unused)]
/// A fast implementation of mamba for inference only.
/// This is based on: https://github.com/LaurentMazare/mamba.rs
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{RmsNorm, VarBuilder};
const D_CONV: usize = 4;
const D_STATE: usize = 16;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
d_model: usize,
n_layer: usize,
vocab_size: usize,
pad_vocab_size_multiple: usize,
}
impl Config {
fn vocab_size(&self) -> usize {
let pad = self.pad_vocab_size_multiple;
(self.vocab_size + pad - 1) / pad * pad
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
}
fn d_inner(&self) -> usize {
self.d_model * 2
}
}
pub struct State {
hs: Vec<Tensor>,
prev_xs: Vec<[Tensor; D_CONV]>,
pos: usize,
}
impl State {
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
let mut hs = Vec::with_capacity(cfg.n_layer);
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
for _i in 0..cfg.n_layer {
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
hs.push(h);
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
}
Ok(Self {
hs,
prev_xs,
pos: 0,
})
}
}
#[derive(Clone, Debug)]
pub struct MambaBlock {
in_proj: Linear,
conv1d_bias: Tensor,
conv1d_weights: [Tensor; D_CONV],
x_proj: Linear,
dt_proj: Linear,
a_log: Tensor,
d: Tensor,
out_proj: Linear,
dt_rank: usize,
layer_index: usize,
d_inner: usize,
}
impl MambaBlock {
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_inner = cfg.d_inner();
let dt_rank = cfg.dt_rank();
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
let x_proj = linear_no_bias(d_inner, dt_rank + D_STATE * 2, vb.pp("x_proj"))?;
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
let a_log = vb.get((d_inner, D_STATE), "A_log")?;
let d = vb.get(d_inner, "D")?;
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
let conv1d_bias = vb.get(d_inner, "conv1d.bias")?;
let conv1d_weight = vb.get((d_inner, 1, D_CONV), "conv1d.weight")?;
let conv1d_weights = [
conv1d_weight.i((.., 0, 0))?,
conv1d_weight.i((.., 0, 1))?,
conv1d_weight.i((.., 0, 2))?,
conv1d_weight.i((.., 0, 3))?,
];
Ok(Self {
in_proj,
conv1d_bias,
conv1d_weights,
x_proj,
dt_proj,
a_log,
d,
out_proj,
dt_rank,
layer_index,
d_inner,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (b_sz, _dim) = xs.dims2()?;
let li = self.layer_index;
let mut xs = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
let proj_for_silu = xs.remove(1);
state.prev_xs[li][state.pos % D_CONV] = xs.remove(0);
let mut proj_for_conv = self.conv1d_bias.broadcast_as((b_sz, self.d_inner))?;
for d_c in 0..D_CONV {
proj_for_conv = (proj_for_conv
+ self.conv1d_weights[d_c]
.broadcast_mul(&state.prev_xs[li][(d_c + 1 + state.pos) % D_CONV])?)?;
}
let proj_for_conv = candle_nn::ops::silu(&proj_for_conv)?;
// SSM + Selection, we're doing inference here so only need the last step of
// the sequence.
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
let x_proj = self.x_proj.forward(&proj_for_conv)?;
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;
let delta = delta.apply(&self.dt_proj)?;
// softplus
let delta = (delta.exp()? + 1.)?.log()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?;
// Selective scan part
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
let delta = delta
.unsqueeze(D::Minus1)?
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
let a = a.broadcast_as((b_sz, self.d_inner, D_STATE))?;
let b = b.broadcast_as((b_sz, self.d_inner, D_STATE))?;
let proj_for_conv_b =
proj_for_conv
.unsqueeze(D::Minus1)?
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
state.hs[li] = ((&state.hs[li] * (&delta * &a)?.exp()?)? + &delta * &b * &proj_for_conv_b)?;
let ss = (state.hs[li]
.matmul(&c.unsqueeze(D::Minus1)?)?
.squeeze(D::Minus1)?
+ proj_for_conv.broadcast_mul(&d)?)?;
let ys = (ss * candle_nn::ops::silu(&proj_for_silu))?;
ys.apply(&self.out_proj)
}
}
#[derive(Clone, Debug)]
pub struct ResidualBlock {
mixer: MambaBlock,
norm: RmsNorm,
}
impl ResidualBlock {
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?;
Ok(Self { mixer, norm })
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
self.mixer.forward(&xs.apply(&self.norm)?, state)? + xs
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
#[derive(Clone, Debug)]
pub struct Model {
embedding: candle_nn::Embedding,
layers: Vec<ResidualBlock>,
norm_f: RmsNorm,
lm_head: Linear,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
Ok(Self {
embedding,
layers,
norm_f,
lm_head,
})
}
pub fn forward(&self, input_ids: &Tensor, state: &mut State) -> Result<Tensor> {
let _b_size = input_ids.dims1()?;
let mut xs = self.embedding.forward(input_ids)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs, state)?
}
state.pos += 1;
xs.apply(&self.norm_f)?.apply(&self.lm_head)
}
}

View File

@ -13,6 +13,7 @@ pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod mamba;
pub mod marian;
pub mod mistral;
pub mod mixformer;