forked from mindspore-Ecosystem/mindspore
fix fracz_group transmit in multi graph
This commit is contained in:
parent
0155e9630a
commit
b61381ac3f
|
@ -96,8 +96,11 @@ void SetGroupAttr(const ParameterPtr ¶m, const AnfNodePtr &out_trans, const
|
|||
// not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
|
||||
// set the groups and fracz_group attr here for these paired TransData nodes.
|
||||
if (fz_group > 1) {
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
|
||||
if (out_trans->isa<CNode>()) {
|
||||
// if has transdata after parameter
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
|
||||
}
|
||||
if (dest_format == kOpFormat_FRAC_Z) {
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
|
||||
|
|
|
@ -200,6 +200,30 @@ bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const ParameterPtr ¶m) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto groups = param->fracz_group();
|
||||
if (groups == 1) {
|
||||
return false;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<KernelWithIndex> todo{};
|
||||
auto used_cnodes = GetNeighborFraczNodes(manager, param, 0, groups);
|
||||
std::copy(used_cnodes.begin(), used_cnodes.end(), std::back_inserter(todo));
|
||||
while (!todo.empty()) {
|
||||
KernelWithIndex node_index = todo.back();
|
||||
if (HasFraczGroupAttrAndSet(node_index.first, node_index.second, groups)) {
|
||||
todo.pop_back();
|
||||
continue;
|
||||
}
|
||||
auto next_nodes = GetNeighborFraczNodes(manager, node_index.first, node_index.second, groups);
|
||||
std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(todo));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
|
||||
|
@ -207,17 +231,23 @@ bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
|
|||
bool changed = false;
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
changed = SetAttrFraczGroup(func_graph, param) || changed;
|
||||
}
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName ||
|
||||
node_name == kConv2DBackpropFilterOpName) {
|
||||
changed = SetAttrFraczGroup(func_graph, cnode) || changed;
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName ||
|
||||
node_name == kConv2DBackpropFilterOpName) {
|
||||
changed = SetAttrFraczGroup(func_graph, cnode) || changed;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
|
|
Loading…
Reference in New Issue