fix bessel dynamic shape bprop
This commit is contained in:
parent
8559dd7572
commit
523a099439
|
@ -431,7 +431,7 @@ REG_BPROP_BUILDER("BesselI1").SetBody(BODYFUNC(ib) {
|
||||||
auto dout = ib->GetInput(kIndex2);
|
auto dout = ib->GetInput(kIndex2);
|
||||||
auto bessel_i0 = ib->Emit("BesselI0", {x});
|
auto bessel_i0 = ib->Emit("BesselI0", {x});
|
||||||
auto zero = ib->ZerosLike(x);
|
auto zero = ib->ZerosLike(x);
|
||||||
auto one = ib->Fill(1.0, ib->GetShape(x), ib->GetDtype(x)->type_id());
|
auto one = ib->Fill(1.0, ib->Shape(x), ib->GetDtype(x)->type_id());
|
||||||
auto dout_dx = ib->Select(ib->Equal(x, zero), one, ib->Sub(bessel_i0, (ib->Div(out, x))));
|
auto dout_dx = ib->Select(ib->Equal(x, zero), one, ib->Sub(bessel_i0, (ib->Div(out, x))));
|
||||||
auto dx = ib->Mul(dout, dout_dx);
|
auto dx = ib->Mul(dout, dout_dx);
|
||||||
return {dx};
|
return {dx};
|
||||||
|
@ -451,7 +451,7 @@ REG_BPROP_BUILDER("BesselJ1").SetBody(BODYFUNC(ib) {
|
||||||
auto dout = ib->GetInput(kIndex2);
|
auto dout = ib->GetInput(kIndex2);
|
||||||
auto bessel_j0 = ib->Emit("BesselJ0", {x});
|
auto bessel_j0 = ib->Emit("BesselJ0", {x});
|
||||||
auto zero = ib->ZerosLike(x);
|
auto zero = ib->ZerosLike(x);
|
||||||
auto zero_p5 = ib->Fill(0.5, ib->GetShape(x), ib->GetDtype(x)->type_id());
|
auto zero_p5 = ib->Fill(0.5, ib->Shape(x), ib->GetDtype(x)->type_id());
|
||||||
auto dout_dx = ib->Select(ib->Equal(x, zero), zero_p5, ib->Sub(bessel_j0, (ib->Div(out, x))));
|
auto dout_dx = ib->Select(ib->Equal(x, zero), zero_p5, ib->Sub(bessel_j0, (ib->Div(out, x))));
|
||||||
auto dx = ib->Mul(dout, dout_dx);
|
auto dx = ib->Mul(dout, dout_dx);
|
||||||
return {dx};
|
return {dx};
|
||||||
|
|
Loading…
Reference in New Issue