fix group FracZ transdata in multi-graph scene

This commit is contained in:
yuchaojie 2021-07-01 21:20:46 +08:00
parent 0ac3cd3aef
commit 45810c2adc
2 changed files with 24 additions and 2 deletions

View File

@ -386,7 +386,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
other_pm->AddPass(std::make_shared<EliminateRedundantOp>());
optimizer->AddPassManager(other_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -78,6 +78,25 @@ void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
}
}
void SetGroupAttr(const ParameterPtr &param, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans,
const std::string &dest_format) {
MS_EXCEPTION_IF_NULL(param);
auto fz_group = param->fracz_group();
// in the scenario of gradient freezing or infer while training, the parameters are already set with
// fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param)
// to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are
// not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
// set the groups and fracz_group attr here for these paired TransData nodes.
if (fz_group > 1) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
if (dest_format == kOpFormat_FRAC_Z) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
}
}
}
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
@ -99,7 +118,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
if (real_input->isa<Parameter>()) {
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
}
return transdata;
}
return input_node;
}