fix group FracZ transdata in multi-graph scene

This commit is contained in:
yuchaojie 2021-07-01 21:20:46 +08:00
parent 0ac3cd3aef
commit 45810c2adc
2 changed files with 24 additions and 2 deletions

View File

@ -386,7 +386,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<RefreshParameterFormat>()); other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
other_pm->AddPass(std::make_shared<SplitOpOptimizer>()); other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>()); other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
other_pm->AddPass(std::make_shared<EliminateRedundantOp>());
optimizer->AddPassManager(other_pm); optimizer->AddPassManager(other_pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();

View File

@ -78,6 +78,25 @@ void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
} }
} }
void SetGroupAttr(const ParameterPtr &param, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans,
const std::string &dest_format) {
MS_EXCEPTION_IF_NULL(param);
auto fz_group = param->fracz_group();
// in the scenario of gradient freezing or infer while training, the parameters are already set with
// fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param)
// to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are
// not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
// set the groups and fracz_group attr here for these paired TransData nodes.
if (fz_group > 1) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
if (dest_format == kOpFormat_FRAC_Z) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
}
}
}
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
const KernelSelectPtr &kernel_select) { const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -99,7 +118,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index; << " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
if (real_input->isa<Parameter>()) {
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
}
return transdata;
} }
return input_node; return input_node;
} }