Fixed TP sharded version.
This commit is contained in:
parent
1735e4831e
commit
ed58de7551
|
@ -3,7 +3,7 @@ members = [
|
|||
"candle-core",
|
||||
"candle-examples",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
# "candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/whisper",
|
||||
|
|
|
@ -103,7 +103,7 @@ pub enum Op {
|
|||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1: Send + Sync {
|
||||
pub trait CustomOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
|
|
|
@ -247,20 +247,24 @@ fn main() -> Result<()> {
|
|||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
if rank == 0 {
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
if rank == 0 {
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use cudarc::driver::safe::CudaSlice;
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use std::collections::HashMap;
|
||||
|
@ -27,14 +27,42 @@ struct TensorParallelRowLinear {
|
|||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
struct AllReduce {
|
||||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
fn name(&self) -> &'static str {
|
||||
"allreduce"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
todo!("implement allreduce for cpu is not necessary for single node");
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s: &candle::CudaStorage,
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
Ok(x.clone())
|
||||
// let n = x.shape().elem_count();
|
||||
// let cuda_slice: CudaSlice<f16> = x.try_into()?;
|
||||
// let dev = cuda_slice.device();
|
||||
// let mut slice_receive = dev.alloc_zeros(n).unwrap();
|
||||
// comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap();
|
||||
// Tensor::from_raw_storage(slice_receive, x.shape())
|
||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
|
@ -187,11 +215,11 @@ impl RmsNorm {
|
|||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let size = self.scale.shape().dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
|
@ -213,7 +241,7 @@ struct CausalSelfAttention {
|
|||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
|
||||
let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
|
@ -227,7 +255,7 @@ impl CausalSelfAttention {
|
|||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let (b_sz, seq_len, _) = x.shape().r3()?;
|
||||
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
||||
|
||||
let qkv = self.qkv_proj.forward(x)?;
|
||||
let n_embd = self.n_head * self.head_dim;
|
||||
|
@ -302,7 +330,7 @@ impl CausalSelfAttention {
|
|||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
|
@ -312,10 +340,6 @@ impl CausalSelfAttention {
|
|||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||
|
||||
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
||||
vb.clone(),
|
||||
&["q_proj", "k_proj", "v_proj"],
|
||||
|
@ -364,9 +388,7 @@ impl Mlp {
|
|||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
||||
|
@ -433,7 +455,7 @@ impl Llama {
|
|||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||
let (_b_sz, seq_len) = x.shape().dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
|
|
Loading…
Reference in New Issue