forked from mindspore-Ecosystem/mindspore
commit
f9ee0622ea
|
@ -236,7 +236,15 @@ REG_BPROP_BUILDER("ParallelResizeBilinear").SetUnusedInputs({i2}).SetBody(BODYFU
|
|||
{"half_pixel_centers", MakeValue(false)}});
|
||||
return {dx, ib->ZerosLike(size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SiLU").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto sigmoid_input = ib->Emit("Sigmoid", {x});
|
||||
auto bc_dx = ib->Mul(x, dout);
|
||||
auto bc_dy = ib->Mul(sigmoid_input, dout);
|
||||
auto dx = ib->Emit("SigmoidGrad", {sigmoid_input, bc_dx});
|
||||
return {ib->Add(dx, bc_dy)};
|
||||
});
|
||||
REG_BPROP_BUILDER("DynamicBroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto shp = ib->GetInput(kIndex1);
|
||||
|
|
|
@ -958,8 +958,13 @@ REG_BPROP_BUILDER("ReduceProd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}});
|
||||
auto y = ib->Reshape(ib->Mul(left, right), permuted_shape);
|
||||
auto out = ib->Mul(ib->Transpose(y, InvertPermutation(perm)), grad);
|
||||
auto dx = ib->Reshape(out, input_shape);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
auto x_dtype_id = ib->GetDtypeId(x);
|
||||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For 'ReduceProd', gradient not support for complex type currently.";
|
||||
} else {
|
||||
auto dx = ib->Reshape(out, input_shape);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
}
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReduceMax").SetBody(BODYFUNC(ib) {
|
||||
|
@ -967,8 +972,13 @@ REG_BPROP_BUILDER("ReduceMax").SetBody(BODYFUNC(ib) {
|
|||
auto axis = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
auto x_dtype_id = ib->GetDtypeId(x);
|
||||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For 'ReduceMax', gradient not support for complex type currently.";
|
||||
} else {
|
||||
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
}
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReduceMin").SetBody(BODYFUNC(ib) {
|
||||
|
@ -976,8 +986,13 @@ REG_BPROP_BUILDER("ReduceMin").SetBody(BODYFUNC(ib) {
|
|||
auto axis = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
auto x_dtype_id = ib->GetDtypeId(x);
|
||||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For 'ReduceMin', gradient not support for complex type currently.";
|
||||
} else {
|
||||
auto dx = MinOrMaxGrad(ib, x, GetIntList(axis), out, dout);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
}
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReduceMean").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -1000,8 +1015,13 @@ REG_BPROP_BUILDER("ReduceMean").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
MS_EXCEPTION(ValueError) << "out shape size can not be 0";
|
||||
}
|
||||
auto div_shape = getSize(shape_x) / shape_out_sz;
|
||||
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
auto x_dtype_id = ib->GetDtypeId(x);
|
||||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For 'ReduceMean', gradient not support for complex type currently.";
|
||||
} else {
|
||||
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
}
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ArgMaxWithValue").SetBody(BODYFUNC(ib) {
|
||||
|
|
Loading…
Reference in New Issue