Transpose the rhs in linear.

This commit is contained in:
laurent 2023-08-03 16:51:24 +01:00
parent df6667ba88
commit a35a935118
1 changed files with 3 additions and 2 deletions

View File

@ -27,13 +27,14 @@ pub struct Linear {
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let weight = weight.t().unwrap().contiguous().unwrap();
Self { weight, bias }
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
&[bsize, _, _] => self.weight.broadcast_left(bsize)?,
_ => self.weight.clone(),
};
let x = x.matmul(&w)?;
match &self.bias {