forked from mindspore-Ecosystem/mindspore
graph input support NCHW
This commit is contained in:
parent
223f500bab
commit
d5926ddab7
|
@ -257,6 +257,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/reduce_same_act_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/split_one_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/specify_graph_input_format.cc
|
||||
${LITE_DIR}/tools/optimizer/fisson/eliminate_concat_split.cc
|
||||
${LITE_DIR}/tools/optimizer/fisson/fisson_util.cc
|
||||
${LITE_DIR}/tools/optimizer/fisson/iter_node_outputs.cc
|
||||
|
|
|
@ -739,7 +739,6 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
|
|||
graph_inputs_has_exported_.find(input_node) == graph_inputs_has_exported_.end()) {
|
||||
graph_inputs_has_exported_.insert(input_node);
|
||||
meta_graphT->inputIndex.push_back(meta_graphT->allTensors.size() - 1);
|
||||
meta_graphT->allTensors.back()->format = schema::Format_NHWC;
|
||||
}
|
||||
} else if (input_node->isa<ValueNode>()) {
|
||||
auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node);
|
||||
|
|
|
@ -46,7 +46,6 @@ class AnfExporter {
|
|||
public:
|
||||
AnfExporter() = default;
|
||||
virtual ~AnfExporter() = default;
|
||||
void set_train_flag(bool train_flag) { train_flag_ = train_flag; }
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
|
||||
bool train_flag = false);
|
||||
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
|
|
|
@ -286,15 +286,15 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) {
|
||||
data_info->format_ = mindspore::NHWC;
|
||||
}
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr && !opt::CheckPrimitiveType(cnode, prim::kPrimResize)) {
|
||||
auto value = prim->GetAttr(ops::kFormat);
|
||||
if (value->isa<mindspore::Int64Imm>()) {
|
||||
data_info->format_ = GetValue<int64_t>(value);
|
||||
}
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
data_info->format_ = NHWC;
|
||||
}
|
||||
// attr weightFormat is only used by conv-like ops' second input
|
||||
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
|
|
|
@ -113,6 +113,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/reduce_same_act_pass.cc
|
||||
../optimizer/graph/split_one_pass.cc
|
||||
../optimizer/graph/find_const_subgraph_pass.cc
|
||||
../optimizer/graph/specify_graph_input_format.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_exporter anf_exporter)
|
||||
|
|
|
@ -58,6 +58,7 @@
|
|||
#include "tools/optimizer/graph/reduce_same_act_pass.h"
|
||||
#include "tools/optimizer/graph/split_one_pass.h"
|
||||
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
||||
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
@ -382,6 +383,11 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
MS_LOG(ERROR) << "Do Quantize failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!RunOptimizerPass(old_graph, {"SpecifyGraphInputFormat"})) {
|
||||
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return old_graph;
|
||||
}
|
||||
|
||||
|
@ -393,6 +399,8 @@ void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
|
|||
registry::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train));
|
||||
registry::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train));
|
||||
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
|
||||
registry::PassRegistry("SpecifyGraphInputFormat",
|
||||
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
||||
|
|
|
@ -76,6 +76,8 @@ Flags::Flags() {
|
|||
"set this option. Model input shapes is same with origin model by default."
|
||||
"e.g. inTensor1:1,32,32,32;inTensor2:1,1,32,32,4",
|
||||
"");
|
||||
AddFlag(&Flags::graphInputFormatStr, "inputFormat",
|
||||
"Assign the format of model inputs. Valid only for 4-dimensional input. NHWC | NCHW", "NHWC");
|
||||
}
|
||||
|
||||
int Flags::InitInputOutputDataType() {
|
||||
|
@ -211,6 +213,9 @@ int Flags::InitTrainModel() {
|
|||
}
|
||||
|
||||
int Flags::InitInTensorShape() {
|
||||
if (this->inTensorShape.empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::string content = this->inTensorShape;
|
||||
std::vector<int64_t> shape;
|
||||
auto shape_strs = lite::StrSplit(content, std::string(";"));
|
||||
|
@ -242,6 +247,18 @@ int Flags::InitInTensorShape() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitGraphInputFormat() {
|
||||
if (this->graphInputFormatStr == "NHWC") {
|
||||
graphInputFormat = mindspore::NHWC;
|
||||
} else if (this->graphInputFormatStr == "NCHW") {
|
||||
graphInputFormat = mindspore::NCHW;
|
||||
} else if (!this->graphInputFormatStr.empty()) {
|
||||
MS_LOG(ERROR) << "graph input format is invalid.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitConfigFile() {
|
||||
auto plugins_path_str = GetStrFromConfigFile(this->configFile, "plugin_path");
|
||||
if (!plugins_path_str.empty()) {
|
||||
|
@ -351,12 +368,16 @@ int Flags::Init(int argc, const char **argv) {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!this->inTensorShape.empty()) {
|
||||
ret = InitInTensorShape();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init input tensor shape failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitGraphInputFormat();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init graph input format failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/api/format.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "tools/common/flag_parser.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
@ -60,6 +61,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
|
||||
int InitInTensorShape();
|
||||
|
||||
int InitGraphInputFormat();
|
||||
|
||||
int Init(int argc, const char **argv);
|
||||
|
||||
public:
|
||||
|
@ -93,6 +96,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
std::string inTensorShape;
|
||||
std::string dec_key = "";
|
||||
std::string dec_mode = "AES-GCM";
|
||||
std::string graphInputFormatStr;
|
||||
mindspore::Format graphInputFormat = mindspore::NHWC;
|
||||
};
|
||||
|
||||
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
|
||||
|
|
|
@ -343,6 +343,7 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
ModifyCNodeFormat(cnode, trans_insert_info->pre_);
|
||||
status = node_infer_shape_.InferShape(cnode);
|
||||
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
|
@ -442,6 +443,7 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
|
|||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ModifyCNodeFormat(middle_cnode, trans_info.post_);
|
||||
status = node_infer_shape_.InferShape(middle_cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
|
@ -587,9 +589,22 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
|
|||
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
|
||||
}
|
||||
|
||||
void DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (pre_trans_type == kNONE) {
|
||||
return;
|
||||
}
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
if (pre_trans_type == kNHWC2NCHW) {
|
||||
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
|
||||
} else {
|
||||
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
|
||||
}
|
||||
}
|
||||
|
||||
bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
|
|
|
@ -62,6 +62,7 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
|
|
|
@ -18,6 +18,76 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Format *format) {
|
||||
MS_ASSERT(cnode != nullptr && format != nullptr);
|
||||
auto origin_inputs = cnode->inputs();
|
||||
lite::RemoveIfDepend(cnode);
|
||||
lite::RemoveIfMakeTuple(cnode);
|
||||
RemoveIfMonad(cnode);
|
||||
if (index <= 0 || static_cast<size_t>(index) >= cnode->size()) {
|
||||
MS_LOG(ERROR) << "input index out of range";
|
||||
cnode->set_inputs(origin_inputs);
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<CNode>(cnode->input(index))) {
|
||||
cnode->set_inputs(origin_inputs);
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto real_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
if (CheckPrimitiveType(real_cnode, prim::kPrimTupleGetItem)) {
|
||||
real_cnode = real_cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
cnode->set_inputs(origin_inputs);
|
||||
MS_ASSERT(real_cnode != nullptr);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(real_cnode->input(0));
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
if (primitive->GetAttr(ops::kFormat) == nullptr) {
|
||||
MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
*format = static_cast<mindspore::Format>(GetValue<int64_t>(primitive->GetAttr(ops::kFormat)));
|
||||
if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) {
|
||||
std::vector<int> perm;
|
||||
if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get transpose perm failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (perm.size() != 4) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (perm == kNH2NC && *format == mindspore::NHWC) {
|
||||
*format = mindspore::NCHW;
|
||||
} else if (perm == kNC2NH && *format == mindspore::NCHW) {
|
||||
*format = mindspore::NHWC;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ModifySubGraphInputCNodeFormat(const FuncGraphPtr &sub_graph, const ParameterPtr &certain_input,
|
||||
mindspore::Format format) {
|
||||
MS_ASSERT(sub_graph != nullptr && certain_input != nullptr);
|
||||
auto manager = sub_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_users = manager->node_users()[certain_input];
|
||||
for (auto &node_user : node_users) {
|
||||
if (node_user.second != 1) {
|
||||
continue;
|
||||
}
|
||||
auto post_cnode = node_user.first->cast<CNodePtr>();
|
||||
if (post_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "post node is not cnode, which is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto primitive = GetValueNode<PrimitivePtr>(post_cnode->input(0));
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(format));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func_graph is nullptr.";
|
||||
|
@ -149,6 +219,14 @@ void InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
|
||||
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||
}
|
||||
mindspore::Format format = mindspore::NHWC;
|
||||
if (GetCNodeCertainInputFormat(cnode, index, &format) != lite::RET_OK) {
|
||||
MS_LOG(DEBUG) << "has no change for current control node." << cnode->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) {
|
||||
MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var();
|
||||
}
|
||||
} else {
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool SpecifyGraphInputFormat::Run(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (format_ == mindspore::NHWC) {
|
||||
return true;
|
||||
}
|
||||
if (format_ != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "this pass only support to transfer graph input format from nhwc to nchw.";
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (HandleGraphInput(graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transfer graph input format from nhwc to nchw failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS SpecifyGraphInputFormat::HandleGraphInput(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto manager = graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto graph_inputs = graph->get_inputs();
|
||||
for (const auto &input : graph_inputs) {
|
||||
auto input_node = input->cast<ParameterPtr>();
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
auto abstract = input_node->abstract();
|
||||
MS_ASSERT(abstract != nullptr);
|
||||
ShapeVector shape;
|
||||
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (shape.size() != kInputSizeFour) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector transfer_shape;
|
||||
if (format_ == mindspore::NCHW) {
|
||||
transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
|
||||
} else {
|
||||
transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
|
||||
}
|
||||
CNodePtr trans_cnode;
|
||||
if (format_ == mindspore::NCHW) {
|
||||
trans_cnode = opt::GenTransposeNode(graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
|
||||
} else {
|
||||
trans_cnode = opt::GenTransposeNode(graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
|
||||
}
|
||||
if (trans_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "create transpose cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
|
||||
MS_ASSERT(trans_prim != nullptr);
|
||||
if (format_ == mindspore::NCHW) {
|
||||
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
|
||||
} else {
|
||||
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
}
|
||||
trans_cnode->set_abstract(abstract->Clone());
|
||||
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
|
||||
manager->Replace(input, trans_cnode);
|
||||
if (PostTransposeFusion(graph, trans_cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "post transpose and transpose fusion failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS SpecifyGraphInputFormat::PostTransposeFusion(const FuncGraphPtr &graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(graph != nullptr && cnode != nullptr);
|
||||
std::vector<int> cur_perm;
|
||||
if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get transpose perm failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto node_users = graph->manager()->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto post_node = node_user.first;
|
||||
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
|
||||
std::vector<int> post_trans_perm;
|
||||
auto post_trans_node = post_node->cast<CNodePtr>();
|
||||
MS_ASSERT(post_trans_node != nullptr);
|
||||
if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get post transpose node perm failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
|
||||
graph->manager()->Replace(post_node, cnode->input(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SPECIFY_GRAPH_INPUT_FORMAT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SPECIFY_GRAPH_INPUT_FORMAT_H_
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SpecifyGraphInputFormat : public Pass {
|
||||
public:
|
||||
explicit SpecifyGraphInputFormat(mindspore::Format format = mindspore::NHWC)
|
||||
: Pass("SpecifyGraphInputFormat"), format_(format) {}
|
||||
~SpecifyGraphInputFormat() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
STATUS HandleGraphInput(const FuncGraphPtr &graph);
|
||||
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
mindspore::Format format_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SPECIFY_GRAPH_INPUT_FORMAT_H_
|
Loading…
Reference in New Issue