graph input support NCHW

This commit is contained in:
xuanyue 2021-08-11 12:54:33 +08:00
parent 223f500bab
commit d5926ddab7
13 changed files with 304 additions and 12 deletions

View File

@ -257,6 +257,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc
${LITE_DIR}/tools/optimizer/graph/reduce_same_act_pass.cc
${LITE_DIR}/tools/optimizer/graph/split_one_pass.cc
${LITE_DIR}/tools/optimizer/graph/specify_graph_input_format.cc
${LITE_DIR}/tools/optimizer/fisson/eliminate_concat_split.cc
${LITE_DIR}/tools/optimizer/fisson/fisson_util.cc
${LITE_DIR}/tools/optimizer/fisson/iter_node_outputs.cc

View File

@ -739,7 +739,6 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
graph_inputs_has_exported_.find(input_node) == graph_inputs_has_exported_.end()) {
graph_inputs_has_exported_.insert(input_node);
meta_graphT->inputIndex.push_back(meta_graphT->allTensors.size() - 1);
meta_graphT->allTensors.back()->format = schema::Format_NHWC;
}
} else if (input_node->isa<ValueNode>()) {
auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node);

View File

@ -46,7 +46,6 @@ class AnfExporter {
public:
AnfExporter() = default;
virtual ~AnfExporter() = default;
void set_train_flag(bool train_flag) { train_flag_ = train_flag; }
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
bool train_flag = false);
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,

View File

@ -286,15 +286,15 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
return RET_ERROR;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) {
data_info->format_ = mindspore::NHWC;
}
if (prim->GetAttr(ops::kFormat) != nullptr && !opt::CheckPrimitiveType(cnode, prim::kPrimResize)) {
auto value = prim->GetAttr(ops::kFormat);
if (value->isa<mindspore::Int64Imm>()) {
data_info->format_ = GetValue<int64_t>(value);
}
}
if (!param_node->has_default()) {
data_info->format_ = NHWC;
}
// attr weightFormat is only used by conv-like ops' second input
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||

View File

@ -113,6 +113,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/reduce_same_act_pass.cc
../optimizer/graph/split_one_pass.cc
../optimizer/graph/find_const_subgraph_pass.cc
../optimizer/graph/specify_graph_input_format.cc
)
add_subdirectory(../anf_exporter anf_exporter)

View File

@ -58,6 +58,7 @@
#include "tools/optimizer/graph/reduce_same_act_pass.h"
#include "tools/optimizer/graph/split_one_pass.h"
#include "tools/optimizer/graph/decrease_transpose_algo.h"
#include "tools/optimizer/graph/specify_graph_input_format.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -382,6 +383,11 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
MS_LOG(ERROR) << "Do Quantize failed.";
return nullptr;
}
if (!RunOptimizerPass(old_graph, {"SpecifyGraphInputFormat"})) {
MS_LOG(ERROR) << "Run transpose opt pass failed.";
return nullptr;
}
return old_graph;
}
@ -393,6 +399,8 @@ void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
registry::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train));
registry::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train));
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
registry::PassRegistry("SpecifyGraphInputFormat",
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
}
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {

View File

@ -76,6 +76,8 @@ Flags::Flags() {
"set this option. Model input shapes is same with origin model by default."
"e.g. inTensor1:1,32,32,32;inTensor2:1,1,32,32,4",
"");
AddFlag(&Flags::graphInputFormatStr, "inputFormat",
"Assign the format of model inputs. Valid only for 4-dimensional input. NHWC | NCHW", "NHWC");
}
int Flags::InitInputOutputDataType() {
@ -211,6 +213,9 @@ int Flags::InitTrainModel() {
}
int Flags::InitInTensorShape() {
if (this->inTensorShape.empty()) {
return RET_OK;
}
std::string content = this->inTensorShape;
std::vector<int64_t> shape;
auto shape_strs = lite::StrSplit(content, std::string(";"));
@ -242,6 +247,18 @@ int Flags::InitInTensorShape() {
return RET_OK;
}
int Flags::InitGraphInputFormat() {
if (this->graphInputFormatStr == "NHWC") {
graphInputFormat = mindspore::NHWC;
} else if (this->graphInputFormatStr == "NCHW") {
graphInputFormat = mindspore::NCHW;
} else if (!this->graphInputFormatStr.empty()) {
MS_LOG(ERROR) << "graph input format is invalid.";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitConfigFile() {
auto plugins_path_str = GetStrFromConfigFile(this->configFile, "plugin_path");
if (!plugins_path_str.empty()) {
@ -351,12 +368,16 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
if (!this->inTensorShape.empty()) {
ret = InitInTensorShape();
if (ret != RET_OK) {
std::cerr << "Init input tensor shape failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitInTensorShape();
if (ret != RET_OK) {
std::cerr << "Init input tensor shape failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitGraphInputFormat();
if (ret != RET_OK) {
std::cerr << "Init graph input format failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}

View File

@ -19,6 +19,7 @@
#include <string>
#include <vector>
#include "include/api/format.h"
#include "include/registry/parser_context.h"
#include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
@ -60,6 +61,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitInTensorShape();
int InitGraphInputFormat();
int Init(int argc, const char **argv);
public:
@ -93,6 +96,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string inTensorShape;
std::string dec_key = "";
std::string dec_mode = "AES-GCM";
std::string graphInputFormatStr;
mindspore::Format graphInputFormat = mindspore::NHWC;
};
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);

View File

@ -343,6 +343,7 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
return lite::RET_ERROR;
}
}
ModifyCNodeFormat(cnode, trans_insert_info->pre_);
status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
@ -442,6 +443,7 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
ModifyCNodeFormat(middle_cnode, trans_info.post_);
status = node_infer_shape_.InferShape(middle_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
@ -587,9 +589,22 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
}
void DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
MS_ASSERT(cnode != nullptr);
if (pre_trans_type == kNONE) {
return;
}
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(primitive != nullptr);
if (pre_trans_type == kNHWC2NCHW) {
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
} else {
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
}
}
bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";

View File

@ -62,6 +62,7 @@ class DecreaseTransposeAlgo : public Pass {
void ResetSubGraphInput();
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
FmkType fmk_type_{converter::FmkType_MS};
bool train_flag_{false};
NodeInferShape node_infer_shape_;

View File

@ -18,6 +18,76 @@
namespace mindspore {
namespace opt {
namespace {
int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Format *format) {
MS_ASSERT(cnode != nullptr && format != nullptr);
auto origin_inputs = cnode->inputs();
lite::RemoveIfDepend(cnode);
lite::RemoveIfMakeTuple(cnode);
RemoveIfMonad(cnode);
if (index <= 0 || static_cast<size_t>(index) >= cnode->size()) {
MS_LOG(ERROR) << "input index out of range";
cnode->set_inputs(origin_inputs);
return lite::RET_ERROR;
}
if (!utils::isa<CNode>(cnode->input(index))) {
cnode->set_inputs(origin_inputs);
return lite::RET_NO_CHANGE;
}
auto real_cnode = cnode->input(index)->cast<CNodePtr>();
if (CheckPrimitiveType(real_cnode, prim::kPrimTupleGetItem)) {
real_cnode = real_cnode->input(1)->cast<CNodePtr>();
}
cnode->set_inputs(origin_inputs);
MS_ASSERT(real_cnode != nullptr);
auto primitive = GetValueNode<PrimitivePtr>(real_cnode->input(0));
MS_ASSERT(primitive != nullptr);
if (primitive->GetAttr(ops::kFormat) == nullptr) {
MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
return lite::RET_ERROR;
}
*format = static_cast<mindspore::Format>(GetValue<int64_t>(primitive->GetAttr(ops::kFormat)));
if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) {
std::vector<int> perm;
if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get transpose perm failed.";
return lite::RET_ERROR;
}
if (perm.size() != 4) {
return RET_OK;
}
if (perm == kNH2NC && *format == mindspore::NHWC) {
*format = mindspore::NCHW;
} else if (perm == kNC2NH && *format == mindspore::NCHW) {
*format = mindspore::NHWC;
}
}
return lite::RET_OK;
}
int ModifySubGraphInputCNodeFormat(const FuncGraphPtr &sub_graph, const ParameterPtr &certain_input,
mindspore::Format format) {
MS_ASSERT(sub_graph != nullptr && certain_input != nullptr);
auto manager = sub_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_users = manager->node_users()[certain_input];
for (auto &node_user : node_users) {
if (node_user.second != 1) {
continue;
}
auto post_cnode = node_user.first->cast<CNodePtr>();
if (post_cnode == nullptr) {
MS_LOG(ERROR) << "post node is not cnode, which is invalid.";
return lite::RET_ERROR;
}
auto primitive = GetValueNode<PrimitivePtr>(post_cnode->input(0));
MS_ASSERT(primitive != nullptr);
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(format));
}
return lite::RET_OK;
}
} // namespace
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is nullptr.";
@ -149,6 +219,14 @@ void InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
}
mindspore::Format format = mindspore::NHWC;
if (GetCNodeCertainInputFormat(cnode, index, &format) != lite::RET_OK) {
MS_LOG(DEBUG) << "has no change for current control node." << cnode->fullname_with_scope();
continue;
}
if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) {
MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var();
}
} else {
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(index))) {

View File

@ -0,0 +1,124 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/specify_graph_input_format.h"
#include <memory>
#include <vector>
#include "tools/optimizer/common/format_utils.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace opt {
bool SpecifyGraphInputFormat::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
if (format_ == mindspore::NHWC) {
return true;
}
if (format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "this pass only support to transfer graph input format from nhwc to nchw.";
return false;
}
auto manager = Manage(graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
if (HandleGraphInput(graph) != lite::RET_OK) {
MS_LOG(ERROR) << "transfer graph input format from nhwc to nchw failed.";
return false;
}
return true;
}
STATUS SpecifyGraphInputFormat::HandleGraphInput(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto manager = graph->manager();
MS_ASSERT(manager != nullptr);
auto graph_inputs = graph->get_inputs();
for (const auto &input : graph_inputs) {
auto input_node = input->cast<ParameterPtr>();
MS_ASSERT(input_node != nullptr);
auto abstract = input_node->abstract();
MS_ASSERT(abstract != nullptr);
ShapeVector shape;
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
return lite::RET_ERROR;
}
if (shape.size() != kInputSizeFour) {
continue;
}
ShapeVector transfer_shape;
if (format_ == mindspore::NCHW) {
transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
} else {
transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
}
CNodePtr trans_cnode;
if (format_ == mindspore::NCHW) {
trans_cnode = opt::GenTransposeNode(graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
} else {
trans_cnode = opt::GenTransposeNode(graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
}
if (trans_cnode == nullptr) {
MS_LOG(ERROR) << "create transpose cnode failed.";
return lite::RET_ERROR;
}
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
MS_ASSERT(trans_prim != nullptr);
if (format_ == mindspore::NCHW) {
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
} else {
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
}
trans_cnode->set_abstract(abstract->Clone());
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
manager->Replace(input, trans_cnode);
if (PostTransposeFusion(graph, trans_cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "post transpose and transpose fusion failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS SpecifyGraphInputFormat::PostTransposeFusion(const FuncGraphPtr &graph, const CNodePtr &cnode) {
MS_ASSERT(graph != nullptr && cnode != nullptr);
std::vector<int> cur_perm;
if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get transpose perm failed.";
return lite::RET_ERROR;
}
auto node_users = graph->manager()->node_users()[cnode];
for (auto &node_user : node_users) {
auto post_node = node_user.first;
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
std::vector<int> post_trans_perm;
auto post_trans_node = post_node->cast<CNodePtr>();
MS_ASSERT(post_trans_node != nullptr);
if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get post transpose node perm failed.";
return lite::RET_ERROR;
}
if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
graph->manager()->Replace(post_node, cnode->input(1));
}
}
}
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SPECIFY_GRAPH_INPUT_FORMAT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SPECIFY_GRAPH_INPUT_FORMAT_H_
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
class SpecifyGraphInputFormat : public Pass {
public:
explicit SpecifyGraphInputFormat(mindspore::Format format = mindspore::NHWC)
: Pass("SpecifyGraphInputFormat"), format_(format) {}
~SpecifyGraphInputFormat() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
STATUS HandleGraphInput(const FuncGraphPtr &graph);
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
mindspore::Format format_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SPECIFY_GRAPH_INPUT_FORMAT_H_