forked from mindspore-Ecosystem/mindspore
!6705 [MSLITE]transformat op optimize for strideslice
Merge pull request !6705 from zhengjun10/stride
This commit is contained in:
commit
d61683b456
|
@ -1,4 +1,4 @@
|
|||
mtk_detect-mbv2-shortcut-400-400-simplified.onnx
|
||||
mtk_emotions-d2012-75.8%.onnx
|
||||
mtk_face_features_v3.onnx
|
||||
# mtk_face_features_v3.onnx
|
||||
ml_face_3d.onnx
|
||||
|
|
|
@ -524,6 +524,30 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|||
MS_ASSERT(postNode != nullptr);
|
||||
auto &postTensor = graphT->allTensors.at(postTensorIdx);
|
||||
MS_ASSERT(postTensor != nullptr);
|
||||
// for multioutput,when one outpout as other node input,need add one more node
|
||||
if (IsContain(graphT->outputIndex, postTensorIdx)) {
|
||||
auto toAddTensor = CopyTensorDefT(postTensor);
|
||||
if (toAddTensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Copy TensorT failed";
|
||||
*errorCode = RET_NULL_PTR;
|
||||
return graphT->nodes.end();
|
||||
}
|
||||
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
||||
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
||||
auto toAddNode = opDefCopyer(toAddNodeIn.get());
|
||||
toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
|
||||
toAddNode->inputIndex.clear();
|
||||
toAddNode->inputIndex.push_back(postTensorIdx);
|
||||
toAddNode->outputIndex.clear();
|
||||
toAddNode->outputIndex.push_back(toAddTensorIdx);
|
||||
for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
|
||||
if (*iter == postTensorIdx) {
|
||||
*iter = toAddTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
toAddNodes.emplace_back(std::move(toAddNode));
|
||||
}
|
||||
auto toAddTensor = CopyTensorDefT(postTensor);
|
||||
if (toAddTensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Copy TensorT failed";
|
||||
|
|
|
@ -83,7 +83,12 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
|
|||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_Power};
|
||||
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
|
||||
schema::PrimitiveType_Split};
|
||||
|
||||
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
||||
|
||||
std::unordered_map<int, int> GetNc2NhAxisMap() { return nc2NhAxisMap; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; }
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/common.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -30,6 +31,8 @@ namespace lite {
|
|||
using STATUS = int;
|
||||
STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
|
||||
|
||||
std::unordered_map<int, int> GetNc2NhAxisMap();
|
||||
|
||||
std::vector<schema::PrimitiveType> GetInsertOpList();
|
||||
|
||||
std::vector<schema::PrimitiveType> GetNhwcOpList();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
|
||||
#include "tools/common/converter_op_utils.h"
|
||||
|
@ -117,12 +118,49 @@ STATUS TransOpInsertPass::FindOutTransType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
||||
if (node == nullptr && node->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "node or primitive null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto type = node->primitive->value.type;
|
||||
if (graph->allTensors.at(node->inputIndex[0])->dims.size() != 4) {
|
||||
MS_LOG(ERROR) << "change op axis only support 4 dims";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
if (type == PrimitiveType_Concat) {
|
||||
auto origin_axis = node->primitive->value.AsConcat()->axis;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
||||
}
|
||||
if (type == PrimitiveType_StridedSlice) {
|
||||
auto attr = node->primitive->value.AsStridedSlice();
|
||||
auto origin_begin = attr->begin;
|
||||
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
|
||||
auto origin_end = attr->end;
|
||||
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
|
||||
auto origin_stride = attr->stride;
|
||||
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
|
||||
}
|
||||
if (type == PrimitiveType_Split) {
|
||||
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
bool changed = true;
|
||||
int run_counts = 0;
|
||||
std::vector<CNodeT *> has_insert_nodes;
|
||||
while (changed && run_counts < 10) {
|
||||
changed = false;
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
auto type = node->primitive->value.type;
|
||||
if (!IsContain(GetInsertOpList(), type)) {
|
||||
if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = node->name;
|
||||
|
@ -134,14 +172,12 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|||
MS_LOG(ERROR) << "FindOutTransType error";
|
||||
return ret;
|
||||
}
|
||||
// 4 dims means infershape success,can delete
|
||||
if (type == PrimitiveType_Concat) {
|
||||
if (graph->allTensors.at(node->inputIndex[0])->dims.size() == 4) {
|
||||
node->primitive->value.AsConcat()->axis = -1;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
ret = ChangeOpAxis(graph, node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ChangeOpAxis error";
|
||||
return ret;
|
||||
}
|
||||
has_insert_nodes.push_back(node.get());
|
||||
STATUS status = RET_OK;
|
||||
auto input_tensor_size = (*iter)->inputIndex.size();
|
||||
for (size_t i = 0; i < input_tensor_size; i++) {
|
||||
|
@ -159,6 +195,9 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|||
return status;
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
run_counts++;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -37,6 +37,8 @@ class TransOpInsertPass : public FormatTransPass {
|
|||
|
||||
STATUS FindOutTransType();
|
||||
|
||||
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
|
||||
|
||||
private:
|
||||
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
|
||||
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;
|
||||
|
|
Loading…
Reference in New Issue