!1496 revert decoupled of 1313

Merge pull request !1496 from lianliguang/revert_decoupled
This commit is contained in:
mindspore-ci-bot 2020-05-27 10:32:22 +08:00 committed by Gitee
commit 6be8929f62
4 changed files with 171 additions and 168 deletions

View File

@ -31,156 +31,18 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
trans_inputs.emplace_back(NewValueNode(prim));
trans_inputs.emplace_back(input_node);
auto reshape = func_graph->NewCNode(trans_inputs);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
reshape->set_scope(input_node->scope());
kernel_select->SelectKernel(reshape);
return reshape;
}
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = node;
CNodePtr trans_data = nullptr;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
MS_EXCEPTION_IF_NULL(node);
// if insert transdata for input we need to change the input
if (is_insert_input) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
}
auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
padding_axis = AnfAlgo::GetInputReshapeType(node, 0);
}
bool need_padding = false;
if (is_insert_input) {
need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
} else {
need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
}
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape =
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;
} else {
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
auto reshape_node =
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
trans_node = reshape_node;
}
// refresh the transdata's format to ori format & dst format
RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis);
return trans_node;
}
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
auto input_node = AnfAlgo::GetInputNode(node, index);
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
}
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
<< "when inserting the transdata node " << node->DebugString();
}
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
}
return input_node;
}
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString();
}
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
}
return node;
}
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
<< node->DebugString();
}
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
} else {
// No need insert trans op.
make_tuple_inputs.push_back(tuple_getitem);
}
}
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node, const TypeId device_type,
const kernel::KernelBuildInfo &ori_build_info) {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetInputReshapeType({reshape_type});
builder.SetInputReshapeType({reshape_type});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({device_type});
builder.SetOutputsDeviceType({device_type});
builder.SetKernelType(ori_build_info->kernel_type());
builder.SetFusionType(ori_build_info->fusion_type());
builder.SetProcessor(ori_build_info->processor());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get());
builder.SetKernelType(ori_build_info.kernel_type());
builder.SetFusionType(ori_build_info.fusion_type());
builder.SetProcessor(ori_build_info.processor());
return builder.Build();
}
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
@ -193,7 +55,8 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
trans_inputs.push_back(input);
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node);
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
std::vector<kernel::Axis> padding_axis;
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
@ -216,6 +79,154 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
return trans_node;
}
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
trans_inputs.emplace_back(NewValueNode(prim));
trans_inputs.emplace_back(input_node);
auto reshape = func_graph->NewCNode(trans_inputs);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
reshape->set_scope(input_node->scope());
kernel_select->SelectKernel(reshape);
return reshape;
}
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
auto input_node = AnfAlgo::GetInputNode(node, index);
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
}
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
<< "when inserting the transdata node " << node->DebugString();
}
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string origin_format = kOpFormat_DEFAULT;
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName,
true);
}
return input_node;
}
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString();
}
std::string origin_format = output_format;
std::string dest_format = kOpFormat_DEFAULT;
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName,
false);
}
return node;
}
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
<< node->DebugString();
}
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
std::string dest_format = kOpFormat_DEFAULT;
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format,
dest_format, kTransDataOpName, false));
} else {
// No need insert trans op.
make_tuple_inputs.push_back(tuple_getitem);
}
}
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index,
const std::string &origin_format, const std::string &dest_format,
const std::string &op_name, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = node;
AnfNodePtr trans_data = nullptr;
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
MS_EXCEPTION_IF_NULL(node);
if (origin_format.empty() || dest_format.empty()) {
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
}
// if insert transdata for input we need to change the input
if (is_insert_input) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
}
auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
}
bool need_padding = false;
if (is_insert_input) {
need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
} else {
need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
}
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
trans_node = trans_data;
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape =
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name);
trans_node = trans_data;
} else {
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
auto reshape_node =
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
trans_node = reshape_node;
}
// refresh the transdata's format to ori format & dst format
MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
return trans_node;
}
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {

View File

@ -58,11 +58,11 @@ class KernelQuery {
}
};
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name);
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index,
const std::string &origin_format, const std::string &dest_format,
const std::string &op_name, bool is_insert_input);
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,

View File

@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
// insert trans
if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node);
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format,
kTransDataOpName, false);
final_index = 0;
MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();

View File

@ -67,30 +67,22 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
// trans input_format to hwcn
new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transdata_node);
new_transdata_node =
AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true);
// trans hwcn to default_format
new_transpose_node =
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
new_transpose_node);
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN,
output_format, prim::kPrimTranspose->name(), false);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
new_replace_node = new_transpose_node;
} else {
// trans default to hwcn
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::kPrimTranspose->name());
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN,
prim::kPrimTranspose->name(), true);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
new_transpose_node);
// trans hwcn to output_format
new_transdata_node =
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transdata_node);
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN,
output_format, kTransDataOpName, false);
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();