diff --git a/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.cc b/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.cc index 076d72e157c..104b33bc306 100644 --- a/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.cc +++ b/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.cc @@ -37,6 +37,7 @@ NodePtrList BpropBuilder::Run(const NodePtrList &inputs, const mindspore::HashMa inputs_ptr_ = &inputs; attrs_ptr_ = &attrs; instance_name_ = instance_name; + input_size_ = inputs.size() - kIndex2; return handle.func(this); } diff --git a/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.h b/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.h index 522c4584957..1804af98103 100644 --- a/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.h +++ b/mindspore/ccsrc/frontend/expander/bprop/bprop_irbuilder.h @@ -48,7 +48,6 @@ class COMMON_EXPORT BpropBuilder : public Emitter { /// \brief Run irbuilder to generate a graph NodePtrList Run(const NodePtrList &inputs, const mindspore::HashMap &attrs, const BpropHandle &handle, const std::string &instance_name); - ValuePtr GetAttr(const std::string &attr) const; template S GetAttr(const std::string &attr) const { @@ -113,12 +112,14 @@ class COMMON_EXPORT BpropBuilder : public Emitter { std::string GetInstanceName() const { return instance_name_; } NodePtr TanhGrad(const NodePtr &y, const NodePtr &dy) { return Emit("TanhGrad", {y, dy}); } virtual NodePtr OutZeros(const NodePtr &node) { return ZerosLike(node); } + size_t input_size() const { return input_size_; } protected: std::string name_; std::string instance_name_; const NodePtrList *inputs_ptr_{nullptr}; const mindspore::HashMap *attrs_ptr_{nullptr}; + size_t input_size_; }; class IrBuilder : public BpropBuilder { diff --git a/mindspore/ccsrc/frontend/expander/bprop/grad_ops/common_utils.cc b/mindspore/ccsrc/frontend/expander/bprop/grad_ops/common_utils.cc index b3e8aa56506..737f3f20815 100644 --- a/mindspore/ccsrc/frontend/expander/bprop/grad_ops/common_utils.cc +++ b/mindspore/ccsrc/frontend/expander/bprop/grad_ops/common_utils.cc @@ -259,48 +259,62 @@ std::vector GetIntList(const NodePtr &node) { return GetIntList(value); } -NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, - const NodePtr &dy, size_t shift) { +NodePtr StaticBinopGradCommon(BpropBuilder *ib, const NodePtr &dx, const ShapeArray &shape, + const ShapeArray &broadcast_shape, size_t shift, size_t index, bool *is_dynamic_shape) { + NodePtr reduce_dx = dx; + if (broadcast_shape[kIndex0].empty() || broadcast_shape[kIndex1].empty()) { + if (broadcast_shape[index].empty()) { + if (shift) { + std::vector axis(broadcast_shape[index ^ 1].size()); + std::iota(axis.begin(), axis.begin(), 0LL); + reduce_dx = ib->ReduceSum(reduce_dx, axis); + } else { + reduce_dx = ib->ReduceSum(reduce_dx); + } + } + } else if (!IsDynamic(broadcast_shape[0]) && !IsDynamic(broadcast_shape[1])) { + std::vector> bc_axis = BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1]); + if (!bc_axis[index].empty()) { + reduce_dx = ib->ReduceSum(reduce_dx, bc_axis[index], ib->GetRank(reduce_dx) == shape[index].size()); + } + if (ib->GetRank(reduce_dx) != shape[index].size()) { + reduce_dx = ib->Reshape(reduce_dx, shape[index]); + } + } else { + *is_dynamic_shape = true; + } + return reduce_dx; +} + +NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, const NodePtr &dy, + size_t shift) { // Common grad definition for binary operations with shift. // The function is usually used in backprop op to reduce additional dimensions // created by broadcasting. NodePtrList inputs{x, y}; - ShapeArray shape{ib->GetShape(inputs[0]), ib->GetShape(inputs[1])}; + ShapeArray shape{ib->GetShape(inputs[kIndex0]), ib->GetShape(inputs[kIndex1])}; NodePtrList reduce = {dx, dy}; - if (IsDynamicRank(shape[0]) || IsDynamicRank(shape[1])) { + if (IsDynamicRank(shape[kIndex0]) || IsDynamicRank(shape[kIndex1])) { return DynBinopGradCommon(ib, x, y, dx, dy, shift); } if (shape[kIndex0].size() <= shift && shape[kIndex0].size() == shape[kIndex1].size()) { return reduce; } - ShapeVector broadcast_shape[kDim2]; + ShapeArray broadcast_shape(kDim2); for (size_t i = 0; i < kDim2; i++) { broadcast_shape[i] = ShapeVector(shape[i].begin(), shape[i].end() - shift); } - - if (broadcast_shape[0].empty() || broadcast_shape[1].empty()) { - for (size_t i = 0; i < kDim2; i++) { - if (broadcast_shape[i].empty()) { - if (shift) { - std::vector axis(broadcast_shape[i ^ 1].size()); - std::iota(axis.begin(), axis.begin(), 0LL); - reduce[i] = ib->ReduceSum(reduce[i], axis); - } else { - reduce[i] = ib->ReduceSum(reduce[i]); - } - } - } - } else if (!IsDynamic(broadcast_shape[0]) && !IsDynamic(broadcast_shape[1])) { - std::vector> bc_axis = BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1]); - for (size_t i = 0; i < kDim2; i++) { - if (!bc_axis[i].empty()) { - reduce[i] = ib->ReduceSum(reduce[i], bc_axis[i], ib->GetRank(reduce[i]) == shape[i].size()); - } - if (ib->GetRank(reduce[i]) != shape[i].size()) { - reduce[i] = ib->Reshape(reduce[i], shape[i]); - } - } - } else { + bool is_x_shape_dynamic = false; + bool is_y_shape_dynamic = false; + if (dx != nullptr) { + reduce[kIndex0] = + StaticBinopGradCommon(ib, reduce[kIndex0], shape, broadcast_shape, shift, kIndex0, &is_x_shape_dynamic); + } + if (dy != nullptr) { + reduce[kIndex1] = + StaticBinopGradCommon(ib, reduce[kIndex1], shape, broadcast_shape, shift, kIndex1, &is_y_shape_dynamic); + } + if (is_x_shape_dynamic || is_y_shape_dynamic) { return DynBinopGradCommon(ib, x, y, dx, dy, shift); } return reduce; diff --git a/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_math_ops.cc b/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_math_ops.cc index 2536d345dfe..5a0640e44cf 100644 --- a/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_math_ops.cc +++ b/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_math_ops.cc @@ -89,20 +89,27 @@ NodePtrList MinimumMaximumGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr auto half_ratio = ib->Emit("FillV2", {ib->Shape(dout), ib->Tensor(2, ib->GetDtype(dout))}); auto half_dout = ib->Div(dout, half_ratio); NodePtr equal_mask = ib->Equal(x, y); - auto grad_x = ib->Select(equal_mask, half_dout, dout); - auto grad_y = ib->Select(equal_mask, half_dout, dout); - auto zeros = ib->Emit("FillV2", {ib->Shape(dout), ib->Tensor(0, ib->GetDtype(dout))}); NodePtr is_less = ib->Less(x, y); NodePtr is_greater = ib->Greater(x, y); - if (is_minimum) { - grad_x = ib->Select(is_greater, zeros, grad_x); - grad_y = ib->Select(is_less, zeros, grad_y); - } else { - grad_x = ib->Select(is_less, zeros, grad_x); - grad_y = ib->Select(is_greater, zeros, grad_y); + NodePtr grad_x = nullptr; + NodePtr grad_y = nullptr; + if (x->need_compute_grad_out()) { + grad_x = ib->Select(equal_mask, half_dout, dout); + if (is_minimum) { + grad_x = ib->Select(is_greater, zeros, grad_x); + } else { + grad_x = ib->Select(is_less, zeros, grad_x); + } + } + if (y->need_compute_grad_out()) { + grad_y = ib->Select(equal_mask, half_dout, dout); + if (is_minimum) { + grad_y = ib->Select(is_less, zeros, grad_y); + } else { + grad_y = ib->Select(is_greater, zeros, grad_y); + } } - return BinopGradCommon(ib, x, y, grad_x, grad_y); } @@ -473,7 +480,15 @@ REG_BPROP_BUILDER("Add").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - return BinopGradCommon(ib, x, y, dout, dout); + NodePtr dx = nullptr; + NodePtr dy = nullptr; + if (x->need_compute_grad_out()) { + dx = dout; + } + if (y->need_compute_grad_out()) { + dy = dout; + } + return BinopGradCommon(ib, x, y, dx, dy); }); REG_BPROP_BUILDER("AddExt").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { @@ -526,8 +541,14 @@ REG_BPROP_BUILDER("Mul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { MS_EXCEPTION(TypeError) << "For 'Mul', gradient not support for complex type currently."; } - auto bc_dx = ib->Mul(y, dout); - auto bc_dy = ib->Mul(x, dout); + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + bc_dx = ib->Mul(y, dout); + } + if (y->need_compute_grad_out()) { + bc_dy = ib->Mul(x, dout); + } return BinopGradCommon(ib, x, y, bc_dx, bc_dy); }); @@ -535,7 +556,15 @@ REG_BPROP_BUILDER("Sub").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - return BinopGradCommon(ib, x, y, dout, ib->Emit(kNegOpName, {dout})); + NodePtr dx = nullptr; + NodePtr dy = nullptr; + if (x->need_compute_grad_out()) { + dx = dout; + } + if (y->need_compute_grad_out()) { + dy = ib->Emit(kNegOpName, {dout}); + } + return BinopGradCommon(ib, x, y, dout, dy); }); REG_BPROP_BUILDER("Div").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { @@ -543,14 +572,20 @@ REG_BPROP_BUILDER("Div").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); - auto bc_x = ib->Emit(kDivOpName, {dout, y}); - auto bc_y = -(bc_x * out); + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; auto x_dtype_id = ib->GetDtypeId(x); - if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { - auto result = BinopGradCommon(ib, x, y, bc_x, bc_y); - return {ib->Conj(result[0]), ib->Conj(result[1])}; + bc_dx = ib->Emit(kDivOpName, {dout, y}); + if (y->need_compute_grad_out()) { + bc_dy = -(bc_dx * out); } - return BinopGradCommon(ib, x, y, bc_x, bc_y); + auto result = BinopGradCommon(ib, x, y, bc_dx, bc_dy); + bool is_complex = (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128); + if (is_complex) { + result[kIndex0] = ib->Conj(result[kIndex0]); + result[kIndex1] = x->need_compute_grad_out() ? ib->Conj(result[kIndex1]) : ib->OutZeros(y); + } + return result; }); REG_BPROP_BUILDER("BitwiseAnd").SetUnusedInputs({i0, i1, i2, i3}).SetBody(ReturnZeros); @@ -714,8 +749,14 @@ REG_BPROP_BUILDER("Atan2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); auto tmp = ib->RealDiv(dout, (ib->Add((ib->Emit("Square", {x})), (ib->Emit("Square", {y}))))); - auto bc_dx = ib->Mul(tmp, y); - auto bc_dy = ib->Mul(tmp, (ib->Emit("Neg", {x}))); + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + bc_dx = ib->Mul(tmp, y); + } + if (y->need_compute_grad_out()) { + bc_dy = ib->Mul(tmp, (ib->Emit("Neg", {x}))); + } return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); @@ -781,11 +822,17 @@ REG_BPROP_BUILDER("Pow").SetBody(BODYFUNC(ib) { if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) { MS_EXCEPTION(TypeError) << "For 'Pow', gradient not support for complex type currently."; } - auto bc_dx = ib->Mul((ib->Mul(power, (ib->Pow(x, ib->Sub(power, ib->Tensor(1.0, ib->GetDtype(x))))))), dout); - x = - ib->Select(ib->Less(x, ib->Tensor(0, ib->GetDtype(x))), ib->Fill(1.0, ib->Shape(x), ib->GetDtype(x)->type_id()), x); - auto bc_dpower = ib->Mul((ib->Mul(out, (ib->Log(x)))), dout); - return {BinopGradCommon(ib, x, power, bc_dx, bc_dpower)}; + NodePtrList grad_outputs{ib->input_size(), nullptr}; + if (x->need_compute_grad_out()) { + grad_outputs[kIndex0] = + ib->Mul((ib->Mul(power, (ib->Pow(x, ib->Sub(power, ib->Tensor(1.0, ib->GetDtype(x))))))), dout); + } + if (power->need_compute_grad_out()) { + x = ib->Select(ib->Less(x, ib->Tensor(0, ib->GetDtype(x))), ib->Fill(1.0, ib->Shape(x), ib->GetDtype(x)->type_id()), + x); + grad_outputs[kIndex1] = ib->Mul((ib->Mul(out, (ib->Log(x)))), dout); + } + return {BinopGradCommon(ib, x, power, grad_outputs[kIndex0], grad_outputs[kIndex1])}; }); REG_BPROP_BUILDER("Exp").SetBody(BODYFUNC(ib) { @@ -1148,9 +1195,12 @@ REG_BPROP_BUILDER("RealDiv").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { } auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); - auto bc_x = ib->RealDiv(dout, y); - auto bc_y = -(bc_x * out); - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + auto bc_dx = ib->RealDiv(dout, y); + NodePtr bc_dy = nullptr; + if (y->need_compute_grad_out()) { + bc_dy = -(bc_dx * out); + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("DivNoNan").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { @@ -1158,9 +1208,12 @@ REG_BPROP_BUILDER("DivNoNan").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); - auto bc_x = ib->DivNoNan(dout, y); - auto bc_y = -(bc_x * out); - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + auto bc_dx = ib->DivNoNan(dout, y); + NodePtr bc_dy = nullptr; + if (y->need_compute_grad_out()) { + bc_dy = -(bc_dx * out); + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("Xdivy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { @@ -1168,10 +1221,16 @@ REG_BPROP_BUILDER("Xdivy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); auto x_dtype = ib->GetDtype(x); - auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype); - auto bc_x = (ib->Xdivy(not_zero_x, y)) * dout; - auto bc_y = (ib->Xdivy(-x, ib->Emit("Square", {y}))) * dout; - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype); + bc_dx = (ib->Xdivy(not_zero_x, y)) * dout; + } + if (y->need_compute_grad_out()) { + bc_dy = (ib->Xdivy(-x, ib->Emit("Square", {y}))) * dout; + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("FloorDiv").SetUnusedInputs({i0, i1, i2, i3}).SetBody(ReturnZeros); @@ -1180,11 +1239,16 @@ REG_BPROP_BUILDER("FloorMod").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - auto bc_x = dout; - auto bc_y = (-dout) * (ib->FloorDiv(x, y)); - bc_x = ib->Cast(bc_x, ib->GetDtype(x)); - bc_y = ib->Cast(bc_y, ib->GetDtype(y)); - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + bc_dx = ib->Cast(dout, ib->GetDtype(x)); + } + if (y->need_compute_grad_out()) { + bc_dy = (-dout) * (ib->FloorDiv(x, y)); + bc_dy = ib->Cast(bc_dy, ib->GetDtype(y)); + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("TruncateDiv").SetUnusedInputs({i0, i1, i2, i3}).SetBody(ReturnZeros); @@ -1193,20 +1257,31 @@ REG_BPROP_BUILDER("TruncateMod").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - auto bc_x = dout; - auto bc_y = (-dout) * (ib->Emit("TruncateDiv", {x, y})); - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + bc_dx = dout; + } + if (y->need_compute_grad_out()) { + bc_dy = (-dout) * (ib->Emit("TruncateDiv", {x, y})); + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("Mod").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - auto bc_x = dout; - auto bc_y = (-dout) * (ib->FloorDiv(x, y)); - bc_x = ib->Cast(bc_x, ib->GetDtype(x)); - bc_y = ib->Cast(bc_y, ib->GetDtype(y)); - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + bc_dx = ib->Cast(dout, ib->GetDtype(x)); + } + if (y->need_compute_grad_out()) { + bc_dy = (-dout) * (ib->FloorDiv(x, y)); + bc_dy = ib->Cast(bc_dy, ib->GetDtype(y)); + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("Xlogy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { @@ -1214,10 +1289,16 @@ REG_BPROP_BUILDER("Xlogy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); auto x_dtype = ib->GetDtype(x); - auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype); - auto bc_x = ib->Xlogy(not_zero_x, y) * dout; - auto bc_y = ib->Xdivy(x, y) * dout; - return {BinopGradCommon(ib, x, y, bc_x, bc_y)}; + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype); + bc_dx = ib->Xlogy(not_zero_x, y) * dout; + } + if (y->need_compute_grad_out()) { + bc_dy = ib->Xdivy(x, y) * dout; + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("Sqrt").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { @@ -1307,11 +1388,17 @@ REG_BPROP_BUILDER("Hypot").SetBody(BODYFUNC(ib) { auto x2_f32 = ib->Cast(x2, kFloat32); auto out_f32 = ib->Cast(out, kFloat32); auto dout_f32 = ib->Cast(dout, kFloat32); - auto dx1 = ib->Mul(ib->Div(x1_f32, out_f32), dout_f32); - auto dx2 = ib->Mul(ib->Div(x2_f32, out_f32), dout_f32); + NodePtr dx1 = nullptr; + NodePtr dx2 = nullptr; + if (x1->need_compute_grad_out()) { + dx1 = ib->Mul(ib->Div(x1_f32, out_f32), dout_f32); + } + if (x2->need_compute_grad_out()) { + dx2 = ib->Mul(ib->Div(x2_f32, out_f32), dout_f32); + } auto tmp = BinopGradCommon(ib, x1_f32, x2_f32, dx1, dx2); - auto result_dx1 = ib->Cast(tmp[0], ib->GetDtype(x1)); - auto result_dx2 = ib->Cast(tmp[1], ib->GetDtype(x2)); + auto result_dx1 = x1->need_compute_grad_out() ? ib->Cast(tmp[0], ib->GetDtype(x1)) : ib->OutZeros(x1); + auto result_dx2 = x2->need_compute_grad_out() ? ib->Cast(tmp[1], ib->GetDtype(x2)) : ib->OutZeros(x2); return {result_dx1, result_dx2}; }); @@ -1920,7 +2007,15 @@ REG_BPROP_BUILDER("AddV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - return {BinopGradCommon(ib, x, y, dout, dout)}; + NodePtr bc_dx = nullptr; + NodePtr bc_dy = nullptr; + if (x->need_compute_grad_out()) { + bc_dx = dout; + } + if (y->need_compute_grad_out()) { + bc_dy = dout; + } + return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)}; }); REG_BPROP_BUILDER("Addcdiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { @@ -2191,20 +2286,22 @@ REG_BPROP_BUILDER("BatchMatMul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto w = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - NodePtr dx; - if (ta) { - dx = ib->BatchMatMul(w, dout, tb, true); - } else { - dx = ib->BatchMatMul(dout, w, false, !tb); + NodePtr dx = nullptr; + if (x->need_compute_grad_out()) { + if (ta) { + dx = ib->BatchMatMul(w, dout, tb, true); + } else { + dx = ib->BatchMatMul(dout, w, false, !tb); + } } - - NodePtr dw; - if (tb) { - dw = ib->BatchMatMul(dout, x, true, ta); - } else { - dw = ib->BatchMatMul(x, dout, !ta, false); + NodePtr dw = nullptr; + if (w->need_compute_grad_out()) { + if (tb) { + dw = ib->BatchMatMul(dout, x, true, ta); + } else { + dw = ib->BatchMatMul(x, dout, !ta, false); + } } - return BinopGradCommon(ib, x, w, dx, dw, 2); }); diff --git a/mindspore/ccsrc/include/common/expander/core/node.h b/mindspore/ccsrc/include/common/expander/core/node.h index 8d91f105243..34d6719f103 100644 --- a/mindspore/ccsrc/include/common/expander/core/node.h +++ b/mindspore/ccsrc/include/common/expander/core/node.h @@ -58,6 +58,7 @@ class COMMON_EXPORT Node : public std::enable_shared_from_this { virtual bool is_used_value() const { MS_EXCEPTION(NotImplementedError) << "Base Node not implement is_used_value() method"; } + virtual bool need_compute_grad_out() const { return true; } protected: // hold the emitter who created this node. @@ -112,10 +113,13 @@ class COMMON_EXPORT FuncNode : public Node { std::string ToString() const override { return value_->ToString(); } void set_debug_info(const std::string &debug_info) override {} std::string debug_info() const override { return ""; } + bool need_compute_grad_out() const override { return need_compute_grad_out_; } + void set_need_compute_grad_out(bool need_compute_grad_out) { need_compute_grad_out_ = need_compute_grad_out; } private: AbstractBasePtr abstract_; InputType input_type_; + bool need_compute_grad_out_{true}; }; using FuncNodePtr = std::shared_ptr; } // namespace expander diff --git a/mindspore/ccsrc/pipeline/pynative/grad/function/func_builder.cc b/mindspore/ccsrc/pipeline/pynative/grad/function/func_builder.cc index c2940eafd44..132557594f6 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/function/func_builder.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/function/func_builder.cc @@ -252,6 +252,7 @@ void FuncBuilder::SetInputs(std::string instance_name, const std::vectorsize() - kSizeTwo; } NodePtrList FuncBuilder::FlattenNode(const NodePtr &input) { diff --git a/mindspore/ccsrc/pipeline/pynative/grad/function/func_grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/function/func_grad.cc index 9d5e747e056..1058f309434 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/function/func_grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/function/func_grad.cc @@ -176,6 +176,27 @@ bool IsValidTensorInput(const ValuePtr &v) { MS_EXCEPTION_IF_NULL(v); return v->isa() || v->isa(); } + +bool IsNeedComputeGrad(const ValuePtr &input) { + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + auto input_tensor = input->cast(); + auto auto_grad_meta_data = input_tensor->auto_grad_meta_data(); + MS_EXCEPTION_IF_NULL(auto_grad_meta_data); + auto variable = auto_grad_meta_data->variable(); + if (variable != nullptr && variable->is_need_grad()) { + return true; + } + } else if (input->isa()) { + auto seq = input->cast(); + if (!seq->value().empty() && !seq->value().front()->isa()) { + return false; + } + return std::any_of(seq->value().begin(), seq->value().end(), + [](const ValuePtr &val) { return IsNeedComputeGrad(val); }); + } + return false; +} } // namespace TensorPtrList FuncBackwardNode::CallBackward(const TensorPtrList &gradients_in) { @@ -187,7 +208,12 @@ TensorPtrList FuncBackwardNode::CallBackward(const TensorPtrList &gradients_in) const std::vector cal_grads_node = func()(&ir_builder); ValuePtrList cal_grads_values; std::transform(cal_grads_node.begin(), cal_grads_node.end(), std::back_inserter(cal_grads_values), - [](const NodePtr &node) { return node->Value(); }); + [](const NodePtr &node) -> ValuePtr { + if (node == nullptr) { + return kNone; + } + return node->Value(); + }); auto gradients = PostProcess(cal_grads_values); MS_LOG(DEBUG) << "End CallBackward: " << name(); return gradients; @@ -197,7 +223,9 @@ NodePtrList FuncBackwardNode::PreProcess(const TensorPtrList &dout, FuncBuilder NodePtrList node_inputs; node_inputs.reserve(op_inputs_.size() + kSizeFive); for (size_t i = 0; i < op_inputs_.size(); ++i) { - (void)node_inputs.emplace_back(emitter->NewFuncNode(op_inputs_[i], input_abstract_[i], grad_type_[i])); + auto func_node = emitter->NewFuncNode(op_inputs_[i], input_abstract_[i], grad_type_[i]); + func_node->set_need_compute_grad_out(IsNeedComputeGrad(op_inputs_[i])); + (void)node_inputs.emplace_back(func_node); } (void)node_inputs.emplace_back(emitter->NewFuncNode(op_output_, out_abstract_, InputType::kOpOutput)); if (dout.size() == kSizeOne) { @@ -730,7 +758,8 @@ void FuncGrad::CheckSensShapeAndType(const ValuePtr &sens_gradient) { } } -void FuncGrad::PruningGradGraph(const TensorPtrList &weights, const GradAttr &grad_attr, const std::vector &grad_position) { +void FuncGrad::PruningGradGraph(const TensorPtrList &weights, const GradAttr &grad_attr, + const std::vector &grad_position) { mindspore::HashSet grad_pos_list{grad_position.begin(), grad_position.end()}; // Pruning inputs by position in grad graph if (grad_attr.get_by_position) { @@ -775,7 +804,7 @@ void FuncGrad::PruningGradGraph(const TensorPtrList &weights, const GradAttr &gr continue; } bool is_need_grad = false; - for (const auto &edge: variable->func_node()->next_edges()) { + for (const auto &edge : variable->func_node()->next_edges()) { is_need_grad = is_need_grad || edge.variable->is_need_grad(); } if (!is_need_grad) {