diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 5429c18bbc9..1a96f7e1570 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -386,7 +386,6 @@ void AscendBackendOptimization(const std::shared_ptr &kern other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); optimizer->AddPassManager(other_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index e39c50745b1..9fbc60f4ddb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -78,6 +78,25 @@ void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) { } } +void SetGroupAttr(const ParameterPtr ¶m, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans, + const std::string &dest_format) { + MS_EXCEPTION_IF_NULL(param); + auto fz_group = param->fracz_group(); + // in the scenario of gradient freezing or infer while training, the parameters are already set with + // fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param) + // to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are + // not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem, + // set the groups and fracz_group attr here for these paired TransData nodes. + if (fz_group > 1) { + AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans); + AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans); + if (dest_format == kOpFormat_FRAC_Z) { + AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans); + AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans); + } + } +} + AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(node); @@ -99,7 +118,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) << " To DefaultFormat , index: " << index; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); + auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); + if (real_input->isa()) { + SetGroupAttr(real_input->cast(), input_node, transdata, dest_format); + } + return transdata; } return input_node; }