Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup. * Fix the config file paths. * Use the standard matmul where possible.
This commit is contained in:
parent
e04c789230
commit
089fc3b584
|
@ -232,19 +232,25 @@ impl QTensor {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let t = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
}
|
||||
_ => Self::QTensor(qtensor),
|
||||
};
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor {
|
|||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(t) => xs.matmul(&t.t()?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
|||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
|
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
|
@ -576,7 +576,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
|||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
|
@ -602,7 +602,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
|||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
|
@ -628,7 +628,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
|||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
|
@ -654,7 +654,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
|||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
|
@ -681,7 +681,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
|
|
|
@ -484,17 +484,25 @@ fn main() -> Result<()> {
|
|||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
dataset.get("samples_jfk.wav")?
|
||||
};
|
||||
let config = if args.quantized {
|
||||
repo.get("config-tiny.json")?
|
||||
let (config, tokenizer, model) = if args.quantized {
|
||||
let ext = match args.model {
|
||||
WhichModel::TinyEn => "tiny-en",
|
||||
WhichModel::Tiny => "tiny",
|
||||
_ => unimplemented!("no quantized support for {:?}", args.model),
|
||||
};
|
||||
(
|
||||
repo.get(&format!("config-{ext}.json"))?,
|
||||
repo.get(&format!("tokenizer-{ext}.json"))?,
|
||||
repo.get(&format!("model-{ext}-q40.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
repo.get("config.json")?
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let model = if args.quantized {
|
||||
repo.get("model-tiny-q40.gguf")?
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
};
|
||||
(config, repo.get("tokenizer.json")?, model, sample)
|
||||
(config, tokenizer, model, sample)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ impl Benchmark for QMatMul {
|
|||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm);
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
}
|
||||
|
|
|
@ -867,7 +867,7 @@ impl PyQTensor {
|
|||
/// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
||||
/// &RETURNS&: Tensor
|
||||
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
|
||||
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
|
||||
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()).map_err(wrap_err)?;
|
||||
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
|
||||
Ok(PyTensor(res))
|
||||
}
|
||||
|
|
|
@ -33,10 +33,10 @@ struct QMatMul {
|
|||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
|
||||
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Self { inner, span }
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -217,14 +217,14 @@ impl ModelWeights {
|
|||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
|
@ -243,7 +243,7 @@ impl ModelWeights {
|
|||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
|
@ -294,14 +294,14 @@ impl ModelWeights {
|
|||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
|
@ -320,7 +320,7 @@ impl ModelWeights {
|
|||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
|
|
|
@ -90,7 +90,7 @@ impl QMatMul {
|
|||
vb: crate::quantized_var_builder::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let ws = vb.get((in_dim, out_dim), "weight")?;
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws);
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
|
|
Loading…
Reference in New Issue