forked from mindspore-Ecosystem/mindspore
!1496 revert decoupled of 1313
Merge pull request !1496 from lianliguang/revert_decoupled
This commit is contained in:
commit
6be8929f62
|
@ -31,156 +31,18 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
namespace {
|
||||
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
||||
std::vector<AnfNodePtr> trans_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
||||
trans_inputs.emplace_back(NewValueNode(prim));
|
||||
trans_inputs.emplace_back(input_node);
|
||||
auto reshape = func_graph->NewCNode(trans_inputs);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
||||
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
|
||||
reshape->set_scope(input_node->scope());
|
||||
kernel_select->SelectKernel(reshape);
|
||||
return reshape;
|
||||
}
|
||||
|
||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
||||
AnfNodePtr trans_node = nullptr;
|
||||
AnfNodePtr input_node = node;
|
||||
CNodePtr trans_data = nullptr;
|
||||
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
|
||||
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if insert transdata for input we need to change the input
|
||||
if (is_insert_input) {
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
|
||||
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
|
||||
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
||||
padding_axis = AnfAlgo::GetInputReshapeType(node, 0);
|
||||
}
|
||||
bool need_padding = false;
|
||||
if (is_insert_input) {
|
||||
need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
|
||||
} else {
|
||||
need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
|
||||
}
|
||||
if (!need_padding) {
|
||||
// don't need padding insert transdata only
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
trans_node = trans_data;
|
||||
} else if (is_insert_input) {
|
||||
// if need padding & is input need insert a transdata
|
||||
// reshape[padding shape] -> transdata[padding shape] -> node
|
||||
auto padding_shape =
|
||||
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
trans_node = trans_data;
|
||||
} else {
|
||||
// if need padding & is output need insert a transdata
|
||||
// node -> transdata[padding shape] -> reshape[ori_shape]
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
auto reshape_node =
|
||||
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
|
||||
trans_node = reshape_node;
|
||||
}
|
||||
// refresh the transdata's format to ori format & dst format
|
||||
RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis);
|
||||
return trans_node;
|
||||
}
|
||||
|
||||
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input_node = AnfAlgo::GetInputNode(node, index);
|
||||
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
auto real_input = node_with_index.first;
|
||||
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
||||
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
AnfAlgo::SetNodeInput(node, input_node, index);
|
||||
}
|
||||
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< "when inserting the transdata node " << node->DebugString();
|
||||
}
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
||||
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
||||
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< " To DefaultFormat , index: " << index;
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||
}
|
||||
return input_node;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
||||
<< node->DebugString();
|
||||
}
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
|
||||
<< node->DebugString();
|
||||
}
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
|
||||
} else {
|
||||
// No need insert trans op.
|
||||
make_tuple_inputs.push_back(tuple_getitem);
|
||||
}
|
||||
}
|
||||
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
|
||||
MS_EXCEPTION_IF_NULL(trans_data);
|
||||
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
|
||||
auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
|
||||
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &node, const TypeId device_type,
|
||||
const kernel::KernelBuildInfo &ori_build_info) {
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetInputReshapeType({reshape_type});
|
||||
builder.SetInputReshapeType({reshape_type});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetInputsDeviceType({device_type});
|
||||
builder.SetOutputsDeviceType({device_type});
|
||||
builder.SetKernelType(ori_build_info->kernel_type());
|
||||
builder.SetFusionType(ori_build_info->fusion_type());
|
||||
builder.SetProcessor(ori_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get());
|
||||
builder.SetKernelType(ori_build_info.kernel_type());
|
||||
builder.SetFusionType(ori_build_info.fusion_type());
|
||||
builder.SetProcessor(ori_build_info.processor());
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||
|
@ -193,7 +55,8 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
trans_inputs.push_back(input);
|
||||
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
|
||||
MS_EXCEPTION_IF_NULL(trans_node);
|
||||
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
||||
std::vector<kernel::Axis> padding_axis;
|
||||
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
||||
if (need_padding) {
|
||||
// if need padding we should set the transdata node's shape to the padding shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
||||
|
@ -216,6 +79,154 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
return trans_node;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
||||
std::vector<AnfNodePtr> trans_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
||||
trans_inputs.emplace_back(NewValueNode(prim));
|
||||
trans_inputs.emplace_back(input_node);
|
||||
auto reshape = func_graph->NewCNode(trans_inputs);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
||||
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
|
||||
reshape->set_scope(input_node->scope());
|
||||
kernel_select->SelectKernel(reshape);
|
||||
return reshape;
|
||||
}
|
||||
|
||||
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input_node = AnfAlgo::GetInputNode(node, index);
|
||||
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
auto real_input = node_with_index.first;
|
||||
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
||||
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
AnfAlgo::SetNodeInput(node, input_node, index);
|
||||
}
|
||||
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< "when inserting the transdata node " << node->DebugString();
|
||||
}
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
||||
std::string origin_format = kOpFormat_DEFAULT;
|
||||
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
||||
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< " To DefaultFormat , index: " << index;
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName,
|
||||
true);
|
||||
}
|
||||
return input_node;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
||||
<< node->DebugString();
|
||||
}
|
||||
std::string origin_format = output_format;
|
||||
std::string dest_format = kOpFormat_DEFAULT;
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName,
|
||||
false);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
|
||||
<< node->DebugString();
|
||||
}
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
std::string dest_format = kOpFormat_DEFAULT;
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format,
|
||||
dest_format, kTransDataOpName, false));
|
||||
} else {
|
||||
// No need insert trans op.
|
||||
make_tuple_inputs.push_back(tuple_getitem);
|
||||
}
|
||||
}
|
||||
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index,
|
||||
const std::string &origin_format, const std::string &dest_format,
|
||||
const std::string &op_name, bool is_insert_input) {
|
||||
AnfNodePtr trans_node = nullptr;
|
||||
AnfNodePtr input_node = node;
|
||||
AnfNodePtr trans_data = nullptr;
|
||||
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (origin_format.empty() || dest_format.empty()) {
|
||||
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
|
||||
}
|
||||
// if insert transdata for input we need to change the input
|
||||
if (is_insert_input) {
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
||||
}
|
||||
bool need_padding = false;
|
||||
if (is_insert_input) {
|
||||
need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
|
||||
op_name == kTransDataOpName);
|
||||
} else {
|
||||
need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
|
||||
op_name == kTransDataOpName);
|
||||
}
|
||||
if (!need_padding) {
|
||||
// don't need padding insert transdata only
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
|
||||
trans_node = trans_data;
|
||||
} else if (is_insert_input) {
|
||||
// if need padding & is input need insert a transdata
|
||||
// reshape[padding shape] -> transdata[padding shape] -> node
|
||||
auto padding_shape =
|
||||
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name);
|
||||
trans_node = trans_data;
|
||||
} else {
|
||||
// if need padding & is output need insert a transdata
|
||||
// node -> transdata[padding shape] -> reshape[ori_shape]
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
|
||||
auto reshape_node =
|
||||
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
|
||||
trans_node = reshape_node;
|
||||
}
|
||||
// refresh the transdata's format to ori format & dst format
|
||||
MS_EXCEPTION_IF_NULL(trans_data);
|
||||
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
|
||||
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
|
||||
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
|
||||
return trans_node;
|
||||
}
|
||||
|
||||
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
|
||||
const TypeId &input_type, const TypeId &output_type,
|
||||
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
|
||||
|
|
|
@ -58,11 +58,11 @@ class KernelQuery {
|
|||
}
|
||||
};
|
||||
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
|
||||
|
||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||
const bool need_padding, const std::string &op_name);
|
||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index,
|
||||
const std::string &origin_format, const std::string &dest_format,
|
||||
const std::string &op_name, bool is_insert_input);
|
||||
|
||||
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
|
||||
const TypeId &input_type, const TypeId &output_type,
|
||||
|
|
|
@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
// insert trans
|
||||
if (origin_format != cur_format && cur_shape.size() > 1) {
|
||||
auto kernel_select = std::make_shared<KernelSelect>();
|
||||
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node);
|
||||
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format,
|
||||
kTransDataOpName, false);
|
||||
final_index = 0;
|
||||
MS_EXCEPTION_IF_NULL(final_node);
|
||||
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
||||
|
|
|
@ -67,30 +67,22 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
// if output_format=default transdata need split transdata->transpose else transpose->transdata
|
||||
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
|
||||
// trans input_format to hwcn
|
||||
new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
|
||||
false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
|
||||
new_transdata_node);
|
||||
new_transdata_node =
|
||||
AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true);
|
||||
// trans hwcn to default_format
|
||||
new_transpose_node =
|
||||
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
|
||||
new_transpose_node);
|
||||
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN,
|
||||
output_format, prim::kPrimTranspose->name(), false);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
|
||||
new_replace_node = new_transpose_node;
|
||||
} else {
|
||||
// trans default to hwcn
|
||||
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
|
||||
false, prim::kPrimTranspose->name());
|
||||
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN,
|
||||
prim::kPrimTranspose->name(), true);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
|
||||
new_transpose_node);
|
||||
|
||||
// trans hwcn to output_format
|
||||
new_transdata_node =
|
||||
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
|
||||
new_transdata_node);
|
||||
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN,
|
||||
output_format, kTransDataOpName, false);
|
||||
new_replace_node = new_transdata_node;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
|
|
Loading…
Reference in New Issue