Transpose the rhs in linear.
This commit is contained in:
parent
df6667ba88
commit
a35a935118
|
@ -27,13 +27,14 @@ pub struct Linear {
|
|||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
let weight = weight.t().unwrap().contiguous().unwrap();
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?,
|
||||
_ => self.weight.clone(),
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
|
|
Loading…
Reference in New Issue