Add topk sampling. (#1923)
This commit is contained in:
parent
fdfe8fd129
commit
a62a97340c
|
@ -1,24 +1,35 @@
|
|||
use candle::{DType, Error, Result, Tensor};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Sampling {
|
||||
ArgMax,
|
||||
All { temperature: f64 },
|
||||
TopK { k: usize, temperature: f64 },
|
||||
TopP { p: f64, temperature: f64 },
|
||||
}
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
rng: rand::rngs::StdRng,
|
||||
temperature: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
sampling: Sampling,
|
||||
}
|
||||
|
||||
impl LogitsProcessor {
|
||||
pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
Self { rng, sampling }
|
||||
}
|
||||
|
||||
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
|
||||
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
|
||||
None
|
||||
} else {
|
||||
temperature
|
||||
let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
|
||||
let sampling = match temperature {
|
||||
None => Sampling::ArgMax,
|
||||
Some(temperature) => match top_p {
|
||||
None => Sampling::All { temperature },
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
},
|
||||
};
|
||||
Self {
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
temperature,
|
||||
top_p,
|
||||
}
|
||||
Self::from_sampling(seed, sampling)
|
||||
}
|
||||
|
||||
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
|
||||
|
@ -38,14 +49,14 @@ impl LogitsProcessor {
|
|||
Ok(next_token)
|
||||
}
|
||||
|
||||
/// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed
|
||||
/// probability top_p. This way we never sample tokens that have very low probabilities and are
|
||||
/// less likely to go "off the rails".
|
||||
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
|
||||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||
// tokens that exceed probability top_p. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
|
||||
// Sort by descending probability.
|
||||
argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
|
||||
argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
|
||||
// Clamp smaller probabilities to zero.
|
||||
let mut cumsum = 0.;
|
||||
|
@ -60,23 +71,49 @@ impl LogitsProcessor {
|
|||
self.sample_multinomial(prs)
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
|
||||
if top_k >= prs.len() {
|
||||
self.sample_multinomial(prs)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
// Sort by descending probability.
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
let index = self.sample_multinomial(&prs)?;
|
||||
Ok(indices[index as usize] as u32)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = match self.temperature {
|
||||
None => self.sample_argmax(logits)?,
|
||||
Some(temperature) => {
|
||||
let logits = &(&logits / temperature)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(logits)?;
|
||||
let mut prs: Vec<f32> = prs.to_vec1()?;
|
||||
let top_p = self.top_p.unwrap_or(1.);
|
||||
if top_p <= 0.0 || top_p >= 1.0 {
|
||||
let prs = |temperature: f64| -> Result<Vec<f32>> {
|
||||
let logits = (&logits / temperature)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
|
||||
prs.to_vec1()
|
||||
};
|
||||
|
||||
let next_token = match &self.sampling {
|
||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||
Sampling::All { temperature } => {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
}
|
||||
Sampling::TopP { p, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
if *p <= 0.0 || *p >= 1.0 {
|
||||
// simply sample from the predicted probability distribution
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
self.sample_topp(&mut prs, top_p as f32)?
|
||||
self.sample_topp(&mut prs, *p as f32)?
|
||||
}
|
||||
}
|
||||
Sampling::TopK { k, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk(&mut prs, *k)?
|
||||
}
|
||||
};
|
||||
Ok(next_token)
|
||||
}
|
||||
|
|
|
@ -27,3 +27,30 @@ fn sample_with_top_p() -> Result<()> {
|
|||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_with_top_k() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::TopK {
|
||||
k: 1,
|
||||
temperature: 1.0,
|
||||
},
|
||||
);
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 3);
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::TopK {
|
||||
k: 2,
|
||||
temperature: 1.0,
|
||||
},
|
||||
);
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 3);
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue