feat: add backprop for elu (#1269)
* feat: add backprop for elu * Cosmetic tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
dc68c130e4
commit
7051fb8098
|
@ -554,7 +554,16 @@ impl Tensor {
|
|||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
|
|
@ -246,6 +246,30 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch elu
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||
// y = F.elu(x, alpha=2.0)
|
||||
// print(y)
|
||||
// loss = y.min
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue