fix group FracZ transdata in multi-graph scene
This commit is contained in:
parent
0ac3cd3aef
commit
45810c2adc
|
@ -386,7 +386,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
||||
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
|
||||
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
|
||||
other_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
optimizer->AddPassManager(other_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -78,6 +78,25 @@ void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
void SetGroupAttr(const ParameterPtr ¶m, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans,
|
||||
const std::string &dest_format) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto fz_group = param->fracz_group();
|
||||
// in the scenario of gradient freezing or infer while training, the parameters are already set with
|
||||
// fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param)
|
||||
// to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are
|
||||
// not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
|
||||
// set the groups and fracz_group attr here for these paired TransData nodes.
|
||||
if (fz_group > 1) {
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
|
||||
if (dest_format == kOpFormat_FRAC_Z) {
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -99,7 +118,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< " To DefaultFormat , index: " << index;
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||
if (real_input->isa<Parameter>()) {
|
||||
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
|
||||
}
|
||||
return transdata;
|
||||
}
|
||||
return input_node;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue