diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_inner_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_inner_ops.cc index a65fb731380..f72af8ba6e6 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_inner_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_inner_ops.cc @@ -236,7 +236,15 @@ REG_BPROP_BUILDER("ParallelResizeBilinear").SetUnusedInputs({i2}).SetBody(BODYFU {"half_pixel_centers", MakeValue(false)}}); return {dx, ib->ZerosLike(size)}; }); - +REG_BPROP_BUILDER("SiLU").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { + auto x = ib->GetInput(kIndex0); + auto dout = ib->GetInput(kIndex2); + auto sigmoid_input = ib->Emit("Sigmoid", {x}); + auto bc_dx = ib->Mul(x, dout); + auto bc_dy = ib->Mul(sigmoid_input, dout); + auto dx = ib->Emit("SigmoidGrad", {sigmoid_input, bc_dx}); + return {ib->Add(dx, bc_dy)}; +}); REG_BPROP_BUILDER("DynamicBroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtrList { auto x = ib->GetInput(kIndex0); auto shp = ib->GetInput(kIndex1); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_math_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_math_ops.cc index b20f2e8e65e..21805a8af9e 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_math_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_math_ops.cc @@ -958,8 +958,13 @@ REG_BPROP_BUILDER("ReduceProd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { {{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}}); auto y = ib->Reshape(ib->Mul(left, right), permuted_shape); auto out = ib->Mul(ib->Transpose(y, InvertPermutation(perm)), grad); - auto dx = ib->Reshape(out, input_shape); - return {dx, ib->ZerosLike(axis)}; + auto x_dtype_id = ib->GetDtypeId(x); + if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { + MS_EXCEPTION(TypeError) << "For 'ReduceProd', gradient not support for complex type currently."; + } else { + auto dx = ib->Reshape(out, input_shape); + return {dx, ib->ZerosLike(axis)}; + } }); REG_BPROP_BUILDER("ReduceMax").SetBody(BODYFUNC(ib) { @@ -967,8 +972,13 @@ REG_BPROP_BUILDER("ReduceMax").SetBody(BODYFUNC(ib) { auto axis = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); - auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout); - return {dx, ib->ZerosLike(axis)}; + auto x_dtype_id = ib->GetDtypeId(x); + if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { + MS_EXCEPTION(TypeError) << "For 'ReduceMax', gradient not support for complex type currently."; + } else { + auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout); + return {dx, ib->ZerosLike(axis)}; + } }); REG_BPROP_BUILDER("ReduceMin").SetBody(BODYFUNC(ib) { @@ -976,8 +986,13 @@ REG_BPROP_BUILDER("ReduceMin").SetBody(BODYFUNC(ib) { auto axis = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); - auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout); - return {dx, ib->ZerosLike(axis)}; + auto x_dtype_id = ib->GetDtypeId(x); + if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { + MS_EXCEPTION(TypeError) << "For 'ReduceMin', gradient not support for complex type currently."; + } else { + auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout); + return {dx, ib->ZerosLike(axis)}; + } }); REG_BPROP_BUILDER("ReduceMean").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { @@ -1000,8 +1015,13 @@ REG_BPROP_BUILDER("ReduceMean").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { MS_EXCEPTION(ValueError) << "out shape size can not be 0"; } auto div_shape = getSize(shape_x) / shape_out_sz; - auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad))); - return {dx, ib->ZerosLike(axis)}; + auto x_dtype_id = ib->GetDtypeId(x); + if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { + MS_EXCEPTION(TypeError) << "For 'ReduceMean', gradient not support for complex type currently."; + } else { + auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad))); + return {dx, ib->ZerosLike(axis)}; + } }); REG_BPROP_BUILDER("ArgMaxWithValue").SetBody(BODYFUNC(ib) {