Improve the quantized whisper setup. (#1018)

* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
This commit is contained in:
Laurent Mazare 2023-10-02 17:17:46 +01:00 committed by GitHub
parent e04c789230
commit 089fc3b584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 66 additions and 49 deletions

View File

@ -232,19 +232,25 @@ impl QTensor {
}
#[derive(Clone, Debug)]
pub struct QMatMul(std::sync::Arc<QTensor>);
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let t = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
}
_ => Self::QTensor(qtensor),
};
Ok(t)
}
pub fn from_qtensor(qtensor: QTensor) -> Self {
Self(std::sync::Arc::new(qtensor))
}
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
&self.0
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
}
@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(self.0.as_ref())
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(t) => xs.matmul(&t.t()?),
}
}
}

View File

@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
@ -576,7 +576,7 @@ fn quantized_matmul_q2k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -602,7 +602,7 @@ fn quantized_matmul_q3k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -628,7 +628,7 @@ fn quantized_matmul_q4k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -654,7 +654,7 @@ fn quantized_matmul_q5k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -681,7 +681,7 @@ fn quantized_matmul_q6k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);

View File

@ -484,17 +484,25 @@ fn main() -> Result<()> {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
dataset.get("samples_jfk.wav")?
};
let config = if args.quantized {
repo.get("config-tiny.json")?
let (config, tokenizer, model) = if args.quantized {
let ext = match args.model {
WhichModel::TinyEn => "tiny-en",
WhichModel::Tiny => "tiny",
_ => unimplemented!("no quantized support for {:?}", args.model),
};
(
repo.get(&format!("config-{ext}.json"))?,
repo.get(&format!("tokenizer-{ext}.json"))?,
repo.get(&format!("model-{ext}-q40.gguf"))?,
)
} else {
repo.get("config.json")?
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
)
};
let model = if args.quantized {
repo.get("model-tiny-q40.gguf")?
} else {
repo.get("model.safetensors")?
};
(config, repo.get("tokenizer.json")?, model, sample)
(config, tokenizer, model, sample)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -206,7 +206,7 @@ impl Benchmark for QMatMul {
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle::quantized::QMatMul::from_qtensor(mm);
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))
}

View File

@ -867,7 +867,7 @@ impl PyQTensor {
/// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
/// &RETURNS&: Tensor
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()).map_err(wrap_err)?;
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
Ok(PyTensor(res))
}

View File

@ -33,10 +33,10 @@ struct QMatMul {
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Self {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Self { inner, span }
Ok(Self { inner, span })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
@ -217,14 +217,14 @@ impl ModelWeights {
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@ -243,7 +243,7 @@ impl ModelWeights {
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
layers,
norm,
output: QMatMul::from_qtensor(output),
output: QMatMul::from_qtensor(output)?,
masks: HashMap::new(),
span,
span_output,
@ -294,14 +294,14 @@ impl ModelWeights {
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
@ -320,7 +320,7 @@ impl ModelWeights {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output),
output: QMatMul::from_qtensor(output)?,
masks: HashMap::new(),
span,
span_output,

View File

@ -90,7 +90,7 @@ impl QMatMul {
vb: crate::quantized_var_builder::VarBuilder,
) -> Result<Self> {
let ws = vb.get((in_dim, out_dim), "weight")?;
let inner = candle::quantized::QMatMul::from_arc(ws);
let inner = candle::quantized::QMatMul::from_arc(ws)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}

View File

@ -41,7 +41,7 @@ fn quantized_matmul_neg() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,