fix fracz_group transmit in multi graph

This commit is contained in:
yuchaojie 2022-01-18 10:51:51 +08:00
parent 0155e9630a
commit b61381ac3f
2 changed files with 43 additions and 10 deletions

View File

@ -96,8 +96,11 @@ void SetGroupAttr(const ParameterPtr &param, const AnfNodePtr &out_trans, const
// not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
// set the groups and fracz_group attr here for these paired TransData nodes.
if (fz_group > 1) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
if (out_trans->isa<CNode>()) {
// if has transdata after parameter
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
}
if (dest_format == kOpFormat_FRAC_Z) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);

View File

@ -200,6 +200,30 @@ bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
}
return true;
}
bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const ParameterPtr &param) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param);
auto groups = param->fracz_group();
if (groups == 1) {
return false;
}
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<KernelWithIndex> todo{};
auto used_cnodes = GetNeighborFraczNodes(manager, param, 0, groups);
std::copy(used_cnodes.begin(), used_cnodes.end(), std::back_inserter(todo));
while (!todo.empty()) {
KernelWithIndex node_index = todo.back();
if (HasFraczGroupAttrAndSet(node_index.first, node_index.second, groups)) {
todo.pop_back();
continue;
}
auto next_nodes = GetNeighborFraczNodes(manager, node_index.first, node_index.second, groups);
std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(todo));
}
return true;
}
} // namespace
bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
@ -207,17 +231,23 @@ bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
bool changed = false;
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto node : node_list) {
if (node == nullptr || !node->isa<CNode>()) {
if (node == nullptr) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
changed = SetAttrFraczGroup(func_graph, param) || changed;
}
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName ||
node_name == kConv2DBackpropFilterOpName) {
changed = SetAttrFraczGroup(func_graph, cnode) || changed;
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName ||
node_name == kConv2DBackpropFilterOpName) {
changed = SetAttrFraczGroup(func_graph, cnode) || changed;
}
}
}
return changed;