diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index da693fe0a28..ab17ccd60ca 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -1,4 +1,4 @@ mtk_detect-mbv2-shortcut-400-400-simplified.onnx mtk_emotions-d2012-75.8%.onnx -mtk_face_features_v3.onnx +# mtk_face_features_v3.onnx ml_face_3d.onnx diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index c6c746f739b..8f9bb4fd02a 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -524,6 +524,30 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz MS_ASSERT(postNode != nullptr); auto &postTensor = graphT->allTensors.at(postTensorIdx); MS_ASSERT(postTensor != nullptr); + // for multioutput,when one outpout as other node input,need add one more node + if (IsContain(graphT->outputIndex, postTensorIdx)) { + auto toAddTensor = CopyTensorDefT(postTensor); + if (toAddTensor == nullptr) { + MS_LOG(ERROR) << "Copy TensorT failed"; + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn.get()); + toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++); + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(postTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) { + if (*iter == postTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + toAddNodes.emplace_back(std::move(toAddNode)); + } auto toAddTensor = CopyTensorDefT(postTensor); if (toAddTensor == nullptr) { MS_LOG(ERROR) << "Copy TensorT failed"; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 38bc85d7b6d..524da3d50a3 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -82,8 +82,13 @@ static const std::vector int8OpList = { schema::PrimitiveType_Pad}; static const std::vector needInsertOpList = { - schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, - schema::PrimitiveType_Power}; + schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, + schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, + schema::PrimitiveType_Split}; + +static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; + +std::unordered_map GetNc2NhAxisMap() { return nc2NhAxisMap; } std::vector GetInsertOpList() { return needInsertOpList; } diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 0a7bc94a9f0..0a238c589cb 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -19,6 +19,7 @@ #include #include +#include #include "schema/inner/model_generated.h" #include "src/common/common.h" #include "utils/log_adapter.h" @@ -30,6 +31,8 @@ namespace lite { using STATUS = int; STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node); +std::unordered_map GetNc2NhAxisMap(); + std::vector GetInsertOpList(); std::vector GetNhwcOpList(); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index 81f61dbe973..f2999c2d1d8 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" #include "tools/common/converter_op_utils.h" @@ -117,48 +118,86 @@ STATUS TransOpInsertPass::FindOutTransType() { return RET_OK; } +STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node) { + if (node == nullptr && node->primitive == nullptr) { + MS_LOG(ERROR) << "node or primitive null"; + return RET_NULL_PTR; + } + auto type = node->primitive->value.type; + if (graph->allTensors.at(node->inputIndex[0])->dims.size() != 4) { + MS_LOG(ERROR) << "change op axis only support 4 dims"; + return RET_NOT_SUPPORT; + } + if (type == PrimitiveType_Concat) { + auto origin_axis = node->primitive->value.AsConcat()->axis; + auto axis_map = GetNc2NhAxisMap(); + node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; + } + if (type == PrimitiveType_StridedSlice) { + auto attr = node->primitive->value.AsStridedSlice(); + auto origin_begin = attr->begin; + attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; + auto origin_end = attr->end; + attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; + auto origin_stride = attr->stride; + attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; + } + if (type == PrimitiveType_Split) { + auto origin_axis = node->primitive->value.AsSplit()->splitDim; + auto axis_map = GetNc2NhAxisMap(); + node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; + } + return RET_OK; +} + STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - auto type = node->primitive->value.type; - if (!IsContain(GetInsertOpList(), type)) { - continue; - } - auto node_name = node->name; - if (!CanFusion(graph, node)) { - continue; - } - auto ret = FindOutTransType(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "FindOutTransType error"; - return ret; - } - // 4 dims means infershape success,can delete - if (type == PrimitiveType_Concat) { - if (graph->allTensors.at(node->inputIndex[0])->dims.size() == 4) { - node->primitive->value.AsConcat()->axis = -1; - } else { + bool changed = true; + int run_counts = 0; + std::vector has_insert_nodes; + while (changed && run_counts < 10) { + changed = false; + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + auto type = node->primitive->value.type; + if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) { continue; } - } - STATUS status = RET_OK; - auto input_tensor_size = (*iter)->inputIndex.size(); - for (size_t i = 0; i < input_tensor_size; i++) { - iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; - return status; + auto node_name = node->name; + if (!CanFusion(graph, node)) { + continue; } - } - auto output_tensor_size = (*iter)->outputIndex.size(); - for (size_t i = 0; i < output_tensor_size; i++) { - iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; - return status; + auto ret = FindOutTransType(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FindOutTransType error"; + return ret; } + ret = ChangeOpAxis(graph, node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ChangeOpAxis error"; + return ret; + } + has_insert_nodes.push_back(node.get()); + STATUS status = RET_OK; + auto input_tensor_size = (*iter)->inputIndex.size(); + for (size_t i = 0; i < input_tensor_size; i++) { + iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; + return status; + } + } + auto output_tensor_size = (*iter)->outputIndex.size(); + for (size_t i = 0; i < output_tensor_size; i++) { + iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; + return status; + } + } + changed = true; } + run_counts++; } return RET_OK; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h index 22a3798931f..53046321622 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h @@ -37,6 +37,8 @@ class TransOpInsertPass : public FormatTransPass { STATUS FindOutTransType(); + STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node); + private: FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;