!6705 [MSLITE]transformat op optimize for strideslice

Merge pull request !6705 from zhengjun10/stride
This commit is contained in:
mindspore-ci-bot 2020-09-23 19:13:13 +08:00 committed by Gitee
commit d61683b456
6 changed files with 111 additions and 38 deletions

View File

@ -1,4 +1,4 @@
mtk_detect-mbv2-shortcut-400-400-simplified.onnx mtk_detect-mbv2-shortcut-400-400-simplified.onnx
mtk_emotions-d2012-75.8%.onnx mtk_emotions-d2012-75.8%.onnx
mtk_face_features_v3.onnx # mtk_face_features_v3.onnx
ml_face_3d.onnx ml_face_3d.onnx

View File

@ -524,6 +524,30 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
MS_ASSERT(postNode != nullptr); MS_ASSERT(postNode != nullptr);
auto &postTensor = graphT->allTensors.at(postTensorIdx); auto &postTensor = graphT->allTensors.at(postTensorIdx);
MS_ASSERT(postTensor != nullptr); MS_ASSERT(postTensor != nullptr);
// for multioutput,when one outpout as other node input,need add one more node
if (IsContain(graphT->outputIndex, postTensorIdx)) {
auto toAddTensor = CopyTensorDefT(postTensor);
if (toAddTensor == nullptr) {
MS_LOG(ERROR) << "Copy TensorT failed";
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn.get());
toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
toAddNode->inputIndex.clear();
toAddNode->inputIndex.push_back(postTensorIdx);
toAddNode->outputIndex.clear();
toAddNode->outputIndex.push_back(toAddTensorIdx);
for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
if (*iter == postTensorIdx) {
*iter = toAddTensorIdx;
break;
}
}
toAddNodes.emplace_back(std::move(toAddNode));
}
auto toAddTensor = CopyTensorDefT(postTensor); auto toAddTensor = CopyTensorDefT(postTensor);
if (toAddTensor == nullptr) { if (toAddTensor == nullptr) {
MS_LOG(ERROR) << "Copy TensorT failed"; MS_LOG(ERROR) << "Copy TensorT failed";

View File

@ -82,8 +82,13 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Pad}; schema::PrimitiveType_Pad};
static const std::vector<schema::PrimitiveType> needInsertOpList = { static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power}; schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
schema::PrimitiveType_Split};
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
std::unordered_map<int, int> GetNc2NhAxisMap() { return nc2NhAxisMap; }
std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; } std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; }

View File

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <unordered_map>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -30,6 +31,8 @@ namespace lite {
using STATUS = int; using STATUS = int;
STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node); STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
std::unordered_map<int, int> GetNc2NhAxisMap();
std::vector<schema::PrimitiveType> GetInsertOpList(); std::vector<schema::PrimitiveType> GetInsertOpList();
std::vector<schema::PrimitiveType> GetNhwcOpList(); std::vector<schema::PrimitiveType> GetNhwcOpList();

View File

@ -16,6 +16,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector>
#include <utility> #include <utility>
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
#include "tools/common/converter_op_utils.h" #include "tools/common/converter_op_utils.h"
@ -117,48 +118,86 @@ STATUS TransOpInsertPass::FindOutTransType() {
return RET_OK; return RET_OK;
} }
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
if (node == nullptr && node->primitive == nullptr) {
MS_LOG(ERROR) << "node or primitive null";
return RET_NULL_PTR;
}
auto type = node->primitive->value.type;
if (graph->allTensors.at(node->inputIndex[0])->dims.size() != 4) {
MS_LOG(ERROR) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
if (type == PrimitiveType_Concat) {
auto origin_axis = node->primitive->value.AsConcat()->axis;
auto axis_map = GetNc2NhAxisMap();
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
}
if (type == PrimitiveType_StridedSlice) {
auto attr = node->primitive->value.AsStridedSlice();
auto origin_begin = attr->begin;
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
auto origin_end = attr->end;
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
auto origin_stride = attr->stride;
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
}
if (type == PrimitiveType_Split) {
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
auto axis_map = GetNc2NhAxisMap();
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
}
return RET_OK;
}
STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { bool changed = true;
auto &node = *iter; int run_counts = 0;
auto type = node->primitive->value.type; std::vector<CNodeT *> has_insert_nodes;
if (!IsContain(GetInsertOpList(), type)) { while (changed && run_counts < 10) {
continue; changed = false;
} for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto node_name = node->name; auto &node = *iter;
if (!CanFusion(graph, node)) { auto type = node->primitive->value.type;
continue; if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) {
}
auto ret = FindOutTransType();
if (ret != RET_OK) {
MS_LOG(ERROR) << "FindOutTransType error";
return ret;
}
// 4 dims means infershape success,can delete
if (type == PrimitiveType_Concat) {
if (graph->allTensors.at(node->inputIndex[0])->dims.size() == 4) {
node->primitive->value.AsConcat()->axis = -1;
} else {
continue; continue;
} }
} auto node_name = node->name;
STATUS status = RET_OK; if (!CanFusion(graph, node)) {
auto input_tensor_size = (*iter)->inputIndex.size(); continue;
for (size_t i = 0; i < input_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
return status;
} }
} auto ret = FindOutTransType();
auto output_tensor_size = (*iter)->outputIndex.size(); if (ret != RET_OK) {
for (size_t i = 0; i < output_tensor_size; i++) { MS_LOG(ERROR) << "FindOutTransType error";
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); return ret;
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed";
return status;
} }
ret = ChangeOpAxis(graph, node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ChangeOpAxis error";
return ret;
}
has_insert_nodes.push_back(node.get());
STATUS status = RET_OK;
auto input_tensor_size = (*iter)->inputIndex.size();
for (size_t i = 0; i < input_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
return status;
}
}
auto output_tensor_size = (*iter)->outputIndex.size();
for (size_t i = 0; i < output_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed";
return status;
}
}
changed = true;
} }
run_counts++;
} }
return RET_OK; return RET_OK;
} }

View File

@ -37,6 +37,8 @@ class TransOpInsertPass : public FormatTransPass {
STATUS FindOutTransType(); STATUS FindOutTransType();
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
private: private:
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;