!49292 Fix bprop bug

Merge pull request !49292 from ZengZitao/add_silu
This commit is contained in:
i-robot 2023-02-24 01:15:09 +00:00 committed by Gitee
commit f9ee0622ea
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 37 additions and 9 deletions

View File

@ -236,7 +236,15 @@ REG_BPROP_BUILDER("ParallelResizeBilinear").SetUnusedInputs({i2}).SetBody(BODYFU
{"half_pixel_centers", MakeValue(false)}});
return {dx, ib->ZerosLike(size)};
});
REG_BPROP_BUILDER("SiLU").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto sigmoid_input = ib->Emit("Sigmoid", {x});
auto bc_dx = ib->Mul(x, dout);
auto bc_dy = ib->Mul(sigmoid_input, dout);
auto dx = ib->Emit("SigmoidGrad", {sigmoid_input, bc_dx});
return {ib->Add(dx, bc_dy)};
});
REG_BPROP_BUILDER("DynamicBroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto x = ib->GetInput(kIndex0);
auto shp = ib->GetInput(kIndex1);

View File

@ -958,8 +958,13 @@ REG_BPROP_BUILDER("ReduceProd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}});
auto y = ib->Reshape(ib->Mul(left, right), permuted_shape);
auto out = ib->Mul(ib->Transpose(y, InvertPermutation(perm)), grad);
auto dx = ib->Reshape(out, input_shape);
return {dx, ib->ZerosLike(axis)};
auto x_dtype_id = ib->GetDtypeId(x);
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
MS_EXCEPTION(TypeError) << "For 'ReduceProd', gradient not support for complex type currently.";
} else {
auto dx = ib->Reshape(out, input_shape);
return {dx, ib->ZerosLike(axis)};
}
});
REG_BPROP_BUILDER("ReduceMax").SetBody(BODYFUNC(ib) {
@ -967,8 +972,13 @@ REG_BPROP_BUILDER("ReduceMax").SetBody(BODYFUNC(ib) {
auto axis = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex3);
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
return {dx, ib->ZerosLike(axis)};
auto x_dtype_id = ib->GetDtypeId(x);
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
MS_EXCEPTION(TypeError) << "For 'ReduceMax', gradient not support for complex type currently.";
} else {
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
return {dx, ib->ZerosLike(axis)};
}
});
REG_BPROP_BUILDER("ReduceMin").SetBody(BODYFUNC(ib) {
@ -976,8 +986,13 @@ REG_BPROP_BUILDER("ReduceMin").SetBody(BODYFUNC(ib) {
auto axis = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex3);
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
return {dx, ib->ZerosLike(axis)};
auto x_dtype_id = ib->GetDtypeId(x);
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
MS_EXCEPTION(TypeError) << "For 'ReduceMin', gradient not support for complex type currently.";
} else {
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
return {dx, ib->ZerosLike(axis)};
}
});
REG_BPROP_BUILDER("ReduceMean").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
@ -1000,8 +1015,13 @@ REG_BPROP_BUILDER("ReduceMean").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
MS_EXCEPTION(ValueError) << "out shape size can not be 0";
}
auto div_shape = getSize(shape_x) / shape_out_sz;
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
return {dx, ib->ZerosLike(axis)};
auto x_dtype_id = ib->GetDtypeId(x);
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
MS_EXCEPTION(TypeError) << "For 'ReduceMean', gradient not support for complex type currently.";
} else {
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
return {dx, ib->ZerosLike(axis)};
}
});
REG_BPROP_BUILDER("ArgMaxWithValue").SetBody(BODYFUNC(ib) {