fix group FracZ transdata in multi-graph scene
This commit is contained in:
parent
0ac3cd3aef
commit
45810c2adc
|
@ -386,7 +386,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
||||||
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
||||||
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
|
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
|
||||||
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
|
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
|
||||||
other_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
|
||||||
optimizer->AddPassManager(other_pm);
|
optimizer->AddPassManager(other_pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
kernel_graph->SetExecOrderByDefault();
|
kernel_graph->SetExecOrderByDefault();
|
||||||
|
|
|
@ -78,6 +78,25 @@ void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetGroupAttr(const ParameterPtr ¶m, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans,
|
||||||
|
const std::string &dest_format) {
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
auto fz_group = param->fracz_group();
|
||||||
|
// in the scenario of gradient freezing or infer while training, the parameters are already set with
|
||||||
|
// fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param)
|
||||||
|
// to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are
|
||||||
|
// not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
|
||||||
|
// set the groups and fracz_group attr here for these paired TransData nodes.
|
||||||
|
if (fz_group > 1) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
|
||||||
|
if (dest_format == kOpFormat_FRAC_Z) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
||||||
const KernelSelectPtr &kernel_select) {
|
const KernelSelectPtr &kernel_select) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -99,7 +118,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
||||||
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||||
<< " To DefaultFormat , index: " << index;
|
<< " To DefaultFormat , index: " << index;
|
||||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||||
|
if (real_input->isa<Parameter>()) {
|
||||||
|
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
|
||||||
|
}
|
||||||
|
return transdata;
|
||||||
}
|
}
|
||||||
return input_node;
|
return input_node;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue