forked from mindspore-Ecosystem/mindspore
!8613 [MSLITE] add global format trans
From: @zhengjun10 Reviewed-by: @HilbertDavid,@hangangqiang Signed-off-by: @HilbertDavid
This commit is contained in:
commit
576e6d1577
|
@ -601,6 +601,138 @@ STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) {
|
|||
}
|
||||
}
|
||||
|
||||
void TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
|
||||
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
|
||||
MS_LOG(INFO) << "Attr data is from other nodes.";
|
||||
return;
|
||||
}
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
std::vector<int> cur_attr;
|
||||
for (int dim = 0; dim < 4; ++dim) {
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
|
||||
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
|
||||
cur_attr.push_back(origin_attr[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
origin_attr[index] = cur_attr[index];
|
||||
}
|
||||
}
|
||||
|
||||
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
|
||||
auto type = node->primitive->value.type;
|
||||
if (type == schema::PrimitiveType_StridedSlice) {
|
||||
// onnx input size is equal to 5 always.
|
||||
if (node->inputIndex.size() == 5) {
|
||||
for (int index = 1; index < 5; ++index) {
|
||||
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
|
||||
MS_LOG(INFO) << "Here don't consider input is from other nodes.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
|
||||
auto axes = graph->allTensors[node->inputIndex[3]]->data;
|
||||
for (int index = 1; index < 5; ++index) {
|
||||
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
|
||||
reinterpret_cast<int *>(axes.data()), element_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (type == schema::PrimitiveType_Slice) {
|
||||
auto attr = node->primitive->value.AsSlice();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
// transform attr
|
||||
attr->format = schema::Format_NHWC;
|
||||
if (attr->begin.empty() || attr->size.empty()) {
|
||||
MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
int element_num = attr->begin.size();
|
||||
if (attr->axes.empty()) {
|
||||
for (int index = 0; index < element_num; ++index) {
|
||||
attr->axes.push_back(index);
|
||||
}
|
||||
}
|
||||
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
|
||||
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
|
||||
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
|
||||
MS_ASSERT(node->primitive->value != nullptr);
|
||||
auto type = node->primitive->value.type;
|
||||
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
|
||||
if (input1_ndim != 4 && input1_ndim != 0) {
|
||||
if (node->inputIndex.size() > 1) {
|
||||
auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size();
|
||||
if (input2_ndim != 4 && input2_ndim != 0) {
|
||||
MS_LOG(ERROR) << "change op axis only support 4 dims";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "change op axis only support 4 dims";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
if (type == schema::PrimitiveType_Concat) {
|
||||
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsConcat()->axis;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsConcat() == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
||||
}
|
||||
if (type == schema::PrimitiveType_Split) {
|
||||
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsSplit() == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
|
||||
}
|
||||
if (type == schema::PrimitiveType_Crop) {
|
||||
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsCrop()->axis;
|
||||
auto offsets = node->primitive->value.AsCrop()->offsets;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsCrop() == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
|
||||
// nchw->nhwc,offsets need pad 0;
|
||||
if (axis_map[origin_axis] == 0) {
|
||||
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
|
||||
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
|
||||
// orgin_axis = 2 or orgin_axis = 3
|
||||
offsets.push_back(0);
|
||||
} else if (axis_map[origin_axis] == -1) {
|
||||
// origin_axis = 1
|
||||
offsets = {offsets[1], offsets[2], offsets[0]};
|
||||
} else {
|
||||
// axis error
|
||||
MS_LOG(ERROR) << "Crop error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
node->primitive->value.AsCrop()->offsets = offsets;
|
||||
}
|
||||
if (type == schema::PrimitiveType_Slice || type == schema::PrimitiveType_StridedSlice) {
|
||||
return ChangeOpAttrForSlice(graph, node);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::string GetModelName(const std::string &modelFile) {
|
||||
std::string modelName = modelFile;
|
||||
modelName = modelName.substr(modelName.find_last_of('/') + 1);
|
||||
|
|
|
@ -86,6 +86,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|||
std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer);
|
||||
|
||||
STATUS ValidateFileStr(const std::string &modelFile, std::string fileType);
|
||||
|
||||
void TransformAttrByAxes(int *origin_attr, int *axes, int element_size);
|
||||
|
||||
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
|
||||
|
||||
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
|
||||
|
||||
std::string GetModelName(const std::string &modelFile);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -139,7 +139,8 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
|
|||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
|
||||
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop};
|
||||
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop,
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum};
|
||||
|
||||
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
|
||||
|
@ -114,6 +115,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
if (ctx.trainModel == false && ctx.fmk != converter::FmkType_ONNX) {
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
}
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
|
|
|
@ -12,6 +12,7 @@ file(GLOB GRAPH_PASS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
|
||||
)
|
||||
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h"
|
||||
#include <algorithm>
|
||||
#include "third_party/securec/include/securec.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
std::set<size_t> need_del_nodes;
|
||||
std::set<size_t> need_trans_format_nodes;
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
auto type = node->primitive->value.type;
|
||||
if (type != schema::PrimitiveType_Nchw2Nhwc) {
|
||||
continue;
|
||||
}
|
||||
std::vector<size_t> pre_nh2nc_nodes;
|
||||
std::vector<size_t> pre_not_trans_nodes;
|
||||
auto status = FindPreNh2NcNodes(graph, iter - graph->nodes.begin(), &pre_nh2nc_nodes, &pre_not_trans_nodes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
||||
return status;
|
||||
}
|
||||
std::copy(pre_nh2nc_nodes.begin(), pre_nh2nc_nodes.end(), std::inserter(need_del_nodes, need_del_nodes.end()));
|
||||
std::copy(pre_not_trans_nodes.begin(), pre_not_trans_nodes.end(),
|
||||
std::inserter(need_trans_format_nodes, need_trans_format_nodes.end()));
|
||||
if (!pre_nh2nc_nodes.empty()) {
|
||||
need_del_nodes.insert(iter - graph->nodes.begin());
|
||||
}
|
||||
}
|
||||
if (need_del_nodes.empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
for (auto del_node_index : need_del_nodes) {
|
||||
auto node_name = graph->nodes.at(del_node_index)->name;
|
||||
auto status = IsolateOneWayNode(graph, del_node_index);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Isolate Node failed, node: " << node_name << ", error: " << status;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
auto status = TransWeightToNhwc(graph, need_trans_format_nodes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "trans weight to nhwc failed";
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector<int> &pad_dims) {
|
||||
if (pad_dims.size() != 4) {
|
||||
MS_LOG(ERROR) << "pad dims error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto batch = pad_dims[NCHW_N];
|
||||
auto channel = pad_dims[NCHW_C];
|
||||
auto area = pad_dims[NCHW_H] * pad_dims[NCHW_W];
|
||||
auto size = batch * channel * area;
|
||||
auto new_nhwc_data = new (std::nothrow) float[size];
|
||||
if (new_nhwc_data == nullptr) {
|
||||
MS_LOG(ERROR) << "create new nhwc data failed";
|
||||
delete[] new_nhwc_data;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(new_nhwc_data, 0, sizeof(float) * size);
|
||||
auto nchw_data = reinterpret_cast<float *>(tensor->data.data());
|
||||
// nchw to nhwc
|
||||
for (auto i = 0; i < batch; i++) {
|
||||
float *src_batch = nchw_data + i * channel * area;
|
||||
float *dst_batch = new_nhwc_data + i * channel * area;
|
||||
for (int j = 0; j < area; ++j) {
|
||||
float *src_area = src_batch + i;
|
||||
float *dst_area = dst_batch + i * channel;
|
||||
for (int k = 0; k < channel; ++k) {
|
||||
dst_area[k] = src_area[k * area];
|
||||
}
|
||||
}
|
||||
}
|
||||
memcpy(nchw_data, new_nhwc_data, sizeof(float) * size);
|
||||
delete[] new_nhwc_data;
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes) {
|
||||
if (pre_not_trans_nodes.empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
for (auto index : pre_not_trans_nodes) {
|
||||
auto &cur_node = graph->nodes.at(index);
|
||||
// need change axis from nchw to nhwc like concat,slice
|
||||
auto ret = ChangeOpAxis(graph, cur_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ChangeOpAxis error";
|
||||
return ret;
|
||||
}
|
||||
auto node_input_indexs = cur_node->inputIndex;
|
||||
for (auto input_index : node_input_indexs) {
|
||||
// weight data need trans nhwc layerout
|
||||
if (!IsContain(graph->inputIndex, input_index) &&
|
||||
graph->allTensors.at(input_index)->nodeType == NodeType_ValueNode) {
|
||||
auto &weight_tensor = graph->allTensors.at(input_index);
|
||||
auto origin_dims = weight_tensor->dims;
|
||||
weight_tensor->format = Format_NHWC;
|
||||
if (origin_dims.size() > 4) {
|
||||
MS_LOG(ERROR) << "tensor origin tensor size error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (origin_dims.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
auto pad_dims = origin_dims;
|
||||
if (origin_dims.size() == 1) {
|
||||
pad_dims = {1, 1, 1, origin_dims[0]};
|
||||
} else if (origin_dims.size() == 2) {
|
||||
pad_dims = {1, 1, origin_dims[0], origin_dims[1]};
|
||||
} else if (origin_dims.size() == 3) {
|
||||
pad_dims = {1, origin_dims[0], origin_dims[1], origin_dims[2]};
|
||||
}
|
||||
if (ConvertNcTensor2Nh(weight_tensor.get(), pad_dims) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert nchw to nhwc failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_tensor->dims = {pad_dims[NCHW_N], pad_dims[NCHW_H], pad_dims[NCHW_W], pad_dims[NCHW_C]};
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index,
|
||||
std::vector<size_t> *pre_nh2nc_nodes,
|
||||
std::vector<size_t> *pre_not_trans_nodes) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
std::vector<size_t> bfs_queue = {nc2nh_index};
|
||||
// find pre node nh2nc start nodes
|
||||
while (!bfs_queue.empty()) {
|
||||
auto cur_node_index = bfs_queue.back();
|
||||
auto &cur_node = graph->nodes.at(cur_node_index);
|
||||
bfs_queue.pop_back();
|
||||
auto input_node_indexes = GetInputNodeIdx(*graph, *cur_node);
|
||||
for (auto input_node_index : input_node_indexes) {
|
||||
MS_ASSERT(graph->nodes.size() > input_node_index);
|
||||
auto &pre_node = graph->nodes.at(input_node_index);
|
||||
MS_ASSERT(pre_node != nullptr);
|
||||
auto node_type = pre_node->primitive->value.type;
|
||||
if (node_type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
if (!IsContain(*pre_nh2nc_nodes, input_node_index)) {
|
||||
pre_nh2nc_nodes->emplace_back(input_node_index);
|
||||
}
|
||||
} else if (IsContain(GetInsertOpList(), node_type)) {
|
||||
if (!IsContain(bfs_queue, input_node_index)) {
|
||||
bfs_queue.emplace_back(input_node_index);
|
||||
}
|
||||
// todo multi output,other edge need insert nh2nc node
|
||||
auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node);
|
||||
if ((pre_node_output_indexs.size() != 1) && (node_type == schema::PrimitiveType_Activation)) {
|
||||
pre_nh2nc_nodes->clear();
|
||||
pre_not_trans_nodes->clear();
|
||||
return RET_OK;
|
||||
}
|
||||
} else {
|
||||
pre_nh2nc_nodes->clear();
|
||||
pre_not_trans_nodes->clear();
|
||||
return RET_OK;
|
||||
}
|
||||
if (!IsContain(*pre_not_trans_nodes, cur_node_index) && cur_node_index != nc2nh_index) {
|
||||
pre_not_trans_nodes->emplace_back(cur_node_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H
|
||||
#define MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
using mindspore::schema::TensorT;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class GlobalFormatTransformPass : public GraphPass {
|
||||
public:
|
||||
GlobalFormatTransformPass() = default;
|
||||
|
||||
~GlobalFormatTransformPass() = default;
|
||||
|
||||
STATUS Run(MetaGraphT *graph) override;
|
||||
|
||||
protected:
|
||||
STATUS TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes);
|
||||
|
||||
STATUS FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, std::vector<size_t> *to_do_insert_nodes,
|
||||
std::vector<size_t> *pre_not_trans_nodes);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H
|
|
@ -127,146 +127,6 @@ STATUS TransOpInsertPass::FindOutTransType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
|
||||
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
|
||||
MS_LOG(INFO) << "Attr data is from other nodes.";
|
||||
return;
|
||||
}
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
std::vector<int> cur_attr;
|
||||
for (int dim = 0; dim < 4; ++dim) {
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
|
||||
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
|
||||
cur_attr.push_back(origin_attr[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
origin_attr[index] = cur_attr[index];
|
||||
}
|
||||
}
|
||||
|
||||
STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
||||
if (node == nullptr && node->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "node or primitive null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto type = node->primitive->value.type;
|
||||
if (type == PrimitiveType_StridedSlice) {
|
||||
// onnx input size is equal to 5 always.
|
||||
if (node->inputIndex.size() == 5) {
|
||||
for (int index = 1; index < 5; ++index) {
|
||||
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
|
||||
MS_LOG(INFO) << "Here don't consider input is from other nodes.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
|
||||
auto axes = graph->allTensors[node->inputIndex[3]]->data;
|
||||
for (int index = 1; index < 5; ++index) {
|
||||
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
|
||||
reinterpret_cast<int *>(axes.data()), element_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (type == PrimitiveType_Slice) {
|
||||
auto attr = node->primitive->value.AsSlice();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
// transform attr
|
||||
attr->format = schema::Format_NHWC;
|
||||
if (attr->begin.empty() || attr->size.empty()) {
|
||||
MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
int element_num = attr->begin.size();
|
||||
if (attr->axes.empty()) {
|
||||
for (int index = 0; index < element_num; ++index) {
|
||||
attr->axes.push_back(index);
|
||||
}
|
||||
}
|
||||
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
|
||||
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
|
||||
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
||||
if (node == nullptr && node->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "node or primitive null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
MS_ASSERT(node->primitive->value != nullptr);
|
||||
auto type = node->primitive->value.type;
|
||||
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
|
||||
if (input1_ndim != 4) {
|
||||
if (node->inputIndex.size() > 1) {
|
||||
auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size();
|
||||
if (input2_ndim != 4 && input2_ndim != 0) {
|
||||
MS_LOG(ERROR) << "change op axis only support 4 dims";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "change op axis only support 4 dims";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
if (type == PrimitiveType_Concat) {
|
||||
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsConcat()->axis;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsConcat() == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
||||
}
|
||||
if (type == PrimitiveType_Split) {
|
||||
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsSplit() == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
|
||||
}
|
||||
if (type == PrimitiveType_Crop) {
|
||||
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsCrop()->axis;
|
||||
auto offsets = node->primitive->value.AsCrop()->offsets;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsCrop() == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
|
||||
// nchw->nhwc,offsets need pad 0;
|
||||
if (axis_map[origin_axis] == 0) {
|
||||
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
|
||||
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
|
||||
// orgin_axis = 2 or orgin_axis = 3
|
||||
offsets.push_back(0);
|
||||
} else if (axis_map[origin_axis] == -1) {
|
||||
// origin_axis = 1
|
||||
offsets = {offsets[1], offsets[2], offsets[0]};
|
||||
} else {
|
||||
// axis error
|
||||
MS_LOG(ERROR) << "Crop error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
node->primitive->value.AsCrop()->offsets = offsets;
|
||||
}
|
||||
if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) {
|
||||
return ChangeOpAttrForSlice(graph, node);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
bool changed = true;
|
||||
|
|
|
@ -41,8 +41,6 @@ class TransOpInsertPass : public FormatTransPass {
|
|||
|
||||
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
|
||||
|
||||
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
|
||||
|
||||
private:
|
||||
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
|
||||
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;
|
||||
|
|
Loading…
Reference in New Issue