Add support to flan-t5 (#840)
This commit is contained in:
parent
9a465e1b26
commit
49d3f7f708
|
@ -6,6 +6,8 @@ use serde::Deserialize;
|
|||
pub enum Activation {
|
||||
#[default]
|
||||
Gelu,
|
||||
#[serde(rename = "gated-gelu")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Elu(f64),
|
||||
}
|
||||
|
@ -14,6 +16,10 @@ impl super::Module for Activation {
|
|||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
// TODO: This is "gelu_new", not the original "gelu".
|
||||
// There's some small numerical difference:
|
||||
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||
Self::NewGelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
}
|
||||
|
|
|
@ -148,27 +148,71 @@ impl T5DenseActDense {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5DenseGatedActDense {
|
||||
wi_0: Linear,
|
||||
wi_1: Linear,
|
||||
wo: Linear,
|
||||
act: Activation,
|
||||
}
|
||||
|
||||
impl T5DenseGatedActDense {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
|
||||
let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
|
||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||
Ok(Self {
|
||||
wi_0,
|
||||
wi_1,
|
||||
wo,
|
||||
act: Activation::NewGelu,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
|
||||
let hidden_linear = self.wi_1.forward(xs)?;
|
||||
let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
|
||||
let xs = self.wo.forward(&xs)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerFF {
|
||||
dense_relu_dense: T5DenseActDense,
|
||||
dense_act: Option<T5DenseActDense>,
|
||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||
layer_norm: T5LayerNorm,
|
||||
}
|
||||
|
||||
impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
// is_gated_act is not supported.
|
||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||
(
|
||||
None,
|
||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
None,
|
||||
)
|
||||
};
|
||||
Ok(Self {
|
||||
dense_relu_dense,
|
||||
dense_act,
|
||||
gated_dense_act,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = self.layer_norm.forward(xs)?;
|
||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||
let ys = match &self.dense_act {
|
||||
Some(dense_act) => dense_act.forward(&ys)?,
|
||||
None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
|
||||
};
|
||||
let xs = (xs + ys)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue