fix transdata insert cast

This commit is contained in:
WilliamLian 2020-06-18 15:59:09 +08:00
parent e32d539b5f
commit e3a26c2229
4 changed files with 16 additions and 27 deletions

View File

@ -54,7 +54,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr trans_data = nullptr;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
MS_EXCEPTION_IF_NULL(node);
// if insert transdata for input we need to change the input
@ -63,7 +62,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
}
auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
@ -95,7 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
trans_node = reshape_node;
}
// refresh the transdata's format to ori format & dst format
RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis);
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
return trans_node;
}
@ -162,22 +160,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
return make_tuple;
}
} // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetInputReshapeType({reshape_type});
builder.SetInputReshapeType({reshape_type});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({device_type});
builder.SetOutputsDeviceType({device_type});
builder.SetKernelType(ori_build_info->kernel_type());
builder.SetFusionType(ori_build_info->fusion_type());
builder.SetProcessor(ori_build_info->processor());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get());
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
MS_EXCEPTION_IF_NULL(ori_build_info);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
builder->SetInputsFormat({input_format});
builder->SetInputReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type});
builder->SetOutputsFormat({output_format});
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
}
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,

View File

@ -70,7 +70,7 @@ class KernelQuery {
}
};
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,

View File

@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node);
RefreshKernelBuildInfo(cur_format, origin_format, final_node);
final_index = 0;
MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();

View File

@ -69,13 +69,11 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
// trans input_format to hwcn
new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transdata_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node);
// trans hwcn to default_format
new_transpose_node =
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
new_transpose_node);
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
new_replace_node = new_transpose_node;
} else {
@ -83,14 +81,12 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::kPrimTranspose->name());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
new_transpose_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
// trans hwcn to output_format
new_transdata_node =
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transdata_node);
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node);
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();