diff --git a/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.cc index bf530bd820e..71ded55e582 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.cc @@ -278,24 +278,30 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std: } } -bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(out); +bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) { + MS_EXCEPTION_IF_NULL(grad_param); - MS_LOG(DEBUG) << "Forward cnode: " << cnode->DebugString(); - auto prim = GetCNodePrimitive(cnode); + MS_LOG(DEBUG) << "Forward cnode: " << grad_param->cnode->DebugString(); + auto prim = GetCNodePrimitive(grad_param->cnode); if (prim == nullptr) { - MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString(); + MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString(); } if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) || IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) { + MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad"; return true; } // anfnode_to_variable_adjoint_ hold out value, to avoid device not release, clear its device_address - auto cloned_value = ShallowCopyTensorValue(out); + auto cloned_value = ShallowCopyTensorValue(grad_param->out); ClearDeviceAddress(cloned_value); AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType); - CNodePtr input_node = ConstructBpropGraphInput(cnode, op_args, out, dout); + auto fn = std::make_shared(tape_, dout); + auto variable_adjoint = std::make_shared(fn, cloned_value); + if (!grad_param->grad_by_value) { + BuildKNode(grad_param, variable_adjoint); + need_do_manager_replace_ = true; + } + CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout); MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString(); std::vector outputs; #ifndef ENABLE_TEST @@ -304,7 +310,7 @@ bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op } else { mindspore::BuildBprop(input_node, &outputs, &users_); if (outputs.empty()) { - MS_LOG(DEBUG) << "The bprop output should not be empty" << cnode->DebugString(); + MS_LOG(DEBUG) << "The bprop output should not be empty" << grad_param->cnode->DebugString(); BuildCustomBpropCNode(input_node, &outputs); } } @@ -316,15 +322,12 @@ bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op } #endif if (outputs.empty()) { - MS_LOG(EXCEPTION) << "the bprop output should not be empty" << cnode->DebugString(); + MS_LOG(EXCEPTION) << "The bprop output should not be empty" << grad_param->cnode->DebugString(); } - auto fn = std::make_shared(tape_, dout); - auto variable_adjoint = std::make_shared(fn, cloned_value); - UpdateNextEdges(fn, cnode, outputs, op_args); - anfnode_to_variable_adjoint_.insert(std::make_pair(cnode, variable_adjoint)); - + UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args); + anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint)); // record last_node for brackpropagate - last_node_ = cnode; + last_node_ = grad_param->cnode; return true; } @@ -350,15 +353,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) { } bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout); } else { - // Set current knode for cnode - k_node = BuildKNode(grad_param); - const auto it = anfnode_to_variable_adjoint_.find(grad_param->cnode); - if (it == anfnode_to_variable_adjoint_.end()) { - MS_LOG(EXCEPTION) << "Can not find cnode " << grad_param->cnode->DebugString(); - } - // Get current cnode all inputs knode - const auto &k_node_list = BuildKNodeListFromPrimalCNode(grad_param->cnode, it->second); - bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, k_node_list, grad_param->out, &dout); + BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list); + bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout); } std::vector outputs; @@ -373,50 +369,10 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) { variable_adjoint->set_k_node(k_node); UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args); anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint)); - has_fbprop_ = true; + need_do_manager_replace_ = true; return true; } -AnfNodePtr AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param) { - AnfNodePtrList node_list; - MS_EXCEPTION_IF_NULL(grad_param); - for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) { - (void)node_list.emplace_back(BuildKNodeForCNodeInput(grad_param->op_args, grad_param->cnode->input(i), i)); - } - auto k_node = tape_->NewCNode(node_list); - k_node->set_abstract(grad_param->out->ToAbstract()->Broaden()); - return k_node; -} - -AnfNodePtrList AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(adjoint); - AnfNodePtrList node_list; - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i)); - if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) { - MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString(); - } - MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node()); - (void)node_list.emplace_back(input_adjoint_iter->second->k_node()); - } - return node_list; -} - -AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node, - size_t input_index) { - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node); - if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) { - MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString(); - } - return input_adjoint_iter->second->k_node(); - } else { - return input_node; - } -} - CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args, const ValuePtr &out, AnfNodePtr *const tape_dout) { // Wrap tuple_getitem(fprop_app, 1) in a FuncGraph and optimize it; @@ -464,7 +420,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, } FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector &grad_position, - const GradAttr &grad_attr, bool build_formal_param) { + const GradAttr &grad_attr) { // Set sens node and weights node SetSensAndWeights(weights, grad_attr.has_sens); @@ -483,41 +439,84 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std:: return tape_; } -CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args, - const ValuePtr &out, const AnfNodePtr &dout) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(out); - MS_EXCEPTION_IF_NULL(dout); - - if (cnode->size() == 0) { - MS_LOG(EXCEPTION) << "cnode do not have inputs"; - } - std::vector node_lists; - (void)node_lists.emplace_back(cnode->input(0)); - for (size_t i = 0; i < op_args.size(); ++i) { - auto v = op_args[i]; - auto node = cnode->input(i + 1); - if (node->isa()) { - node_lists.emplace_back(node); - node->set_abstract(v->ToAbstract()); - continue; +CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) { + MS_EXCEPTION_IF_NULL(grad_param); + std::vector node_list; + (void)node_list.emplace_back(grad_param->cnode->input(0)); + if (grad_param->grad_by_value) { + for (size_t i = 0; i < grad_param->op_args.size(); ++i) { + const auto &v = grad_param->op_args[i]; + auto node = grad_param->cnode->input(i + 1); + if (node->isa()) { + node_list.emplace_back(node); + node->set_abstract(v->ToAbstract()); + continue; + } + auto v_node = NewValueNode(grad_param->op_args[i]); + v_node->set_abstract(grad_param->op_args[i]->ToAbstract()); + node_list.emplace_back(v_node); } - auto v_node = NewValueNode(op_args[i]); - v_node->set_abstract(op_args[i]->ToAbstract()); - node_lists.emplace_back(v_node); + } else { + // Input is a Parameter or cnode, not a value node + BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list); } - auto out_node = NewValueNode(out); - out_node->set_abstract(out->ToAbstract()); - node_lists.emplace_back(out_node); - node_lists.emplace_back(dout); - CNodePtr input_node = tape_->NewCNode(node_lists); - input_node->set_abstract(out->ToAbstract()->Broaden()); + auto out_node = NewValueNode(grad_param->out); + auto out_abs = grad_param->out->ToAbstract()->Broaden(); + out_node->set_abstract(out_abs); + // set out + node_list.emplace_back(out_node); + // set dout + node_list.emplace_back(dout); + auto input_node = tape_->NewCNode(node_list); + input_node->set_abstract(out_abs); return input_node; } -bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args, - const ValuePtr &out) { - return k_cell->KPynativeOp(cnode, op_args, out); +void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args, + std::vector *const node_list) { + MS_EXCEPTION_IF_NULL(cnode); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString(); + if (cnode->input(i)->isa()) { + const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i)); + if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) { + MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString(); + } + MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node()); + (void)node_list->emplace_back(input_adjoint_iter->second->k_node()); + } else { + cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()); + (void)node_list->emplace_back(cnode->input(i)); + } + } +} + +void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) { + MS_EXCEPTION_IF_NULL(grad_param); + AnfNodePtrList node_list; + for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) { + (void)node_list.emplace_back(BuildKNodeForCNodeInput(grad_param->cnode->input(i))); + } + auto k_node = tape_->NewCNode(node_list); + k_node->set_abstract(grad_param->out->ToAbstract()->Broaden()); + VariableNode->set_k_node(k_node); +} + +AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node); + if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) { + MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString(); + } + return input_adjoint_iter->second->k_node(); + } else { + return input_node; + } +} + +bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param) { + return k_cell->KPynativeOp(grad_param); } void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, @@ -992,7 +991,7 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new return; } auto &old_node_users = users_[old_node]; - for (auto pair_node : old_node_users) { + for (const auto &pair_node : old_node_users) { auto cnode = pair_node.first; size_t index = pair_node.second; if (index >= cnode->size()) { @@ -1034,7 +1033,7 @@ void AutoGradCellImpl::ClearDeviceAddress(const ValuePtr &out) { void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) { const auto ¶meters = tape_->parameters(); auto cell_inputs_size = cell_inputs_.size(); - if (has_fbprop_) { + if (need_do_manager_replace_) { MS_LOG(DEBUG) << "Do parameter replace by manager"; auto mng = MakeManager({tape_}, false); auto tr = mng->Transact(); @@ -1051,7 +1050,7 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo (void)tr.Replace(weights[i], parameters[weight_offset + i]); } tr.Commit(); - has_fbprop_ = false; + need_do_manager_replace_ = false; } else { for (size_t i = 0; i < cell_inputs_size; ++i) { Replace(cell_inputs_[i], parameters[i]); @@ -1093,9 +1092,8 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs, } FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &auto_grad_cell, const AnfNodePtrList &weights, - const std::vector &grad_position, const GradAttr &grad_attr, - bool build_formal_param) { - return auto_grad_cell->Finish(weights, grad_position, grad_attr, build_formal_param); + const std::vector &grad_position, const GradAttr &grad_attr) { + return auto_grad_cell->Finish(weights, grad_position, grad_attr); } } // namespace ad } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h b/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h index 9ecf8cbb14b..56817c9e4bc 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/auto_grad.h @@ -42,8 +42,9 @@ struct GradAttr { }; struct GradParam { - GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg = nullptr) - : cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {} + GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg, + bool grad_by_value) + : cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)), grad_by_value(grad_by_value) {} // Primal CNode create by op forward process const CNodePtr cnode; @@ -111,7 +112,7 @@ class AutoGradCellImpl { AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector &input_param_values); ~AutoGradCellImpl() = default; // Reverse connect bprop of op - bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out); + bool KPynativeOp(const GradParamPtr &grad_param); // Reverse connect ms_function or higher order sub bprop funcgraph bool KPynativeWithFProp(const GradParamPtr &grad_param); CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args, const ValuePtr &out, @@ -121,7 +122,7 @@ class AutoGradCellImpl { // Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it // can be grad again. FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector &grad_position, - const GradAttr &grad_attr, bool build_formal_param); + const GradAttr &grad_attr); private: // Last cnode of this Cell, may be a primitive op or cell with user defined bprop. @@ -139,14 +140,13 @@ class AutoGradCellImpl { // Record cnode's input map for tape_ UserType users_; // Flag for ms_funtcion and high order - bool has_fbprop_{false}; + bool need_do_manager_replace_{false}; bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const; std::vector GetNeedGradFlags(const CNodePtr &cnode); // construct input as cnode for expander - CNodePtr ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, - const AnfNodePtr &dout); + CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout); // Back propagate for one node; void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector &dins, const ValuePtrList &op_args); @@ -182,9 +182,10 @@ class AutoGradCellImpl { void ClearDeviceAddress(const ValuePtr &out); // Fbprop - AnfNodePtr BuildKNode(const GradParamPtr &grad_param); - AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint); - AnfNodePtr BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node, size_t input_index); + void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode); + void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args, + std::vector *const node_list); + AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node); }; using AutoGradCellImplPtr = std::shared_ptr; @@ -209,15 +210,13 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs, // else: // each cnode in primal funcgraph is replaced by value node FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &k_cell, const AnfNodePtrList &weights, - const std::vector &grad_position, const GradAttr &grad_attr, - bool build_formal_param = false); + const std::vector &grad_position, const GradAttr &grad_attr); // Grad for each operation. // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. // op_args: the arguments list of each input parameters. // out: the op result. -bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args, - const ValuePtr &out); +bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param); // adjoint bprop form ms_function and high grad void GradPynativeFBprop(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, diff --git a/mindspore/ccsrc/frontend/optimizer/expander.cc b/mindspore/ccsrc/frontend/optimizer/expander.cc index 69f5592058e..c4932bc5ce9 100644 --- a/mindspore/ccsrc/frontend/optimizer/expander.cc +++ b/mindspore/ccsrc/frontend/optimizer/expander.cc @@ -33,14 +33,17 @@ namespace mindspore { /* namespace to support opt */ namespace opt { -bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) { - static const std::map> op2attrs = { - {prim::kPrimBroadcastTo->name(), {kAttrShape}}, - {prim::kPrimReduceMax->name(), {kAttrKeepDims}}, - {prim::kPrimReduceMin->name(), {kAttrKeepDims}}, - {prim::kPrimReduceSum->name(), {kAttrKeepDims}}}; +namespace { +const std::map> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}}, + {prim::kPrimReduceMax->name(), {kAttrKeepDims}}, + {prim::kPrimReduceMin->name(), {kAttrKeepDims}}, + {prim::kPrimReduceSum->name(), {kAttrKeepDims}}}; +} +bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); auto todos = TopoSort(graph->get_return()); + auto mng = Manage({graph}, false); for (const auto &node : todos) { if (!node->isa() || !AnfUtils::IsRealKernel(node)) { continue; @@ -66,7 +69,8 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) { AnfNodePtrList inputs = {NewValueNode(new_prim)}; auto cnode = dyn_cast_ptr(node); (void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend()); - cnode->set_inputs(inputs); + auto new_cnode = graph->NewCNodeInOrder(inputs); + (void)mng->Replace(node, new_cnode); } return true; } diff --git a/mindspore/ccsrc/frontend/optimizer/expander.h b/mindspore/ccsrc/frontend/optimizer/expander.h index ccea68877b2..1825d6d8ecc 100644 --- a/mindspore/ccsrc/frontend/optimizer/expander.h +++ b/mindspore/ccsrc/frontend/optimizer/expander.h @@ -24,6 +24,7 @@ namespace opt { * Try Expand cnode for front end graph. */ AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node); +bool ConvertPrimToPrimPy(const FuncGraphPtr &graph); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc index 69c2c8ff17e..bdff8eae66a 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -33,9 +33,11 @@ std::string GetOpPythonPath(const OperatorName &op_name) { // almost all ops are defined in two main paths const std::string ops_module = OP_PATH; const std::string inner_ops_module = INNER_OP_PATH; + const std::string grad_ops_module = GRAD_OP_PATH; const std::string functional_op_module = FUNCTIONAL_OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); + py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module)); py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module)); if (py::hasattr(inner_mod, common::SafeCStr(op_name))) { @@ -44,9 +46,12 @@ std::string GetOpPythonPath(const OperatorName &op_name) { if (py::hasattr(mod, common::SafeCStr(op_name))) { return ops_module; } + if (py::hasattr(grad_mod, common::SafeCStr(op_name))) { + return grad_ops_module; + } if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) { - MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module - << " don't have op:" << op_name; + MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and " + << functional_op_module << " don't have op:" << op_name; } return functional_op_module; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index d8af9595353..67f0215a58c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -126,6 +126,7 @@ constexpr char REDUCE_OP_ALL[] = "prod"; constexpr char REDUCE_OP_PROD[] = "prod"; constexpr char OP_PATH[] = "mindspore.ops.operations"; constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; +constexpr char GRAD_OP_PATH[] = "mindspore.ops.operations._grad_ops"; constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional"; constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; constexpr char GET_OP_FUNCTION[] = "_get_python_op"; diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index cd9044c609b..4fb1b69aa06 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -200,6 +200,16 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) { irpass.depend_value_elim_, }); OptPassGroupMap map({{"ad_final_opt", bg_final_opt}}); + if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) { + (void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize())); + opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_}; + (void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate)); + opt::OptPassConfig environ_eliminate = opt::OptPassConfig({ + irpass.incorporate_call_, + irpass.incorporate_call_switch_, + }); + (void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate)); + } auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map); MS_EXCEPTION_IF_NULL(resource); auto func_graph = resource->func_graph(); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc index 4f05695eaf0..54cefc8a928 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc @@ -22,10 +22,10 @@ #include "ir/cell.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/debug/trace.h" -#include "frontend/optimizer/ad/prim_bprop_optimizer.h" #include "backend/common/optimizer/helper.h" #include "include/common/utils/convert_utils_py.h" #include "frontend/optimizer/ad/grad.h" +#include "frontend/optimizer/expander.h" #include "pipeline/jit/pass.h" namespace mindspore { @@ -136,10 +136,159 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) { int64_t input = v->cast()->value(); return std::make_shared(input, kInt64); } else { - MS_LOG(EXCEPTION) << "Output is " << v->ToString() << ", abstract " << v->ToAbstract()->Broaden(); + MS_LOG(DEBUG) << "Output is " << v->ToString() << ", abstract " << v->ToAbstract()->Broaden(); + return v; } } +bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) { + if (old_abs == new_abs) { + return false; + } + if (old_abs == nullptr || new_abs == nullptr) { + MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs"; + return true; + } + if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) || + !common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) { + MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs, old abs: " << old_abs->ToString() + << " new abs: " << new_abs->ToString(); + return true; + } + return false; +} + +bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) { + if (v1 == v2) { + return true; + } + if (v1 == nullptr || v2 == nullptr) { + return false; + } + if (v1->isa() && v2->isa()) { + return v1->cast()->ValueEqual(*(v2->cast())); + } + return *v1 == *v2; +} + +bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) { + if (p1 == p2) { + return true; + } + if (p1 == nullptr || p2 == nullptr) { + return false; + } + return p1->key() == p2->key(); +} + +bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const std::vector &new_anf_inputs, + const TopCellInfoPtr &top_cell) { + MS_EXCEPTION_IF_NULL(old_node_info); + auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() + + old_node_info->input_param_infos.size(); + if (old_input_size != new_anf_inputs.size() - 1) { + MS_LOG(DEBUG) << "Graph is dynamic, old input size: " << old_input_size + << " new input_infos: " << (new_anf_inputs.size() - 1); + return true; + } + + for (size_t i = 1; i < new_anf_inputs.size(); i++) { + const auto &new_anf_input = new_anf_inputs[i]; + MS_EXCEPTION_IF_NULL(new_anf_input); + if (new_anf_input->isa()) { + const auto &value_iter = old_node_info->input_values.find(i); + if (value_iter == old_node_info->input_values.end()) { + MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value."; + return true; + } + + if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) { + MS_LOG(DEBUG) << "The " << i << "th input, value is different."; + return true; + } + } else if (new_anf_input->isa()) { + // Compare cnode abstract. + const auto &node_iter = old_node_info->input_cnode_info.find(i); + if (node_iter == old_node_info->input_cnode_info.end()) { + MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode."; + return true; + } + + size_t old_op_index = 0; + AbstractBasePtr old_abs = nullptr; + std::tie(old_op_index, old_abs) = node_iter->second; + if (IsAbsDifferent(old_abs, new_anf_input->abstract())) { + MS_LOG(DEBUG) << "The " << i << "th input, abs is different."; + return true; + } + + // Compare cnode edge. + MS_EXCEPTION_IF_NULL(top_cell); + if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash())) { + MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index: " << old_op_index + << " new op_index: " << top_cell->get_op_index_by_cnode_hash(new_anf_input->hash()); + return true; + } + } else { + // Compare parameter. + if (!new_anf_input->isa()) { + MS_LOG(EXCEPTION) << "new_anf_input: " << new_anf_input->fullname_with_scope() + << " is none of value node, cnode and parameter."; + } + + const auto &node_iter = old_node_info->input_param_infos.find(i); + if (node_iter == old_node_info->input_param_infos.end()) { + MS_LOG(DEBUG) << "The " << i + << "th input is different, cur input is a parameter, old input is not a parameter."; + return true; + } + + const auto ¶m = new_anf_input->cast(); + MS_EXCEPTION_IF_NULL(param); + if (!IsParamInfoEqual(node_iter->second, param->param_info())) { + MS_LOG(DEBUG) << "The " << i << "th input, param info is different."; + return true; + } + } + } + return false; +} + +bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode, + bool is_ms_function_node, const std::string &graph_phase, + const TopCellInfoPtr &top_cell) { + MS_EXCEPTION_IF_NULL(old_node_info); + // 1.Detect ms_function phase + if (is_ms_function_node) { + if (!old_node_info->is_graph_node || graph_phase != old_node_info->graph_phase) { + MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node + << " new is_graph_node: " << is_ms_function_node << " old graph_phase " + << old_node_info->graph_phase << " new graph_phase: " << graph_phase; + return true; + } + return false; + } + + // 2.Detect cnode prim + MS_EXCEPTION_IF_NULL(new_cnode); + auto new_prim = GetCNodePrimitive(new_cnode); + if (!common::IsEqual(new_prim, old_node_info->prim)) { + MS_LOG(DEBUG) << "Graph is dynamic, old prim: " + << (old_node_info->prim == nullptr ? "nullptr" : old_node_info->prim->name()) + << " new prim: " << (new_prim == nullptr ? "nullptr" : new_prim->name()); + return true; + } + + // 3.Detect output abs + if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) { + MS_LOG(DEBUG) << "Graph is dynamic, output_abs is different"; + return true; + } + + // 4.Detect inputs + return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs(), top_cell); +} + FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph) { auto resource = std::make_shared(); resource->set_func_graph(bprop_graph); @@ -243,10 +392,10 @@ void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_ MS_EXCEPTION_IF_NULL(input_args_info); if (is_bprop_top) { // Convert input args to parameters for top cell graph in bprop. - for (size_t i = 0; i < input_args_info->input_arg_id_vec.size(); ++i) { + for (auto &id : input_args_info->input_arg_id_vec) { auto new_param = curr_g()->add_parameter(); - MS_LOG(DEBUG) << "Top bprop graph set input parameter " << input_args_info->input_arg_id_vec[i]; - top_cell()->SetParamNodeMapInGraphInfoMap(input_args_info->input_arg_id_vec[i], new_param); + MS_LOG(DEBUG) << "Top bprop graph set input parameter " << id; + top_cell()->SetParamNodeMapInGraphInfoMap(id, new_param); } return; } @@ -369,8 +518,9 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) { fg->debug_info()->set_name("pynative_forward_graph"); auto resource = std::make_shared(); const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id); - top_cell_ = std::make_shared(input_args_info->grad_order, input_args_info->obj_id, - input_args_info->cell_id, already_run_cell_id, resource, fg); + top_cell_ = + std::make_shared(input_args_info->is_high_order_top_cell, input_args_info->grad_order, + input_args_info->obj_id, input_args_info->cell_id, already_run_cell_id, resource, fg); top_cell_->set_forward_already_run(true); top_cell_->set_is_run_cell(input_args_info->is_run_cell); top_cell_->set_input_args_id(input_args_info->input_args_id); @@ -402,8 +552,7 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string & // Set last output abstract and will be used for sens auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr(); MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); - auto sens_v = ConvertOutputValueToTensor(v); - auto cloned_value = ShallowCopyTensorValue(sens_v); + auto cloned_value = ShallowCopyTensorValue(v); if (!MsContext::GetInstance()->get_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) { AsyncUpdateOutputNodeOfTopCell(output_node, cloned_value); } else { @@ -435,9 +584,11 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) { const auto &cell_id = input_args_info->cell_id; MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id << ", input args info ptr " << input_args_info.get(); - const auto &out_value = input_args_info->out_value; - MS_EXCEPTION_IF_NULL(out_value); - const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value); + bool is_top_cell_end = (cell_id == top_cell()->cell_id()); + if (is_top_cell_end) { + input_args_info->out_value = ConvertOutputValueToTensor(input_args_info->out_value); + } + const auto &out_id = PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value); DoGradForCustomBprop(input_args_info, out_id); // Update bprop grad stack if (grad_is_running_ && !bprop_grad_stack_.empty()) { @@ -450,16 +601,15 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) { } } // Just only dump the last forward graph - bool is_top_cell_end = cell_id == top_cell()->cell_id(); if (is_top_cell_end && MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - curr_g()->set_output(GetInput(out_value, out_id)); + curr_g()->set_output(GetInput(input_args_info->out_value, out_id)); PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g()); } // Reset grad flag and update output node of the outermost cell if (input_args_info->is_grad_topest_cell && is_top_cell_end) { MS_LOG(DEBUG) << "Cur top last cell " << cell_id; (void)PopHighOrderGraphStack(); - SetForwardLastNodeInfo(out_value, out_id); + SetForwardLastNodeInfo(input_args_info->out_value, out_id); top_cell()->ClearCellHookOp(); cell_order_ = 0; // set_grad_flag(false); @@ -468,7 +618,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) { if (is_top_cell_end) { // In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj if (!input_args_info->is_grad_topest_cell) { - SetForwardLastNodeInfo(out_value, out_id); + SetForwardLastNodeInfo(input_args_info->out_value, out_id); } top_cell()->CheckSubCellHookChanged(); CheckNeedCompileGraph(input_args_info); @@ -476,7 +626,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) { } } -void GradExecutor::AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info) { +void GradExecutor::AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info) { const auto fn = [this, input_args_info]() { this->EndGraphImpl(input_args_info); }; auto task = std::make_shared(fn); async_executor_->Push(task); @@ -618,13 +768,12 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob py::gil_scoped_release gil_release; async_executor_->Wait(); } - MS_EXCEPTION_IF_NULL(grad); MS_EXCEPTION_IF_NULL(top_input_args_info_); MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id << ", input args info ptr " << top_input_args_info_.get(); + MS_EXCEPTION_IF_NULL(grad); if (grad->sens_param()) { MS_LOG(DEBUG) << "Get sens param"; - top_input_args_info_->has_sens = true; size_t forward_args_size = args.size() - 1; auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]); const auto &sens_tensor = ConvertOutputValueToTensor(sens_v); @@ -632,17 +781,16 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob if (top_input_args_info_->input_arg_value_vec.size() == args.size()) { top_input_args_info_->input_arg_value_vec.pop_back(); } - (void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v)); + (void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor)); top_input_args_info_->has_sens = true; } + // For async, top can not be change when run SetForwardLastNodeInfo if (pre_top_cell_ != nullptr) { set_top_cell(pre_top_cell_); } - if (!top_cell()->need_compile_graph()) { MS_LOG(DEBUG) << "No need compile graph"; top_cell_list_.pop_back(); - UpdateTopCellInfo(false, false); return; } @@ -671,7 +819,7 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args); MS_EXCEPTION_IF_NULL(bprop_graph); bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true); - bool use_dynamic_shape_process = (forward()->device_target() == kAscendDevice ? false : use_dynamic_shape_process_); + bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_; bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process); MS_EXCEPTION_IF_NULL(top_input_args_info_); bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id)); @@ -776,12 +924,14 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector &input_args, MS_EXCEPTION_IF_NULL(bprop_graph); std::vector tensor_args; size_t input_size = has_sens ? input_args.size() - 1 : input_args.size(); - // Sens may be a value tuple not a single tensor + // Sens may be a value tuple not a single tensor; bprop gradph have only one ses params, so tuple sens can not be + // flatten for (size_t i = 0; i < input_size; ++i) { if (PyNativeAlgo::Common::IsTensor(input_args[i])) { (void)tensor_args.emplace_back(input_args[i]); } } + // No flatten if (has_sens) { (void)tensor_args.emplace_back(input_args[input_size]); } @@ -819,15 +969,9 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector &input_args, FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const vector &w_args, const vector &p_args) { MS_EXCEPTION_IF_NULL(top_input_args_info_); - bool build_formal_param = false; - if (!top_input_args_info_->has_custom_bprop && !top_input_args_info_->is_grad_topest_cell && IsNestedGrad()) { - build_formal_param = true; - need_renormalize_ = true; - } - auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr(); MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); - FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr, build_formal_param); + FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr); MS_EXCEPTION_IF_NULL(bprop_graph); MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size(); @@ -836,7 +980,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); bprop_graph->debug_info()->set_name(ss.str()); UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens); - if (top_cell()->ms_function_flag()) { + if (top_cell()->need_do_final_opt()) { bprop_graph = BpropGraphFinalOpt(bprop_graph); } if (top_input_args_info_->is_grad_topest_cell) { @@ -937,14 +1081,15 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector inputs{NewValueNode(first_grad_fg)}; ValuePtrList weights_args; DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args); - pipeline::ResourcePtr r = std::make_shared(); - r->manager()->AddFuncGraph(first_grad_fg); + if (!opt::ConvertPrimToPrimPy(first_grad_fg)) { + MS_LOG(EXCEPTION) << "Convert PrimitiveC to PrimitivePy failed"; + } + + auto r = std::make_shared(); set_eliminate_forward(false); (void)first_grad_fg->transforms().erase(kGrad); // Do high order @@ -968,10 +1113,12 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector out_v{out_value}; out_value = std::make_shared(out_v); } - auto grad_param = std::make_shared(cnode, input_args, out_value, second_grad_fg); + auto grad_param = std::make_shared(cnode, input_args, out_value, second_grad_fg, + !top_cell()->is_high_order_top_cell()); if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) { MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString(); } + top_cell()->set_need_do_final_opt(true); need_renormalize_ = true; } @@ -986,8 +1133,8 @@ void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const s // Replace inputs param MS_EXCEPTION_IF_NULL(inputs); - for (size_t i = 0; i < forward_args.size(); ++i) { - const auto &id = PyNativeAlgo::Common::GetIdByValue(forward_args[i]); + for (const auto &forward_arg : forward_args) { + const auto &id = PyNativeAlgo::Common::GetIdByValue(forward_arg); const auto it = outer_graph_info->input_params.find(id); if (it != outer_graph_info->input_params.end()) { // Can find in outer graph @@ -997,7 +1144,7 @@ void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const s } else { MS_LOG(DEBUG) << "Can't find input param id " << id; // Inner graph input param not find in outer graph, need add to outer graph - (void)inputs->emplace_back(GetInput(forward_args[i], id)); + (void)inputs->emplace_back(GetInput(forward_arg, id)); } } @@ -1055,12 +1202,10 @@ void GradExecutor::ClearGradRes() { if (top_cell_ != nullptr) { top_cell_->ClearDeviceMemory(); } - if (use_dynamic_shape_process_ || already_run_top_cell_.find(top_cell_->already_run_cell_id()) != already_run_top_cell_.end()) { top_cell_ = nullptr; } - DecreaseGradOrder(); ClearGlobalRes(); } @@ -1320,29 +1465,26 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode std::back_inserter(cloned_op_args), [](const ValuePtr &value) { return ShallowCopyTensorValue(value); }); ValuePtr cloned_out = ShallowCopyTensorValue(op_out); - if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) { - MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name; - } + auto grad_param = + std::make_shared(cnode, cloned_op_args, cloned_out, nullptr, !top_cell()->is_high_order_top_cell()); auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr(); if (!MsContext::GetInstance()->get_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) { - AsyncGradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out); + AsyncGradPynativeOp(auto_grad_cell_ptr, grad_param); } else { - GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out); + GradPynativeOp(auto_grad_cell_ptr, grad_param); } } -void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, - const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const { - if (!ad::GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out)) { +void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, + const ad::GradParamPtr &grad_param) const { + if (!ad::GradPynativeOp(auto_grad_cell_ptr, grad_param)) { MS_LOG(EXCEPTION) << "Failed to run ad grad for op "; } } -void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, - const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const { - const auto fn = [this, auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out]() { - this->GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out); - }; +void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, + const ad::GradParamPtr &grad_param) const { + const auto fn = [this, auto_grad_cell_ptr, grad_param]() { this->GradPynativeOp(auto_grad_cell_ptr, grad_param); }; auto task = std::make_shared(fn); async_executor_->Push(task); } @@ -1608,7 +1750,7 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs); } else { if (!input_node->isa()) { - MS_LOG(EXCEPTION) << "input_node:" << input_node->fullname_with_scope() + MS_LOG(EXCEPTION) << "input_node: " << input_node->fullname_with_scope() << " is none of value node, cnode and parameter."; } const auto ¶m = input_node->cast(); @@ -1626,157 +1768,6 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c (void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info); } -bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) { - if (old_abs == new_abs) { - return false; - } - - if (old_abs == nullptr || new_abs == nullptr) { - MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs"; - return true; - } - - if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) || - !common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) { - MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs, old abs:" << old_abs->ToString() - << " new abs:" << new_abs->ToString(); - return true; - } - return false; -} - -bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) { - if (v1 == v2) { - return true; - } - if (v1 == nullptr || v2 == nullptr) { - return false; - } - - if (v1->isa() && v2->isa()) { - return v1->cast()->ValueEqual(*(v2->cast())); - } - return *v1 == *v2; -} - -bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) { - if (p1 == p2) { - return true; - } - if (p1 == nullptr || p2 == nullptr) { - return false; - } - - return p1->key() == p2->key(); -} - -bool GradExecutor::IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, - const std::vector &new_anf_inputs) const { - MS_EXCEPTION_IF_NULL(old_node_info); - - auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() + - old_node_info->input_param_infos.size(); - if (old_input_size != new_anf_inputs.size() - 1) { - MS_LOG(DEBUG) << "graph is dynamic, old input size:" << old_input_size - << " new input_infos:" << (new_anf_inputs.size() - 1); - return true; - } - - for (size_t i = 1; i < new_anf_inputs.size(); i++) { - const auto &new_anf_input = new_anf_inputs[i]; - MS_EXCEPTION_IF_NULL(new_anf_input); - if (new_anf_input->isa()) { - const auto &value_iter = old_node_info->input_values.find(i); - if (value_iter == old_node_info->input_values.end()) { - MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value."; - return true; - } - - if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) { - MS_LOG(DEBUG) << "The " << i << "th input, value is different."; - return true; - } - } else if (new_anf_input->isa()) { - // Compare cnode abstract. - const auto &node_iter = old_node_info->input_cnode_info.find(i); - if (node_iter == old_node_info->input_cnode_info.end()) { - MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode."; - return true; - } - - size_t old_op_index = 0; - AbstractBasePtr old_abs = nullptr; - std::tie(old_op_index, old_abs) = node_iter->second; - if (IsAbsDifferent(old_abs, new_anf_input->abstract())) { - MS_LOG(DEBUG) << "The " << i << "th input, abs is different."; - return true; - } - - // Compare cnode edge. - if (old_op_index != top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash())) { - MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index:" << old_op_index - << " new op_index:" << top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash()); - return true; - } - } else { - // Compare parameter. - if (!new_anf_input->isa()) { - MS_LOG(EXCEPTION) << "new_anf_input:" << new_anf_input->fullname_with_scope() - << " is none of value node, cnode and parameter."; - } - - const auto &node_iter = old_node_info->input_param_infos.find(i); - if (node_iter == old_node_info->input_param_infos.end()) { - MS_LOG(DEBUG) << "The " << i - << "th input is different, cur input is a parameter, old input is not a parameter."; - return true; - } - - const auto ¶m = new_anf_input->cast(); - MS_EXCEPTION_IF_NULL(param); - if (!IsParamInfoEqual(node_iter->second, param->param_info())) { - MS_LOG(DEBUG) << "The " << i << "th input, param info is different."; - return true; - } - } - } - - return false; -} - -bool GradExecutor::IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, - const CNodePtr &new_cnode, bool is_ms_function_node, - const std::string &graph_phase) const { - MS_EXCEPTION_IF_NULL(new_cnode); - MS_EXCEPTION_IF_NULL(old_node_info); - - // 1.Detect ms_function phase - if (is_ms_function_node != old_node_info->is_graph_node || - (is_ms_function_node && graph_phase != old_node_info->graph_phase)) { - MS_LOG(DEBUG) << "graph is dynamic, old is_graph_node:" << old_node_info->is_graph_node - << " new is_graph_node:" << is_ms_function_node << " old graph_phase" << old_node_info->graph_phase - << " new graph_phase:" << graph_phase; - return true; - } - - // 2.Detect cnode prim - auto new_prim = GetCNodePrimitive(new_cnode); - if (!common::IsEqual(new_prim, old_node_info->prim)) { - MS_LOG(DEBUG) << "graph is dynamic, old prim:" << (old_node_info->prim == nullptr ? 0 : old_node_info->prim->name()) - << " new prim:" << (new_prim == nullptr ? 0 : new_prim->name()); - return true; - } - - // 3.Detect output abs - if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) { - MS_LOG(DEBUG) << "graph is dynamic, output_abs is different"; - return true; - } - - // 4.Detect inputs - return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs()); -} - bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, const std::string &graph_phase) const { MS_EXCEPTION_IF_NULL(cnode); @@ -1789,18 +1780,17 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order()); const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id]; if (node_idx >= dynamic_nodes.size()) { - MS_LOG(DEBUG) << "old dynamic_nodes size:" << dynamic_nodes.size() << " cur node_idx is:" << node_idx + MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << " cur node_idx is: " << node_idx << ", graph is dynamic."; return true; } - if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase)) { - MS_LOG(DEBUG) << "graph is dynamic, node_idx:" << node_idx - << " is different, cnode:" << cnode->fullname_with_scope(); + if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase, top_cell())) { + MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx + << " is different, cnode: " << cnode->fullname_with_scope(); return true; } top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx); - return false; } @@ -1812,9 +1802,9 @@ void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_i use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase); if (use_dynamic_shape_process_) { - MS_LOG(DEBUG) << "cnode:" << cnode->fullname_with_scope() << ",node_idx:" << node_idx - << ",is_ms_function_node:" << is_ms_function_node << ",graph_phase:" << graph_phase - << ",use_dynamic_shape_process_:" << use_dynamic_shape_process_; + MS_LOG(DEBUG) << "Cnode: " << cnode->fullname_with_scope() << ", node_idx: " << node_idx + << ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase + << ", use_dynamic_shape_process_: " << use_dynamic_shape_process_; cell_id_with_dynamic_detect_nodes_.clear(); } } diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.h b/mindspore/ccsrc/pipeline/pynative/grad/grad.h index 5dd56ade1cd..373ea9f4d9b 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.h @@ -103,17 +103,13 @@ class GradExecutor { TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id); void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; - void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args); - void EndGraphImpl(const InputArgsInfoPtr &input_args_info); AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const; - void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info); AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const; void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const; void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector &pre_tensors) const; void ClearRes(); void WorkerJoin() { async_executor_->WorkerJoin(); } - void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false, const std::string &graph_phase = "") const; @@ -126,10 +122,8 @@ class GradExecutor { void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode) const; void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const; - void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, - const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const; - void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode, - const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const; + void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const ad::GradParamPtr &grad_param) const; + void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const ad::GradParamPtr &grad_param) const; void AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const; AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const; void SetBpropGraphJitLevel(const py::object &obj) const; @@ -168,6 +162,9 @@ class GradExecutor { InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args); void NewGraphImpl(const InputArgsInfoPtr &input_args_info); void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info); + void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args); + void EndGraphImpl(const InputArgsInfoPtr &input_args_info); + void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info); void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const; void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out, const InputArgsInfoPtr &input_args_info); @@ -195,10 +192,7 @@ class GradExecutor { const std::string &graph_phase) const; bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, const std::string &graph_phase) const; - bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, - const std::vector &new_anf_inputs) const; - bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode, - bool is_ms_function_node, const std::string &graph_phase) const; + bool grad_flag_{false}; bool grad_is_running_{false}; bool need_renormalize_{false}; @@ -209,7 +203,7 @@ class GradExecutor { // Used in sub thread size_t cell_order_{0}; - std::string cur_cell_id_{""}; + std::string cur_cell_id_; // If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ... size_t grad_order_{0}; diff --git a/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc index 0b879759892..9650481ca73 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/ms_function_grad.cc @@ -272,9 +272,8 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run // Connect grad graph of ms_function to context. auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr(); - MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); - auto grad_param = - std::make_shared(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph); + auto grad_param = std::make_shared(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, + grad_graph, !top_cell->is_high_order_top_cell()); { py::gil_scoped_release gil_release; grad_executor->async_executor()->Wait(); @@ -283,7 +282,7 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: " << ms_function_cnode->DebugString(); } - top_cell->set_ms_function_flag(true); + top_cell->set_need_do_final_opt(true); return ms_function_cnode; } @@ -291,8 +290,7 @@ void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor, const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const ad::GradParamPtr &grad_param) const { MS_EXCEPTION_IF_NULL(grad_executor); - - const auto fn = [this, grad_param, auto_grad_cell_ptr]() { + const auto fn = [grad_param, auto_grad_cell_ptr]() { MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr); if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) { MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode"; diff --git a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc index bf7616d7de0..497318d039c 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.cc @@ -68,8 +68,7 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) { // else: // x = x + self.p // return x - for (size_t i = 0; i < op_run_info->base_op_run_info.input_tensor.size(); i++) { - const auto &t = op_run_info->base_op_run_info.input_tensor[i]; + for (auto &t : op_run_info->base_op_run_info.input_tensor) { MS_EXCEPTION_IF_NULL(t); if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) { input_args_info += "w"; @@ -127,6 +126,22 @@ void TopCellInfo::ClearDeviceMemory() const { } } +void TopCellInfo::Clear() { + MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_; + hook_changed_ = false; + is_init_kpynative_ = false; + need_compile_graph_ = false; + forward_already_run_ = false; + op_index_ = 0; + resource_ = nullptr; + fg_ = nullptr; + graph_info_map_.clear(); + op_info_with_tensor_id_.clear(); + tensor_id_with_tensor_object_.clear(); + op_info_with_ms_func_forward_tensors_.clear(); + cnode_hash_with_op_index_.clear(); +} + void TopCellInfo::DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) { auto &graph_info = graph_info_map().at(g); MS_EXCEPTION_IF_NULL(graph_info); @@ -145,12 +160,12 @@ void TopCellInfo::SetParamNodeMapInGraphInfoMap(const std::string &id, const Par } void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index, - bool save_flag) const { + bool need_save_sub_id) const { auto &graph_info = graph_info_map().at(fg()); MS_EXCEPTION_IF_NULL(graph_info); graph_info->node_map[id] = std::make_pair(node, std::vector{index}); // For example, set id of ((A,B),C) = {CNode, -1} - if (save_flag) { + if (need_save_sub_id) { SetMultipleOutputToGraphInfoMap(id, node); } } @@ -188,23 +203,6 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const } } -void TopCellInfo::Clear() { - MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_; - hook_changed_ = false; - ms_function_flag_ = false; - is_init_kpynative_ = false; - need_compile_graph_ = false; - forward_already_run_ = false; - op_index_ = 0; - resource_ = nullptr; - fg_ = nullptr; - graph_info_map_.clear(); - op_info_with_tensor_id_.clear(); - tensor_id_with_tensor_object_.clear(); - op_info_with_ms_func_forward_tensors_.clear(); - cnode_hash_with_op_index_.clear(); -} - void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node, const std::vector &index) const { auto &graph_info = graph_info_map().at(fg()); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h index 2ef75a8505a..e1b918152d5 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h @@ -58,9 +58,10 @@ using GraphInfoPtr = std::shared_ptr; class TopCellInfo { public: ~TopCellInfo() = default; - TopCellInfo(size_t grad_order, std::string c_cell_id, std::string cellid, std::string already_run_cell_id, - pipeline::ResourcePtr r, FuncGraphPtr fg) - : grad_order_(grad_order), + TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string c_cell_id, std::string cellid, + std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg) + : is_high_order_top_cell_(is_high_order_top_cell), + grad_order_(grad_order), c_cell_id_(std::move(c_cell_id)), cell_id_(std::move(cellid)), already_run_cell_id_(std::move(already_run_cell_id)), @@ -76,12 +77,13 @@ class TopCellInfo { void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op); void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info); inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); } - inline bool ms_function_flag() const { return ms_function_flag_; } - inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; } inline bool forward_already_run() const { return forward_already_run_; } inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } inline bool need_compile_graph() const { return need_compile_graph_; } inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } + inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; } + inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; } + inline bool need_do_final_opt() const { return need_do_final_opt_; } inline pipeline::ResourcePtr resource() const { return resource_; } inline FuncGraphPtr fg() const { MS_EXCEPTION_IF_NULL(fg_); @@ -100,7 +102,7 @@ class TopCellInfo { graph_info_map_[fg] = graph_info; } inline void set_is_run_cell(bool is_run_cell) { is_run_cell_ = is_run_cell; } - inline bool is_run_cell() { return is_run_cell_; } + inline bool is_run_cell() const { return is_run_cell_; } inline const OrderedMap &graph_info_map() const { return graph_info_map_; } inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_); @@ -119,26 +121,25 @@ class TopCellInfo { return op_info_with_ms_func_forward_tensors_; } inline size_t op_index() const { return op_index_; } - inline void IncreaseOpIndex() { op_index_++; } + inline void IncreaseOpIndex() { ++op_index_; } inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) { cnode_hash_with_op_index_[node_hash] = op_index; } - inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) { - auto iter = cnode_hash_with_op_index_.find(node_hash); + inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) const { + const auto iter = cnode_hash_with_op_index_.find(node_hash); if (iter == cnode_hash_with_op_index_.end()) { MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_"; } return iter->second; } - void Clear(); - void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id); void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const; void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1, - bool save_flag = true) const; + bool need_save_sub_id = true) const; void ClearDeviceMemory() const; + void Clear(); private: void SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const; @@ -148,12 +149,13 @@ class TopCellInfo { const std::vector &index) const; bool hook_changed_{false}; - bool ms_function_flag_{false}; bool is_init_kpynative_{false}; bool forward_already_run_{false}; bool need_compile_graph_{false}; bool is_run_cell_{false}; size_t op_index_{0}; + bool is_high_order_top_cell_{false}; + bool need_do_final_opt_{false}; size_t grad_order_{0}; std::string c_cell_id_; std::string cell_id_; diff --git a/mindspore/python/mindspore/ops/_grad/grad_implementations.py b/mindspore/python/mindspore/ops/_grad/grad_implementations.py index 64f8d93b56d..9d5a74d2ab2 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/python/mindspore/ops/_grad/grad_implementations.py @@ -225,3 +225,9 @@ def bprop_scalar_calc(x, y, out, dout): def bprop_scalar_not(x, out, dout): """Backpropagator for primitive `bool_not` and `string_not`.""" return (C.zeros_like(x),) + + +@bprops.register("TensorMove") +def bprop_tensor_move(x, out, dout): + """Backpropagator for primitive `mutable`.""" + return (dout,) diff --git a/tests/st/gradient/test_grad_pynative.py b/tests/st/gradient/test_grad_pynative.py index 445147bae29..266ed162405 100644 --- a/tests/st/gradient/test_grad_pynative.py +++ b/tests/st/gradient/test_grad_pynative.py @@ -148,7 +148,7 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative(): assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy()) -@pytest.mark.level2 +@pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_grad_iteration_function_pynative():