forked from mindspore-Ecosystem/mindspore
!2164 add more needtrans format for transdata
Merge pull request !2164 from lianliguang/master
This commit is contained in:
commit
04a23d138f
|
@ -70,7 +70,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
|||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
|
||||
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
|
||||
if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
|
||||
kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) {
|
||||
kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) {
|
||||
priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
|
||||
is_init = true;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
namespace {
|
||||
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
|
||||
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
||||
std::vector<AnfNodePtr> trans_inputs;
|
||||
|
@ -110,13 +111,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
AnfAlgo::SetNodeInput(node, input_node, index);
|
||||
}
|
||||
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< "when inserting the transdata node " << node->DebugString();
|
||||
}
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
||||
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
||||
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< " To DefaultFormat , index: " << index;
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||
|
@ -133,7 +130,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|||
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
||||
<< node->DebugString();
|
||||
}
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
||||
}
|
||||
|
@ -154,7 +151,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
|
||||
} else {
|
||||
// No need insert trans op.
|
||||
|
|
|
@ -97,7 +97,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
|
|||
std::string convert_format;
|
||||
for (const auto &do_mask : do_mask_node_list) {
|
||||
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0);
|
||||
if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) {
|
||||
if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) {
|
||||
special_format = do_mask_data_format;
|
||||
}
|
||||
if (format_counter.find(do_mask_data_format) == format_counter.end()) {
|
||||
|
@ -111,7 +111,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
|
|||
convert_format = kOpFormat_DEFAULT;
|
||||
break;
|
||||
}
|
||||
if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() &&
|
||||
if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() &&
|
||||
special_format != do_mask_data_format) {
|
||||
convert_format = kOpFormat_DEFAULT;
|
||||
break;
|
||||
|
@ -133,7 +133,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
|
|||
if (counter < iter.second) {
|
||||
convert_format = iter.first;
|
||||
}
|
||||
if (counter == iter.second && kNeedTransFormatSet.find(convert_format) == kNeedTransFormatSet.end()) {
|
||||
if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) {
|
||||
convert_format = iter.first;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -265,7 +265,7 @@ const std::set<std::string> kOptOperatorSet = {
|
|||
kApplyRMSPropOpName,
|
||||
};
|
||||
|
||||
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
|
||||
kOpFormat_FRACTAL_Z_C04};
|
||||
|
||||
|
|
|
@ -58,6 +58,8 @@ trans_data_op_info = TBERegOp("TransData") \
|
|||
.dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_HWCN, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue