!49889 fix bessel dynamic shape bprop
Merge pull request !49889 from luoyang/fix-bessel
This commit is contained in:
commit
37054b1543
|
@ -431,7 +431,7 @@ REG_BPROP_BUILDER("BesselI1").SetBody(BODYFUNC(ib) {
|
|||
auto dout = ib->GetInput(kIndex2);
|
||||
auto bessel_i0 = ib->Emit("BesselI0", {x});
|
||||
auto zero = ib->ZerosLike(x);
|
||||
auto one = ib->Fill(1.0, ib->GetShape(x), ib->GetDtype(x)->type_id());
|
||||
auto one = ib->Fill(1.0, ib->Shape(x), ib->GetDtype(x)->type_id());
|
||||
auto dout_dx = ib->Select(ib->Equal(x, zero), one, ib->Sub(bessel_i0, (ib->Div(out, x))));
|
||||
auto dx = ib->Mul(dout, dout_dx);
|
||||
return {dx};
|
||||
|
@ -451,7 +451,7 @@ REG_BPROP_BUILDER("BesselJ1").SetBody(BODYFUNC(ib) {
|
|||
auto dout = ib->GetInput(kIndex2);
|
||||
auto bessel_j0 = ib->Emit("BesselJ0", {x});
|
||||
auto zero = ib->ZerosLike(x);
|
||||
auto zero_p5 = ib->Fill(0.5, ib->GetShape(x), ib->GetDtype(x)->type_id());
|
||||
auto zero_p5 = ib->Fill(0.5, ib->Shape(x), ib->GetDtype(x)->type_id());
|
||||
auto dout_dx = ib->Select(ib->Equal(x, zero), zero_p5, ib->Sub(bessel_j0, (ib->Div(out, x))));
|
||||
auto dx = ib->Mul(dout, dout_dx);
|
||||
return {dx};
|
||||
|
|
Loading…
Reference in New Issue