!39171 update groups of transdata in ref output
Merge pull request !39171 from yuchaojie/ir_fusion2
This commit is contained in:
commit
22e8680e53
|
@ -254,7 +254,7 @@ abstract::ShapePtr GetPadShape(const ShapeVector &padding_shape, const ShapeVect
|
|||
AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const AnfNodePtr &node, const KernelSelectPtr &kernel_select,
|
||||
const std::string &input_format, const std::string &dst_format,
|
||||
const std::string &reshape_type, const TypeId &type_id) {
|
||||
const std::string &reshape_type, const TypeId &type_id, int64_t groups) {
|
||||
if (input_format == dst_format) {
|
||||
MS_LOG(INFO) << "Input format[" << input_format << "] is equal to dst format, no need to insert transdata.";
|
||||
return input_node;
|
||||
|
@ -326,10 +326,10 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
|
|||
common::AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
|
||||
}
|
||||
if (spec_format == kOpFormat_FRAC_Z && orig_node->isa<CNode>() &&
|
||||
common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, orig_node->cast<CNodePtr>())) {
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrGroups, orig_node, trans_data);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrFracZGroup, orig_node, trans_data);
|
||||
if (spec_format == kOpFormat_FRAC_Z && groups != 1 &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast<CNodePtr>())) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), trans_data);
|
||||
}
|
||||
// refresh the transdata's format to ori format & dst format
|
||||
RefreshKernelBuildInfo(input_format, dst_format, trans_data, reshape_type, type_id);
|
||||
|
@ -411,6 +411,12 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
}
|
||||
if (op_name == prim::kPrimTranspose->name()) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node);
|
||||
} else if (op_name == prim::kPrimTransData->name()) {
|
||||
if (orig_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, orig_node->cast<CNodePtr>())) {
|
||||
auto fracz_group = common::AnfAlgo::GetNodeAttr<int64_t>(orig_node, kAttrFracZGroup);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fracz_group), trans_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fracz_group), trans_node);
|
||||
}
|
||||
}
|
||||
if (is_dynamic_shape) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node);
|
||||
|
|
|
@ -128,7 +128,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const AnfNodePtr &node, const KernelSelectPtr &kernel_select,
|
||||
const std::string &input_format, const std::string &dst_format,
|
||||
const std::string &reshape_type, const TypeId &type_id = kTypeUnknown);
|
||||
const std::string &reshape_type, const TypeId &type_id = kTypeUnknown,
|
||||
int64_t groups = 1);
|
||||
|
||||
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
|
||||
} // namespace opt
|
||||
|
|
|
@ -125,9 +125,10 @@ AnfNodePtr DealRefOutput::AddAdditionalToRefOutput(const FuncGraphPtr &func_grap
|
|||
kOpFormat_DEFAULT, cur_reshape_type, cur_type);
|
||||
}
|
||||
if (origin_format != kOpFormat_DEFAULT) {
|
||||
int64_t groups = common::AnfAlgo::GetAttrGroups(origin_pair.first, origin_pair.second);
|
||||
auto origin_reshape_type = AnfAlgo::GetOutputReshapeType(origin_pair.first, origin_pair.second);
|
||||
final_node = AddTransOpNodeToGraphWithFormat(func_graph, final_node, final_node, kernel_select, kOpFormat_DEFAULT,
|
||||
origin_format, origin_reshape_type, cur_type);
|
||||
origin_format, origin_reshape_type, cur_type, groups);
|
||||
}
|
||||
final_index = 0;
|
||||
need_refresh_ref_addr = true;
|
||||
|
|
|
@ -79,6 +79,7 @@ bool HasFraczGroupAttrAndSet(const AnfNodePtr &node, size_t index, int64_t group
|
|||
if (node_name == kDependName && index != 0) {
|
||||
return true;
|
||||
}
|
||||
bool has_group_attr = false;
|
||||
if (kInOutOperatorSet.find(node_name) != kInOutOperatorSet.end()) {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
if (index >= input_num) {
|
||||
|
@ -91,24 +92,20 @@ bool HasFraczGroupAttrAndSet(const AnfNodePtr &node, size_t index, int64_t group
|
|||
if (input_num > fz_group_idx.size()) {
|
||||
(void)fz_group_idx.insert(fz_group_idx.cbegin(), input_num - fz_group_idx.size(), 1);
|
||||
}
|
||||
if (fz_group_idx[index] == 1) {
|
||||
fz_group_idx[index] = groups;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(fz_group_idx), cnode);
|
||||
return false;
|
||||
if (fz_group_idx[index] != 1) {
|
||||
has_group_attr = true;
|
||||
}
|
||||
} else {
|
||||
fz_group_idx[index] = groups;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(fz_group_idx), cnode);
|
||||
}
|
||||
fz_group_idx[index] = groups;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(fz_group_idx), cnode);
|
||||
return has_group_attr;
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
|
||||
return true;
|
||||
}
|
||||
has_group_attr = common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), cnode);
|
||||
if (node_name == kTransDataOpName) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), cnode);
|
||||
}
|
||||
return false;
|
||||
return has_group_attr;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -243,7 +240,8 @@ bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
|
|||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
// clear cnode fracz_group first, since the fracz_group info may be out-of-date in later graph of multi-graph scene
|
||||
for (auto &node : node_list) {
|
||||
if (node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, node->cast<CNodePtr>())) {
|
||||
if (node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, node->cast<CNodePtr>()) &&
|
||||
common::AnfAlgo::GetCNodeName(node) != kTransDataOpName) {
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrFracZGroup, node);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue