forked from mindspore-Ecosystem/mindspore
fix transdata insert cast
This commit is contained in:
parent
e32d539b5f
commit
e3a26c2229
|
@ -54,7 +54,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
CNodePtr trans_data = nullptr;
|
||||
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
|
||||
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if insert transdata for input we need to change the input
|
||||
|
@ -63,7 +62,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
|
||||
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
|
||||
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
||||
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
|
||||
|
@ -95,7 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
trans_node = reshape_node;
|
||||
}
|
||||
// refresh the transdata's format to ori format & dst format
|
||||
RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis);
|
||||
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
|
||||
return trans_node;
|
||||
}
|
||||
|
||||
|
@ -162,22 +160,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
|
||||
MS_EXCEPTION_IF_NULL(trans_data);
|
||||
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
|
||||
auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetInputReshapeType({reshape_type});
|
||||
builder.SetInputReshapeType({reshape_type});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetInputsDeviceType({device_type});
|
||||
builder.SetOutputsDeviceType({device_type});
|
||||
builder.SetKernelType(ori_build_info->kernel_type());
|
||||
builder.SetFusionType(ori_build_info->fusion_type());
|
||||
builder.SetProcessor(ori_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get());
|
||||
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
||||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
||||
builder->SetInputsFormat({input_format});
|
||||
builder->SetInputReshapeType({reshape_type});
|
||||
builder->SetOutputReshapeType({reshape_type});
|
||||
builder->SetOutputsFormat({output_format});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
||||
}
|
||||
|
||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||
|
|
|
@ -70,7 +70,7 @@ class KernelQuery {
|
|||
}
|
||||
};
|
||||
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
|
||||
|
||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||
|
|
|
@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
if (origin_format != cur_format && cur_shape.size() > 1) {
|
||||
auto kernel_select = std::make_shared<KernelSelect>();
|
||||
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node);
|
||||
RefreshKernelBuildInfo(cur_format, origin_format, final_node);
|
||||
final_index = 0;
|
||||
MS_EXCEPTION_IF_NULL(final_node);
|
||||
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
||||
|
|
|
@ -69,13 +69,11 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
// trans input_format to hwcn
|
||||
new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
|
||||
false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
|
||||
new_transdata_node);
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node);
|
||||
// trans hwcn to default_format
|
||||
new_transpose_node =
|
||||
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
|
||||
new_transpose_node);
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
|
||||
new_replace_node = new_transpose_node;
|
||||
} else {
|
||||
|
@ -83,14 +81,12 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
|
||||
false, prim::kPrimTranspose->name());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
|
||||
new_transpose_node);
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
|
||||
|
||||
// trans hwcn to output_format
|
||||
new_transdata_node =
|
||||
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
|
||||
new_transdata_node);
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node);
|
||||
new_replace_node = new_transdata_node;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
|
|
Loading…
Reference in New Issue