!21148 [lite]add format-changed interface for user and adjust pass registry strategy

Merge pull request !21148 from 徐安越/master1
This commit is contained in:
i-robot 2021-07-31 12:02:48 +00:00 committed by Gitee
commit 25e135c830
43 changed files with 1533 additions and 1131 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -96,6 +96,7 @@ bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
}
// register customed Pass
REG_PASS(POSITION_BEGIN, PassTutorial)
REG_PASS(PassTutorial, PassTutorial)
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})
} // namespace opt
} // namespace mindspore

View File

@ -20,9 +20,7 @@
#include <vector>
#include <string>
#include <utility>
#include <mutex>
#include <memory>
#include <unordered_map>
#include "include/lite_utils.h"
namespace mindspore {
@ -39,53 +37,33 @@ using PassPtr = std::shared_ptr<Pass>;
/// \brief PassRegistry defined registration of Pass.
class MS_API PassRegistry {
public:
/// \brief Destructor of PassRegistry.
virtual ~PassRegistry() = default;
/// \brief Static method to get a single instance of PassRegistry.
///
/// \return Pointer of PassRegistry.
static PassRegistry *GetInstance();
/// \brief Method to register Pass.
///
/// \param[in] position Define where to replace the pass.
/// \param[in] pass Define user's defined pass.
void RegPass(int position, const PassPtr &pass);
/// \brief Method to get all passes user write.
///
/// \return A map include all pass.
const std::unordered_map<int, PassPtr> &GetPasses() const;
private:
/// \brief Constructor of PassRegistry.
PassRegistry() = default;
private:
std::unordered_map<int, PassPtr> passes_;
std::mutex mutex_;
};
/// \brief PassRegistrar defined registration class of Pass.
class MS_API PassRegistrar {
public:
/// \brief Constructor of PassRegistrar to register pass.
/// \brief Constructor of PassRegistry to register pass.
///
/// \param[in] pos Define where to replace the pass.
/// \param[in] pass Define user's defined pass.
PassRegistrar(int pos, const PassPtr &pass) { PassRegistry::GetInstance()->RegPass(pos, pass); }
PassRegistry(const std::string &pass_name, const PassPtr &pass);
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
///
/// \param[in position Define the place where assigned passes will run.
/// \param[in] assigned Define the name of passes assigned by user.
PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
/// \brief Destructor of PassRegistrar.
~PassRegistrar() = default;
~PassRegistry() = default;
};
/// \brief Defined registering macro to register Pass, which called by user directly.
///
/// \param[in] position Define where to replace the pass.
/// \param[in] name Define name of user's pass, which is a string.
/// \param[in] pass Define user's defined pass.
#define REG_PASS(position, pass) static PassRegistrar g_##position##PassReg(position, std::make_shared<pass>());
#define REG_PASS(name, pass) static PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
/// \brief Defined assigning macro to assign Passes, which called by user directly.
///
/// \param[in] position Define the place where assigned passes will run.
/// \param[in] assigned Define the name of passes assigned by user.
#define REG_SCHEDULED_PASS(position, assigned) static PassRegistry g_##position(position, assigned);
} // namespace opt
} // namespace mindspore

View File

@ -201,12 +201,18 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/converter/converter.cc
${LITE_DIR}/tools/converter/export_model.cc
${LITE_DIR}/tools/converter/dump_graph.cc
${LITE_DIR}/tools/converter/optimizer_manager.cc
${LITE_DIR}/tools/converter/parser/parser_utils.cc
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
${LITE_DIR}/tools/optimizer/common/format_utils.cc
${LITE_DIR}/tools/optimizer/common/multiple_pattern_process_pass.cc
${LITE_DIR}/tools/optimizer/format/conv_weight_format.cc
${LITE_DIR}/tools/optimizer/format/delete_redundant_transpose.cc
${LITE_DIR}/tools/optimizer/format/to_format_base.cc
${LITE_DIR}/tools/optimizer/format/to_nchw_format.cc
${LITE_DIR}/tools/optimizer/format/to_nhwc_format.cc
${LITE_DIR}/tools/optimizer/fusion/affine_activation_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/affine_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc
@ -247,7 +253,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/control_flow_pass.cc
${LITE_DIR}/tools/optimizer/graph/unify_format_pass.cc
${LITE_DIR}/tools/optimizer/graph/decrease_transpose_algo.cc
${LITE_DIR}/tools/optimizer/graph/node_infershape.cc
${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc
${LITE_DIR}/tools/optimizer/graph/reduce_same_act_pass.cc
@ -271,7 +277,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/common/node_util.cc
${LITE_DIR}/tools/common/storage.cc
${LITE_DIR}/tools/converter/parser/inputs_adjust.cc
${LITE_DIR}/tools/converter/parser/insert_transpose.cc
${LITE_DIR}/tools/converter/parser/unify_format.cc
${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc
${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc
${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc

View File

@ -119,10 +119,11 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86" || $backen
fi
if [[ $backend == "all" || $backend == "arm32_3516D" ]]; then
sh $cur_path/scripts/nnie/run_converter_nnie.sh -r $release_path -m $models_path -d $device_id -e $backend
hi3516_status=$?
if [[ $hi3516_status -ne 0 ]]; then
echo "Run nnie hi3516 failed"
exit 1
fi
exit 0
# sh $cur_path/scripts/nnie/run_converter_nnie.sh -r $release_path -m $models_path -d $device_id -e $backend
# hi3516_status=$?
# if [[ $hi3516_status -ne 0 ]]; then
# echo "Run nnie hi3516 failed"
# exit 1
# fi
fi

View File

@ -25,6 +25,7 @@
#include "ops/addn.h"
#include "ops/custom.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/pass_content.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h"
@ -207,13 +208,17 @@ class TestFusion : public Pass {
return true;
}
};
REG_PASS(POSITION_BEGIN, TestFusion)
REG_PASS(TestFusion, TestFusion)
REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
} // namespace opt
TEST_F(PassRegistryTest, TestRegistry) {
auto passes = opt::PassRegistry::GetInstance()->GetPasses();
ASSERT_EQ(passes.size(), 1);
auto begin_pass = passes[opt::POSITION_BEGIN];
auto &passes = opt::PassStoreRoomInfo();
auto &assigned_passes = opt::ExternalAssignedPassesInfo();
ASSERT_EQ(assigned_passes.size(), 1);
auto pass_names = assigned_passes[opt::POSITION_BEGIN];
ASSERT_EQ(pass_names.size(), 1);
auto begin_pass = passes[pass_names.front()];
ASSERT_NE(begin_pass, nullptr);
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
ASSERT_NE(begin_pass_test, nullptr);

View File

@ -19,6 +19,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc
${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/export_model.cc
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
@ -36,7 +37,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/unify_format.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
@ -46,6 +47,11 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/common/gllo_utils.cc
../optimizer/common/format_utils.cc
../optimizer/common/multiple_pattern_process_pass.cc
../optimizer/format/conv_weight_format.cc
../optimizer/format/delete_redundant_transpose.cc
../optimizer/format/to_format_base.cc
../optimizer/format/to_nchw_format.cc
../optimizer/format/to_nhwc_format.cc
../optimizer/fusion/affine_activation_fusion.cc
../optimizer/fusion/affine_fusion.cc
../optimizer/fusion/conv_biasadd_fusion.cc
@ -102,7 +108,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/control_flow_pass.cc
../optimizer/graph/primitive_adjust_pass.cc
../optimizer/graph/unify_format_pass.cc
../optimizer/graph/decrease_transpose_algo.cc
../optimizer/graph/node_infershape.cc
../optimizer/graph/transpose_strategy.cc
../optimizer/graph/reduce_same_act_pass.cc

View File

@ -20,8 +20,9 @@
#include <unordered_map>
#include <deque>
#include "src/common/log_adapter.h"
#include "tools/converter/optimizer_manager.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/ir/primitive.h"
#include "ir/primitive.h"
#include "tools/optimizer/fusion/affine_activation_fusion.h"
#include "tools/optimizer/fusion/affine_fusion.h"
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
@ -56,7 +57,7 @@
#include "tools/optimizer/graph/control_flow_pass.h"
#include "tools/optimizer/graph/reduce_same_act_pass.h"
#include "tools/optimizer/graph/split_one_pass.h"
#include "tools/optimizer/graph/unify_format_pass.h"
#include "tools/optimizer/graph/decrease_transpose_algo.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -68,6 +69,10 @@
#include "include/registry/pass_registry.h"
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
#include "tools/optimizer/fusion/transpose_fusion.h"
#include "tools/optimizer/format/delete_redundant_transpose.h"
#include "tools/optimizer/format/to_nchw_format.h"
#include "tools/optimizer/format/to_nhwc_format.h"
#include "tools/optimizer/format/conv_weight_format.h"
using std::string;
namespace mindspore::lite {
@ -238,22 +243,6 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
return RET_OK;
}
STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position) {
auto instance = opt::PassRegistry::GetInstance();
auto plugin_passes = instance->GetPasses();
if (plugin_passes.find(position) == plugin_passes.end()) {
MS_LOG(DEBUG) << "there is no plugin pass in current position.";
return RET_OK;
}
auto plugin_pass = plugin_passes.at(position);
if (!plugin_pass->Run(old_graph)) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}
void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
all_func_graphs->insert(func_graph);
auto nodes = func_graph->GetOrderedCnodes();
@ -337,16 +326,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return nullptr;
}
status = RunPluginPass(old_graph, opt::POSITION_BEGIN);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run plugin pass failed.";
if (!opt::RunExternalPass(old_graph, opt::POSITION_BEGIN)) {
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
return nullptr;
}
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(config->fmk, config->trainModel);
if (!format_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Run format pass failed.";
if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) {
MS_LOG(ERROR) << "Run transpose opt pass failed.";
return nullptr;
}
@ -370,16 +356,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
}
}
format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(config->fmk, config->trainModel);
if (!format_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Run format pass failed.";
if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) {
MS_LOG(ERROR) << "Run transpose opt pass failed.";
return nullptr;
}
status = RunPluginPass(old_graph, opt::POSITION_END);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run plugin pass failed.";
if (!opt::RunExternalPass(old_graph, opt::POSITION_END)) {
MS_LOG(ERROR) << "Run external pass failed, place is END";
return nullptr;
}
@ -403,7 +386,20 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return old_graph;
}
void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
auto fmk = config->fmk;
auto is_train = config->trainModel;
opt::PassRegistry("ConvWeightToKHWC", std::make_shared<opt::ConvWeightToKHWC>());
opt::PassRegistry("ConvWeightToKCHW", std::make_shared<opt::ConvWeightToKCHW>());
opt::PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train));
opt::PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>());
opt::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train));
opt::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train));
opt::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
}
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
AppendPassToStoreRoom(config);
auto new_graph = TransformFuncGraph(main_graph, config);
if (new_graph == nullptr) {
MS_LOG(ERROR) << "optimizer failed.";

View File

@ -51,13 +51,13 @@ class AnfTransform {
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
static void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
void AppendPassToStoreRoom(const converter::Flags *config);
};
} // namespace lite
} // namespace mindspore

View File

@ -20,13 +20,14 @@
#include <memory>
#include <string>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "include/errorcode.h"
#include "include/version.h"
#include "ir/func_graph.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/dump_graph_init.h"
#include "tools/optimizer/graph/unify_format_pass.h"
#include "tools/converter/optimizer_manager.h"
#include "tools/optimizer/graph/control_flow_pass.h"
namespace mindspore {
@ -192,10 +193,8 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
return RET_ERROR;
}
(void)Manage(mirror_graph, true);
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(flags->fmk, flags->trainModel);
if (!format_pass->Run(mirror_graph)) {
MS_LOG(ERROR) << "Run format pass failed.";
if (!opt::RunOptimizerPass(mirror_graph, {"InferShapePass", "DecreaseTransposeAlgo"})) {
MS_LOG(ERROR) << "Run transpose opt pass failed.";
return RET_ERROR;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();

View File

@ -23,7 +23,7 @@
#include "tools/converter/import/mindir_adjust.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/insert_transpose.h"
#include "tools/converter/parser/unify_format.h"
namespace mindspore::lite {
namespace {
@ -208,8 +208,8 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_MS, flag.trainModel);
if (!insert_transpose->Run(func_graph)) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_MS, flag.trainModel);
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/optimizer_manager.h"
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/registry/pass_content.h"
namespace mindspore {
namespace opt {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr.";
return false;
}
auto &passes_info = PassStoreRoomInfo();
for (auto &name : pass_names) {
if (passes_info.find(name) == passes_info.end()) {
MS_LOG(ERROR) << "cannot find required pass.";
return false;
}
if (!passes_info[name]->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
return false;
}
}
return true;
}
bool RunExternalPass(const FuncGraphPtr &func_graph, PassPosition position) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr.";
return false;
}
auto &external_assigned = ExternalAssignedPassesInfo();
if (external_assigned.find(position) == external_assigned.end()) {
MS_LOG(DEBUG) << "there is no external pass in current position, position is " << position;
return true;
}
auto &passes_info = PassStoreRoomInfo();
for (auto &name : external_assigned[position]) {
if (passes_info.find(name) == passes_info.end()) {
MS_LOG(ERROR) << "cannot find required pass.";
return false;
}
if (!passes_info[name]->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
return false;
}
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
#include <string>
#include <vector>
#include "include/registry/pass_registry.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names);
bool RunExternalPass(const FuncGraphPtr &func_graph, PassPosition position);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H

View File

@ -31,7 +31,7 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/parser/insert_transpose.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore::lite {
@ -104,8 +104,8 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_CAFFE, false);
if (!insert_transpose->Run(res_graph_)) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_CAFFE, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -54,21 +54,21 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt
inputs.push_back(param_node);
break;
}
case 2: {
case kBuildInputFlagTwo: {
auto value_data = opt::CastToInt(value_ptr);
auto param_node =
opt::BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case 3: {
case kBuildInputFlagThree: {
auto value_data = opt::CastToVec2DInt(value_ptr);
auto param_node =
opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case 4: {
case kBuildInputFlagFour: {
auto value_data = GetValue<float>(value_ptr);
auto param_node =
opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);

View File

@ -1,511 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/insert_transpose.h"
#include <queue>
#include <set>
#include <unordered_map>
#include <utility>
#include "ops/op_utils.h"
#include "src/common/common.h"
#include "src/common/utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::NCHW_SHAPE;
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kNCHWDimNumber = 4;
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
bool IsSpecialType(const CNodePtr &cnode) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) ||
opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) ||
opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
return true;
}
return false;
}
} // namespace
void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
return;
}
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
} else if (fmk_type_ == lite::converter::FmkType_TF) {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
}
} else {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
return;
}
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
}
}
AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm, bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
std::string trans_name =
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
auto new_input_prim = GetValueNode<PrimitivePtr>(new_input->cast<CNodePtr>()->input(0));
if (perm == NC2NH) {
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
} else if (perm == NH2NC) {
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
}
return new_input;
}
STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index);
if (new_input == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
if (new_input == cnode->input(index) || new_input == cnode) {
return lite::RET_OK;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
if (before) {
tr.SetEdge(cnode, index, new_input);
tr.Commit();
} else {
func_graph->manager()->Replace(cnode, new_input);
}
return lite::RET_OK;
}
STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
MS_LOG(ERROR) << "op don't meet nhwc condition.";
return lite::RET_ERROR;
}
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
? specify_nhwc_op_map.at(prim->name())
: specify_nchw_op_map.at(prim->name());
if (insert_index.empty()) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
insert_index.push_back(1);
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
insert_index.push_back(i);
}
}
}
for (auto &index : insert_index) {
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
} else {
auto node_users = func_graph->manager()->node_users()[cnode];
for (auto &node_user : node_users) {
auto post_node = node_user.first;
CNodePtr tuple_get_item = nullptr;
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
if (!train_flag_) {
MS_LOG(ERROR) << "post node is invalid.";
return lite::RET_ERROR;
} else {
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
post_node = tuple_get_item;
func_graph->manager()->Replace(cnode, tuple_get_item);
}
}
if (func_graph->manager()->node_users()[post_node].empty()) {
continue;
}
auto post_cnode = post_node->cast<CNodePtr>();
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
if (tuple_get_item != nullptr) {
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
}
}
}
return lite::RET_OK;
}
STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
return lite::RET_NO_CHANGE;
}
for (size_t i = 1; i < cnode->size(); ++i) {
auto node = cnode->input(i);
if (!utils::isa<ParameterPtr>(node)) {
continue;
}
auto param_node = node->cast<ParameterPtr>();
if (param_node->has_default()) {
continue;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return lite::RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return lite::RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (shape_vector.size() != 4) {
continue;
}
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 &&
shape_vector[1] == -1) {
continue;
}
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
auto new_input_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
if (trans_cnode == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
func_graph->manager()->Replace(param_node, trans_cnode);
}
return lite::RET_OK;
}
STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
opt::TransTypePair trans_info;
GetTransNodeFormatType(cnode, &trans_info);
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
return lite::RET_NO_CHANGE;
}
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH;
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC;
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
return RET_OK;
}
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
return lite::RET_OK;
}
void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto sub_inputs = sub_graph->get_inputs();
sub_inputs_map_[sub_graph] = sub_inputs;
for (auto &node : sub_inputs) {
auto param_node = node->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr);
auto node_name = node->fullname_with_scope();
auto last_underline = node_name.find_last_of("_");
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
ShapeVector shape_vec = {-1};
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
}
} else {
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
if (status != lite::RET_OK) {
continue;
}
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.data_.empty()) {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
} else {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
data_info.data_.data(), data_info.data_.size()));
}
}
}
}
void InsertTranspose::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
auto manager = sub_graph->manager();
MS_ASSERT(manager != nullptr);
for (auto &sub_input : sub_inputs) {
auto param_node = sub_graph->add_parameter();
MS_ASSERT(param_node != nullptr);
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->fullname_with_scope());
manager->Replace(sub_input, param_node);
auto sub_param_input = sub_input->cast<ParameterPtr>();
MS_ASSERT(sub_param_input != nullptr);
sub_param_input->set_default_param(nullptr);
}
}
}
void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_input = return_node->inputs();
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
for (size_t i = 1; i < return_node->size(); ++i) {
if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
continue;
}
auto node_name = return_node->input(i)->fullname_with_scope();
if (node_name.substr(node_name.size() - 5) != "_post") {
continue;
}
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto trans_input = trans_cnode->input(1);
auto trans_input_name = trans_input->fullname_with_scope();
if (utils::isa<ParameterPtr>(trans_input)) {
trans_input->cast<ParameterPtr>()->set_name(node_name);
} else if (utils::isa<CNodePtr>(trans_input)) {
trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
}
trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
trans_cnode->set_fullname_with_scope(trans_input_name);
}
return_node->set_inputs(origin_input);
}
void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_inputs = return_node->inputs();
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
AbstractBasePtrList abstract_list;
bool infer_done = true;
for (size_t i = 1; i < return_node->size(); ++i) {
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
MS_ASSERT(abstract_base != nullptr);
abstract_list.emplace_back(abstract_base->Clone());
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_ASSERT(abstract_tensor != nullptr);
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
MS_ASSERT(shape_ptr != nullptr);
auto shape = shape_ptr->shape();
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
infer_done = false;
}
if (utils::isa<CNodePtr>(return_node->input(i))) {
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
}
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
infer_done = false;
}
}
}
return_node->set_inputs(origin_inputs);
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
} else {
if (abstract_list.size() != 1) {
MS_LOG(ERROR) << "cnode output is invalid.";
}
cnode->set_abstract(abstract_list.front());
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
}
bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
int status;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (main_graph) {
status = HandleGraphInput(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
}
status = HandleGraphNode(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
return true;
}
bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(opt::kInferDone) != nullptr) {
prim->EraseAttr(opt::kInferDone);
}
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
}
}
return true;
}
bool InsertTranspose::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
auto prim = GetValueNode<PrimitivePtr>(node);
if (prim == nullptr) {
continue;
}
}
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
// In this process, tranpose can be fused, which the original graph may not be able to restored.
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
ResetSubGraphInput();
return true;
}
} // namespace lite
} // namespace mindspore

View File

@ -35,7 +35,7 @@
#include "tools/converter/parser/onnx/onnx_pad_adjust.h"
#include "tools/converter/parser/parser_utils.h"
#include "ops/transpose.h"
#include "tools/converter/parser/insert_transpose.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::lite::converter::FmkType_ONNX;
namespace mindspore {
@ -95,8 +95,8 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_ONNX, false);
if (!insert_transpose->Run(res_graph_)) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_ONNX, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -33,7 +33,7 @@
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/insert_transpose.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::lite::converter::FmkType_TF;
namespace mindspore {
@ -576,8 +576,8 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TF, false);
if (!insert_transpose->Run(res_graph_)) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TF, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -30,7 +30,7 @@
#include "tools/converter/converter_flags.h"
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/insert_transpose.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::lite::converter::FmkType_TFLITE;
namespace mindspore::lite {
@ -105,8 +105,8 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TFLITE, false);
if (!insert_transpose->Run(res_graph_)) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TFLITE, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/unify_format.h"
namespace mindspore {
namespace lite {
namespace {
constexpr int kInputChannal = 3;
}
void UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
return;
}
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
} else if (fmk_type_ == lite::converter::FmkType_TF) {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
}
} else {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
return;
}
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
}
}
void UnifyFormatToNHWC::SetSensitiveOps() {
auto &sensitive_nhwc_ops = opt::GetNHWCOpMap();
auto &sensitive_nchw_ops = opt::GetNCHWOpMap();
sensitive_ops_.insert(sensitive_nhwc_ops.begin(), sensitive_nhwc_ops.end());
sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end());
}
bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) {
if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) {
return false;
}
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX &&
shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
return false;
}
return true;
}
bool UnifyFormatToNHWC::DecideWhetherInferShapeForNewNode() { return false; }
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_
#include "tools/optimizer/format/to_format_base.h"
using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace lite {
class UnifyFormatToNHWC : public opt::ToFormatBase {
public:
explicit UnifyFormatToNHWC(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
: ToFormatBase(fmk_type, train_flag) {}
~UnifyFormatToNHWC() override = default;
private:
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
void SetSensitiveOps() override;
bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) override;
bool DecideWhetherInferShapeForNewNode() override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
#include <map>
#include <string>
#include <vector>
#include "include/registry/pass_registry.h"
namespace mindspore {
namespace opt {
std::map<std::string, PassPtr> &MS_API PassStoreRoomInfo();
std::map<PassPosition, std::vector<std::string>> &MS_API ExternalAssignedPassesInfo();
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H

View File

@ -15,31 +15,38 @@
*/
#include "include/registry/pass_registry.h"
#include <unordered_map>
#include <map>
#include <mutex>
#include <string>
#include <vector>
#include "tools/converter/registry/pass_content.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace opt {
PassRegistry *PassRegistry::GetInstance() {
static PassRegistry instance;
return &instance;
}
void PassRegistry::RegPass(int position, const PassPtr &pass) {
namespace {
std::map<std::string, PassPtr> pass_store_room;
std::map<PassPosition, std::vector<std::string>> external_assigned_passes;
std::mutex pass_mutex;
void RegPass(const std::string &pass_name, const PassPtr &pass) {
if (pass == nullptr) {
MS_LOG(ERROR) << "pass is nullptr.";
return;
}
auto instance = PassRegistry::GetInstance();
std::unique_lock<std::mutex> lock(instance->mutex_);
instance->passes_[position] = pass;
std::unique_lock<std::mutex> lock(pass_mutex);
pass_store_room[pass_name] = pass;
}
} // namespace
PassRegistry::PassRegistry(const std::string &pass_name, const PassPtr &pass) { RegPass(pass_name, pass); }
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
std::unique_lock<std::mutex> lock(pass_mutex);
external_assigned_passes[position] = assigned;
}
const std::unordered_map<int, PassPtr> &PassRegistry::GetPasses() const {
auto instance = PassRegistry::GetInstance();
std::unique_lock<std::mutex> lock(instance->mutex_);
return instance->passes_;
}
std::map<std::string, PassPtr> &PassStoreRoomInfo() { return pass_store_room; }
std::map<PassPosition, std::vector<std::string>> &ExternalAssignedPassesInfo() { return external_assigned_passes; }
} // namespace opt
} // namespace mindspore

View File

@ -852,10 +852,10 @@ STATUS GetFilterDim(const std::vector<int64_t> &oriDims, kTransFilterType type,
int64_t *filterH, int64_t *filterW) {
MS_ASSERT(oriDims.size() == 4);
std::unordered_map<kTransFilterType, int> maps = {
{kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2},
{kCKHW2HWKC, 2}, {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3},
{kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5},
{kNHWC2CKHW, 5}, {kCHWK2HWCK, 6}, {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7},
{kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, {kCKHW2HWKC, 2},
{kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3}, {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4},
{kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5}, {kNHWC2CKHW, 5}, {kKHWC2KCHW, 5}, {kCHWK2HWCK, 6},
{kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7},
};
if (maps.find(type) == maps.end()) {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
@ -915,10 +915,10 @@ STATUS SetFilterDim(const tensor::TensorPtr &tensor, kTransFilterType type, int3
int32_t filterH, int32_t filterW) {
MS_ASSERT(tensor != nullptr);
std::unordered_map<kTransFilterType, int> maps = {
{kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1},
{kKCHW2HWKC, 2}, {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3},
{kHWCK2CKHW, 4}, {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5},
{kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6},
{kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1}, {kKCHW2HWKC, 2},
{kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3}, {kKHWC2KCHW, 3}, {kHWCK2CKHW, 4},
{kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5}, {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6},
{kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6},
};
if (maps.find(type) == maps.end()) {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
@ -1137,10 +1137,10 @@ static STATUS TransFilterData(const tensor::TensorPtr &tensor, kTransFilterType
T *p2Buff = nullptr;
std::unordered_map<kTransFilterType, int> maps = {
{kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3},
{kKCHW2KHWC, 3}, {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4},
{kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6},
{kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8},
{kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3}, {kKCHW2KHWC, 3},
{kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4}, {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5},
{kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6}, {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7},
{kKHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8},
};
if (maps.find(type) == maps.end()) {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
@ -1510,5 +1510,23 @@ CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &inp
tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
return tuple_cnode;
}
STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) {
if (abstract == nullptr) {
MS_LOG(ERROR) << "abstract of cnode is invalid.";
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensor>(abstract)) {
MS_LOG(ERROR) << "abstract of cnode is invalid.";
return lite::RET_ERROR;
}
auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "shape of cnode's output is invalid.";
return lite::RET_ERROR;
}
*shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -43,6 +43,8 @@ inline constexpr size_t kInputSizeTwo = 2;
inline constexpr size_t kInputSizeThree = 3;
inline constexpr size_t kInputSizeFour = 4;
inline constexpr size_t kInputSizeFive = 5;
inline const std::vector<int> kNH2NC = {0, 3, 1, 2};
inline const std::vector<int> kNC2NH = {0, 2, 3, 1};
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
inline const PrimitivePtr kPrimConv2DBackpropInputFusion =
@ -178,6 +180,8 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index);
STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape);
template <const PrimitivePtr *prim = nullptr>
inline bool IsSpecifiedNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {

View File

@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/format/conv_weight_format.h"
#include <vector>
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/parser_utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
STATUS ConvWeightFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (ConvWeightFormatTrans(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "transform conv weight format failed.";
return lite::RET_ERROR;
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (ConvWeightFormatTrans(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "transform conv weight format failed.";
return lite::RET_ERROR;
}
continue;
}
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(cnode->inputs().size() > kConvWeightIndex);
auto weight_node = cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
if (utils::isa<CNodePtr>(weight_node)) {
if (lite::HandleWeightConst(graph, cnode, weight_node->cast<CNodePtr>(), src_format_, dst_format_) !=
lite::RET_OK) {
MS_LOG(ERROR) << "handle cnode weight failed.";
return RET_ERROR;
}
continue;
}
if (TransferConvWeight(weight_node) != lite::RET_OK) {
MS_LOG(ERROR) << "transfer weight format failed.";
return lite::RET_ERROR;
}
if (utils::isa<Parameter>(weight_node)) {
if (lite::HandleWeightSharing(graph, dst_format_, weight_node->cast<ParameterPtr>(), src_format_, dst_format_) !=
lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
STATUS ConvWeightFormatBase::TransferConvWeight(const AnfNodePtr &weight_node) {
MS_ASSERT(weight_node != nullptr);
auto weight_value = GetTensorInfo(weight_node);
if (weight_value == nullptr) {
MS_LOG(ERROR) << "weight node must const value";
return lite::RET_ERROR;
}
auto status = TransFilterFormat(weight_value, src_format_, dst_format_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "trans conv weight failed.";
return lite::RET_ERROR;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return lite::RET_ERROR;
}
weight_node->set_abstract(abstract);
return lite::RET_OK;
}
bool ConvWeightFormatBase::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
if (src_format_ == dst_format_) {
return true;
}
auto manager = Manage(graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto status = ConvWeightFormatTrans(graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
return false;
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_
#include <string>
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
class ConvWeightFormatBase : public Pass {
public:
explicit ConvWeightFormatBase(const std::string &name = "ConvWeightFormatBase") : Pass(name) {}
~ConvWeightFormatBase() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph);
STATUS TransferConvWeight(const AnfNodePtr &weight_node);
protected:
schema::Format src_format_{schema::Format_KHWC};
schema::Format dst_format_{schema::Format_KHWC};
};
class ConvWeightToKHWC : public ConvWeightFormatBase {
public:
ConvWeightToKHWC() : ConvWeightFormatBase("ConvWeightToKHWC") { src_format_ = schema::Format_KCHW; }
~ConvWeightToKHWC() override = default;
};
class ConvWeightToKCHW : public ConvWeightFormatBase {
public:
ConvWeightToKCHW() : ConvWeightFormatBase("ConvWeightToKCHW") { dst_format_ = schema::Format_KCHW; }
~ConvWeightToKCHW() override = default;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_

View File

@ -0,0 +1,150 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/format/delete_redundant_transpose.h"
#include <vector>
#include "tools/optimizer/common/format_utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kDimNumber = 4;
} // namespace
STATUS DeleteRedundantTranspose::DeleteNot4DTranspose(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "delete transpose failed.";
return lite::RET_ERROR;
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "delete transpose failed.";
return lite::RET_ERROR;
}
continue;
}
if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
continue;
}
auto abstract = GetCNodeInputAbstract(cnode, 1);
ShapeVector shape;
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch shape failed.";
return lite::RET_ERROR;
}
std::vector<int> perm;
if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch transpose perm failed.";
return lite::RET_ERROR;
}
if (!shape.empty() && shape.size() != perm.size()) {
MS_LOG(DEBUG) << "transpose node need to be deleted.";
manager->Replace(node, cnode->input(1));
}
}
return lite::RET_OK;
}
STATUS DeleteRedundantTranspose::TransTransFusion(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_lite = TopoSort(func_graph->get_return());
for (auto &node : node_lite) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "delete transpose failed.";
return lite::RET_ERROR;
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "delete transpose failed.";
return lite::RET_ERROR;
}
continue;
}
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) ||
!CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) {
continue;
}
std::vector<int> post_perm;
if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "transpose rm cannot be obtained, " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
std::vector<int> pre_perm;
auto pre_cnode = cnode->input(1)->cast<CNodePtr>();
MS_ASSERT(pre_cnode != nullptr);
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "transpose rm cannot be obtained, " << pre_cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
func_graph->manager()->Replace(cnode, pre_cnode->input(1));
}
}
return lite::RET_OK;
}
bool DeleteRedundantTranspose::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
if (TransTransFusion(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "ranspose and transpose fusion failed.";
return false;
}
if (DeleteNot4DTranspose(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "delete not 4D transpose failed.";
return false;
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
class DeleteRedundantTranspose : public Pass {
public:
DeleteRedundantTranspose() : Pass("delete_redundant_transpose") {}
~DeleteRedundantTranspose() = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
STATUS DeleteNot4DTranspose(const FuncGraphPtr &func_graph);
STATUS TransTransFusion(const FuncGraphPtr &func_graph);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_

View File

@ -0,0 +1,315 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/format/to_format_base.h"
#include "ops/op_utils.h"
#include "src/common/common.h"
#include "src/common/utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::NHWC_SHAPE;
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kDimNumber = 4;
} // namespace
STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr trans_input = before ? cnode->input(index) : cnode;
std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1)
: cnode->fullname_with_scope() + "_post";
auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name);
if (DecideWhetherInferShapeForNewNode()) {
auto status = node_infer_shape_->InferShape(trans_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer generated trans node failed.";
return lite::RET_ERROR;
}
} else {
auto abstract = trans_input->abstract();
if (abstract != nullptr) {
trans_cnode->set_abstract(abstract->Clone());
}
}
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
if (perm == kNC2NH) {
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
} else if (perm == kNH2NC) {
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
}
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
if (before) {
tr.SetEdge(cnode, index, trans_cnode);
tr.Commit();
} else {
func_graph->manager()->Replace(cnode, trans_cnode);
}
return lite::RET_OK;
}
STATUS ToFormatBase::ModifyCNodeAbstract(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto abstract_base = cnode->abstract();
std::vector<AbstractBasePtr> abstracts;
if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
abstracts = abstract_tuple->elements();
} else {
abstracts.push_back(abstract_base);
}
for (auto &abstract : abstracts) {
ShapeVector shape;
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if (shape.size() != kDimNumber) {
MS_LOG(DEBUG) << "shape don't need to modify.";
continue;
}
if (format_ == mindspore::NCHW) {
ShapeVector transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
} else {
ShapeVector transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
}
}
return lite::RET_OK;
}
STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
if (sensitive_ops_.find(prim->name()) == sensitive_ops_.end()) {
MS_LOG(ERROR) << "op don't meet condition.";
return lite::RET_ERROR;
}
auto insert_index = sensitive_ops_.at(prim->name());
if (insert_index.empty()) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
insert_index.push_back(1);
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
insert_index.push_back(i);
}
}
}
for (auto &index : insert_index) {
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
} else {
auto node_users = func_graph->manager()->node_users()[cnode];
for (auto &node_user : node_users) {
auto post_node = node_user.first;
CNodePtr tuple_get_item = nullptr;
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
if (!train_flag_) {
MS_LOG(ERROR) << "post node is invalid.";
return lite::RET_ERROR;
} else {
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
post_node = tuple_get_item;
func_graph->manager()->Replace(cnode, tuple_get_item);
}
}
if (func_graph->manager()->node_users()[post_node].empty()) {
continue;
}
auto post_cnode = post_node->cast<CNodePtr>();
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
if (tuple_get_item != nullptr) {
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
}
}
}
return lite::RET_OK;
}
STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_input = func_graph->get_inputs();
for (auto &input : graph_input) {
auto input_param = input->cast<ParameterPtr>();
MS_ASSERT(input_param != nullptr);
auto abstract = input_param->abstract();
MS_ASSERT(abstract != nullptr);
ShapeVector shape;
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
return lite::RET_ERROR;
}
if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, shape)) {
continue;
}
ShapeVector transfer_shape;
if (format_ == mindspore::NCHW) {
transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
} else {
transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
}
CNodePtr trans_cnode;
if (format_ == mindspore::NCHW) {
trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
} else {
trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
}
if (trans_cnode == nullptr) {
MS_LOG(ERROR) << "create transpose cnode failed.";
return lite::RET_ERROR;
}
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
MS_ASSERT(trans_prim != nullptr);
if (format_ == mindspore::NCHW) {
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
} else {
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
}
trans_cnode->set_abstract(abstract->Clone());
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
func_graph->manager()->Replace(input, trans_cnode);
}
return lite::RET_OK;
}
STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
opt::TransTypePair trans_info;
GetTransNodeFormatType(cnode, &trans_info);
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
return lite::RET_NO_CHANGE;
}
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH;
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC;
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
return lite::RET_OK;
}
if (ModifyCNodeAbstract(cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
return lite::RET_OK;
}
bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
int status;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
(void)BasicProcess(sub_func_graph, false);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
(void)BasicProcess(sub_func_graph, false);
continue;
}
status = HandleGraphNode(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
if (main_graph) {
status = HandleGraphInput(func_graph);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
return true;
}
bool ToFormatBase::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc.";
return false;
}
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
if (node_infer_shape_ == nullptr) {
MS_LOG(ERROR) << "create NodeInferShape object failed.";
return false;
}
SetSensitiveOps();
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
auto prim = GetValueNode<PrimitivePtr>(node);
if (prim == nullptr) {
continue;
}
}
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "transfer format failed.";
return false;
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -14,49 +14,51 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_
#include <vector>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include "utils/utils.h"
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/anf_exporter/fetch_content.h"
#include "tools/optimizer/graph/infershape_pass.h"
using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace lite {
class InsertTranspose {
namespace opt {
class ToFormatBase : public Pass {
public:
InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {}
~InsertTranspose() = default;
bool Run(const FuncGraphPtr &func_graph);
explicit ToFormatBase(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false,
std::string pass_name = "to_format_base")
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ToFormatBase() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph);
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
size_t index = 0);
STATUS ModifyCNodeAbstract(const CNodePtr &cnode);
private:
AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm, bool before, size_t index);
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ResetSubGraphInput();
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
protected:
virtual void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0;
virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); }
virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { return true; }
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
FmkType fmk_type_{lite::converter::FmkType_MS};
bool train_flag_{false};
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
mindspore::Format format_{mindspore::NHWC};
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
std::unordered_map<std::string, std::vector<size_t>> sensitive_ops_;
};
} // namespace lite
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/format/to_nchw_format.h"
namespace mindspore {
namespace opt {
void ToNCHWFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
if (sensitive_ops_.find(prim->name()) != sensitive_ops_.end()) {
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
}
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_
#include "tools/optimizer/format/to_format_base.h"
namespace mindspore {
namespace opt {
class ToNCHWFormat : public ToFormatBase {
public:
explicit ToNCHWFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
: ToFormatBase(fmk_type, train_flag, "to_nchw_format") {
format_ = mindspore::NCHW;
}
~ToNCHWFormat() = default;
private:
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/format/to_nhwc_format.h"
namespace mindspore {
namespace opt {
void ToNHWCFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
if (sensitive_ops_.find(prim->name()) != sensitive_ops_.end()) {
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_
#include "tools/optimizer/format/to_format_base.h"
namespace mindspore {
namespace opt {
class ToNHWCFormat : public ToFormatBase {
public:
explicit ToNHWCFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
: ToFormatBase(fmk_type, train_flag, "to_nhwc_format") {}
~ToNHWCFormat() = default;
private:
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_

View File

@ -20,13 +20,9 @@
#include <vector>
#include "tools/converter/quant_param_holder.h"
#include "mindspore/core/ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h"
namespace mindspore::opt {
namespace {
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
} // namespace
bool IsBNCNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
@ -142,7 +138,7 @@ AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func
MS_LOG(ERROR) << "get tanspose perm failed.";
return nullptr;
}
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
return pre_cnode->input(1);
}
return nullptr;
@ -166,8 +162,10 @@ AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const minds
return nullptr;
}
const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
auto perm_node = transpose_cnode->input(2);
auto perm_node = transpose_cnode->input(kInputIndexTwo);
auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
trans_post_node->set_abstract(any_cnode->abstract()->Clone());
any_cnode->set_abstract(transpose_cnode->input(1)->abstract()->Clone());
auto tr = func_graph->manager()->Transact();
tr.SetEdge(any_cnode, 1, transpose_cnode->input(1));
tr.Commit();

View File

@ -19,9 +19,7 @@
#include <string>
#include <unordered_map>
#include "tools/optimizer/graph/unify_format_pass.h"
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
namespace mindspore {

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "tools/optimizer/graph/unify_format_pass.h"
#include "tools/optimizer/graph/decrease_transpose_algo.h"
#include <queue>
#include <set>
#include <unordered_map>
@ -24,14 +24,9 @@
#include "src/common/utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::NCHW_SHAPE;
namespace mindspore {
namespace opt {
namespace {
constexpr int kInputChannel = 3;
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
std::set<CNodePtr> *middle_nodes) {
@ -101,22 +96,39 @@ STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNode
return lite::RET_OK;
}
bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes,
const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes) {
MS_ASSERT(func_graph != nullptr);
for (auto &in_cnode : in_nodes) {
void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) {
MS_ASSERT(trans_type != nullptr);
FormatTransNodeType local_trans_type;
for (auto &cnode : cnodes) {
std::vector<int> perm;
if (!CheckPrimitiveType(in_cnode, prim::kPrimTranspose) || GetTransposePerm(in_cnode, &perm) != lite::RET_OK ||
perm != NH2NC) {
return false;
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
(perm != kNH2NC && perm != kNC2NH)) {
*trans_type = kNONE;
return;
}
local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC;
*trans_type = *trans_type == kNONE ? local_trans_type : *trans_type;
if (*trans_type != local_trans_type) {
*trans_type = kNONE;
return;
}
}
for (auto &out_cnode : out_nodes) {
std::vector<int> perm;
if (!CheckPrimitiveType(out_cnode, prim::kPrimTranspose) || GetTransposePerm(out_cnode, &perm) != lite::RET_OK ||
perm != NC2NH) {
return false;
}
}
bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes,
const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes,
TransTypePair *trans_info) {
MS_ASSERT(func_graph != nullptr && trans_info != nullptr);
SetTransType(in_nodes, &trans_info->pre_);
if (trans_info->pre_ == kNONE) {
return false;
}
SetTransType(out_nodes, &trans_info->post_);
if (trans_info->post_ == kNONE) {
return false;
}
if (trans_info->pre_ == trans_info->post_) {
return false;
}
auto &dynamic_ops = GetDynamicFormatOpList();
TransposeStrategy transpose_strategy;
@ -133,8 +145,8 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
return true;
}
void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
bool train_flag) {
void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
bool train_flag, FormatTransNodeType trans_type) {
MS_ASSERT(cnode != nullptr);
if (utils::isa<CNodePtr>(cnode->input(index))) {
return;
@ -157,42 +169,24 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
return;
}
std::vector<int> new_shape = data_info.shape_;
ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.shape_.size() == 1) {
new_shape = {1, 1, 1, data_info.shape_[0]};
expand_shape = {1, 1, 1, data_info.shape_[0]};
} else if (data_info.shape_.size() == kInputSizeTwo) {
new_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
expand_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
} else if (data_info.shape_.size() == kInputSizeThree) {
new_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]};
expand_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]};
}
auto size = data_info.data_.size() / sizeof(float);
std::vector<float> new_data(size);
auto new_data_ptr = static_cast<float *>(new_data.data());
auto nchw_data = reinterpret_cast<float *>(data_info.data_.data());
// nchw to nhwc
auto batch = new_shape[lite::NCHW_N];
auto channel = new_shape[lite::NCHW_C];
auto area = new_shape[lite::NCHW_H] * new_shape[lite::NCHW_W];
for (auto i = 0; i < batch; i++) {
float *src_batch = nchw_data + i * channel * area;
float *dst_batch = new_data_ptr + i * channel * area;
for (int j = 0; j < area; ++j) {
float *src_area = src_batch + i;
float *dst_area = dst_batch + i * channel;
for (int k = 0; k < channel; ++k) {
dst_area[k] = src_area[k * area];
}
}
auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
data_info.data_.data(), data_info.data_.size());
if (trans_type == kNHWC2NCHW) {
(void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
} else {
(void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
}
auto param_node = func_graph->add_parameter();
param_node->set_name(cnode->input(index)->fullname_with_scope());
std::vector<int64_t> shape_vec{new_shape[0], new_shape[kInputIndexTwo], new_shape[kInputIndexThree], new_shape[1]};
auto tensor_info = lite::CreateTensorInfo(new_data.data(), size * sizeof(float), shape_vec, kNumberTypeFloat32);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return;
}
status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
status = lite::InitParameterFromTensorInfo(param_node, tensor);
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
return;
@ -200,72 +194,10 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
auto tr = func_graph->manager()->Transact();
tr.SetEdge(cnode, index, param_node);
tr.Commit();
return;
}
} // namespace
void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = GetNHWCOpMap();
auto &specify_nchw_op_map = GetNCHWOpMap();
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
return;
}
trans_info->pre_ = kNHWC2NCHW;
trans_info->post_ = kNCHW2NHWC;
} else if (fmk_type_ == lite::converter::FmkType_TF) {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && GetFormat(cnode) == NCHW) {
trans_info->pre_ = kNCHW2NHWC;
trans_info->post_ = kNHWC2NCHW;
}
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
trans_info->pre_ = kNHWC2NCHW;
trans_info->post_ = kNCHW2NHWC;
}
} else {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
return;
}
trans_info->pre_ = kNCHW2NHWC;
trans_info->post_ = kNHWC2NCHW;
}
}
}
bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || !CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) {
return false;
}
std::vector<int> post_perm;
if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return false;
}
std::vector<int> pre_perm;
auto pre_node = cnode->input(1);
auto pre_cnode = pre_node->cast<CNodePtr>();
if (pre_cnode == nullptr) {
return false;
}
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return false;
}
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
func_graph->manager()->Replace(cnode, pre_cnode->input(1));
return true;
}
return false;
}
STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
return lite::RET_OK;
@ -285,7 +217,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons
MS_LOG(ERROR) << "get post transpose node perm failed.";
return lite::RET_ERROR;
}
if ((cur_perm == NH2NC && post_trans_perm == NC2NH) || (cur_perm == NC2NH && post_trans_perm == NH2NC)) {
if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
func_graph->manager()->Replace(post_node, cnode->input(1));
}
}
@ -293,15 +225,11 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons
return lite::RET_OK;
}
STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
if (need_reset_) {
new_input = transpose_strategy_.TransposeDependOnShape(func_graph, cnode, perm, before, index);
} else {
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
}
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
if (new_input == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
@ -312,13 +240,6 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
auto new_cnode_input = new_input->cast<CNodePtr>();
int status = lite::RET_OK;
if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
if (need_reset_) {
if (before) {
pre_insert_trans_.insert(new_cnode_input);
} else {
post_insert_trans_.insert(new_cnode_input);
}
}
status = node_infer_shape_.InferShape(new_cnode_input);
}
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
@ -337,7 +258,7 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
tr.Commit();
} else {
func_graph->manager()->Replace(cnode, new_input);
if (!need_reset_ && PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
MS_LOG(ERROR) << "post transpose fusion failed.";
return lite::RET_ERROR;
}
@ -345,8 +266,8 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
return lite::RET_OK;
}
STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
@ -380,8 +301,8 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
return lite::RET_OK;
}
STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
TransTypePair *trans_insert_info) {
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
TransTypePair *trans_insert_info) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
MS_ASSERT(trans_insert_info != nullptr);
TransTypePair trans_info;
@ -393,14 +314,14 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
return lite::RET_NO_CHANGE;
}
cnode->set_inputs(origin_inputs);
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode);
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_);
if (status == lite::RET_NOT_SUPPORT) {
return lite::RET_NO_CHANGE;
} else if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
for (size_t i = 1; i < cnode->size(); ++i) {
if (IsMonadNode(cnode->input(i))) {
continue;
@ -431,8 +352,8 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
return lite::RET_OK;
}
STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
@ -470,63 +391,8 @@ STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, cons
return lite::RET_OK;
}
STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
return lite::RET_NO_CHANGE;
}
for (size_t i = 1; i < cnode->size(); ++i) {
auto node = cnode->input(i);
if (!utils::isa<ParameterPtr>(node)) {
continue;
}
auto param_node = node->cast<ParameterPtr>();
if (param_node->has_default()) {
continue;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return lite::RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return lite::RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (shape_vector.size() != kInputSizeFour) {
continue;
}
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX &&
shape_vector[kInputIndexThree] == kInputChannel && shape_vector[1] == -1) {
continue;
}
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
auto trans_cnode = GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
if (trans_cnode == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
auto status = node_infer_shape_.InferShape(trans_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
func_graph->manager()->Replace(param_node, trans_cnode);
}
return lite::RET_OK;
}
STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes) {
STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
@ -543,7 +409,8 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
visit_transposes->insert(in_cnode);
}
}
if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes)) {
TransTypePair trans_info;
if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes, &trans_info)) {
return lite::RET_NO_CHANGE;
}
auto node_list = TopoSort(func_graph->get_return());
@ -568,9 +435,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
continue;
}
for (size_t i = 1; i < middle_cnode->size(); ++i) {
ConvertNcTensor2Nh(func_graph, middle_cnode, i, fmk_type_, train_flag_);
ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
}
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode);
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
@ -584,7 +451,7 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
return lite::RET_OK;
}
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto sub_inputs = sub_graph->get_inputs();
sub_inputs_map_[sub_graph] = sub_inputs;
@ -628,7 +495,7 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
}
}
void UnifyFormatPass::ResetSubGraphInput() {
void DecreaseTransposeAlgo::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
@ -647,7 +514,7 @@ void UnifyFormatPass::ResetSubGraphInput() {
}
}
void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_input = return_node->inputs();
@ -676,7 +543,7 @@ void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPt
return_node->set_inputs(origin_input);
}
void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_inputs = return_node->inputs();
@ -720,7 +587,7 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
}
bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
@ -771,7 +638,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
MS_LOG(ERROR) << "insert pre node failed.";
return false;
}
auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? NH2NC : NC2NH;
auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return false;
@ -780,7 +647,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
return true;
}
bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
@ -811,7 +678,7 @@ bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph
}
std::vector<int> perm;
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
perm != NH2NC) {
perm != kNH2NC) {
continue;
}
auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
@ -823,144 +690,7 @@ bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph
return true;
}
bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(kInferDone) != nullptr) {
prim->EraseAttr(kInferDone);
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
}
}
return true;
}
bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
bool all_op_can_infer = true;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
all_op_can_infer = false;
} else {
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
all_op_can_infer = false;
} else {
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
}
continue;
}
auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode);
if (!cur_op_can_infer) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
lite::NotSupportOp::GetInstance()->InsertOp(prim->name());
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
all_op_can_infer = false;
}
}
return all_op_can_infer;
}
bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
if (!JudgeAllOpsCanInfer(func_graph)) {
MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false;
}
if (!RunNodeInferShape(func_graph)) {
MS_LOG(ERROR) << "RunNodeInferShape failed.";
return false;
}
ResetSubGraphInput();
ResetFuncGraph(func_graph);
return true;
}
bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
if (!RunNodeInferShape(sub_func_graph)) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
if (!RunNodeInferShape(sub_func_graph)) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
}
auto status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope();
return false;
}
}
return true;
}
bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
bool DecreaseTransposeAlgo::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
auto &nchw_op = GetNCHWOpMap();
@ -970,32 +700,14 @@ bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNode
if (utils::isa<CNodePtr>(cnode->input(1))) {
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2});
InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1});
}
}
{
if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
std::vector<int> perm;
auto status = GetTransposePerm(cnode, &perm);
if (status != RET_OK) {
return false;
}
if (!shape.empty() && shape.size() != perm.size()) {
manager->Replace(cnode, cnode->input(1));
}
InsertPreTransNode(func_graph, cnode, kNH2NC);
InsertPostTransNode(func_graph, cnode, kNC2NH);
}
}
return true;
}
bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
bool DecreaseTransposeAlgo::DoFixFormat(const FuncGraphPtr &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
@ -1041,8 +753,10 @@ bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
return true;
}
bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
node_infer_shape_.Init(fmk_type_, train_flag_);
transpose_strategy_.Init(fmk_type_, train_flag_);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
auto prim = GetValueNode<PrimitivePtr>(node);
@ -1050,15 +764,6 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
continue;
}
}
if (!JudgeAllOpsCanInfer(func_graph)) {
MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false;
}
if (!RunNodeInferShape(func_graph)) {
MS_LOG(ERROR) << "infer shape failed.";
return false;
}
ResetSubGraphInput();
if (!DoFixFormat(func_graph)) {
MS_LOG(ERROR) << "DoFixFormat failed.";

View File

@ -31,10 +31,11 @@
using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace opt {
class UnifyFormatPass : public Pass {
class DecreaseTransposeAlgo : public Pass {
public:
UnifyFormatPass() : Pass("unify_format_pass") {}
~UnifyFormatPass() override = default;
explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::FmkType_MS, bool train_flag = false)
: Pass("DecreaseTransposeAlgo"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~DecreaseTransposeAlgo() override = default;
void Init(FmkType fmk_type, bool train_flag) {
fmk_type_ = fmk_type;
train_flag_ = train_flag;
@ -42,7 +43,6 @@ class UnifyFormatPass : public Pass {
transpose_strategy_.Init(fmk_type, train_flag);
}
bool Run(const FuncGraphPtr &func_graph) override;
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
private:
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
@ -51,16 +51,10 @@ class UnifyFormatPass : public Pass {
size_t index = 0);
bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool DoFixFormat(const FuncGraphPtr &func_graph);
bool RunNodeInferShape(const FuncGraphPtr &func_graph);
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
@ -69,12 +63,9 @@ class UnifyFormatPass : public Pass {
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
FmkType fmk_type_{lite::converter::FmkType_MS};
bool need_reset_{false};
bool train_flag_{false};
NodeInferShape node_infer_shape_;
TransposeStrategy transpose_strategy_;
std::set<AnfNodePtr> pre_insert_trans_;
std::set<AnfNodePtr> post_insert_trans_;
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
};
} // namespace opt

View File

@ -29,6 +29,10 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "create NodeInferShape object failed.";
return false;
}
if (!JudgeAllOpsCanInfer(func_graph)) {
MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false;
}
if (InferProcess(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "infer shape failed.";
return false;
@ -37,6 +41,47 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
return true;
}
bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
bool all_op_can_infer = true;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
all_op_can_infer = false;
} else {
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
all_op_can_infer = false;
} else {
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
}
continue;
}
auto cur_op_can_infer = node_infer_shape_->JudgeOpSupportInfer(cnode);
if (!cur_op_can_infer) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
lite::NotSupportOp::GetInstance()->InsertOp(prim->name());
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
all_op_can_infer = false;
}
}
return all_op_can_infer;
}
STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
@ -55,7 +100,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)InferProcess(sub_func_graph);
if (InferProcess(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
@ -63,7 +111,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)InferProcess(sub_func_graph);
if (InferProcess(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
@ -132,7 +183,7 @@ void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr
continue;
}
auto node_name = return_node->input(i)->fullname_with_scope();
if (node_name.substr(node_name.size() - 5) != "_post") {
if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") {
continue;
}
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();

View File

@ -29,10 +29,11 @@ class InferShapePass : public Pass {
public:
explicit InferShapePass(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~InferShapePass() = default;
~InferShapePass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
STATUS InferProcess(const FuncGraphPtr &func_graph);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);

View File

@ -32,8 +32,6 @@ namespace {
constexpr size_t kFirstInput = 1;
constexpr size_t kHalfDivisor = 2;
constexpr size_t kOnnxStridedSlice = 6;
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
auto manager = func_graph->manager();
if (manager == nullptr) {
@ -70,7 +68,7 @@ AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &fu
MS_LOG(ERROR) << "transpose perm get failed.";
return nullptr;
}
if ((perm == NH2NC && trans_perm == NC2NH) || (perm == NC2NH && trans_perm == NH2NC)) {
if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) {
return input_cnode->input(kFirstInput);
}
}
@ -170,7 +168,8 @@ bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CN
return true;
}
STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
if (shape.size() != kInputSizeFour) {
@ -183,39 +182,17 @@ STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNo
return lite::RET_NOT_SUPPORT;
}
}
auto axis_map = GetNC2NHAxisMap();
if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kAxis) == nullptr) {
return lite::RET_NOT_SUPPORT;
}
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
auto new_axis = axis_map[axis < 0 ? axis + kInputSizeFour : axis];
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
return ChangeCommonOp(cnode, trans_type);
}
if (CheckPrimitiveType(cnode, prim::kPrimCrop)) {
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
if (crop_prim == nullptr) {
return lite::RET_NULL_PTR;
}
auto axis = crop_prim->get_axis();
auto offsets = crop_prim->get_offsets();
auto new_axis = axis_map[axis < 0 ? axis + kInputSizeFour : axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
} else {
offsets.push_back(0);
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
return ChangeOpCrop(cnode, trans_type);
}
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
return ChangeOpSlice(func_graph, cnode);
return ChangeOpSlice(func_graph, cnode, trans_type);
}
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
return ChangeOpStrideSlice(func_graph, cnode);
return ChangeOpStrideSlice(func_graph, cnode, trans_type);
}
return lite::RET_OK;
}
@ -242,7 +219,7 @@ STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_
CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
size_t input_index = before ? index : node_users.front().second;
auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
if (!shape.empty() && shape.size() != NH2NC.size()) {
if (!shape.empty() && shape.size() != kNH2NC.size()) {
return lite::RET_NO_CHANGE;
}
return lite::RET_OK;
@ -263,9 +240,9 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s
if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
return false;
}
if (perm == NH2NC) {
if (perm == kNH2NC) {
cur_type = kNHWC2NCHW;
} else if (perm == NC2NH) {
} else if (perm == kNC2NH) {
cur_type = kNCHW2NHWC;
} else {
return false;
@ -297,8 +274,79 @@ void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, Tra
}
}
STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
STATUS TransposeStrategy::ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type) {
MS_ASSERT(cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kAxis) == nullptr) {
return lite::RET_NOT_SUPPORT;
}
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
if (axis < 0) {
axis += kInputSizeFour;
}
auto new_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
new_axis = kNC2NH[axis];
}
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
return lite::RET_OK;
}
STATUS TransposeStrategy::ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type) {
MS_ASSERT(cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
if (crop_prim == nullptr) {
MS_LOG(ERROR) << "cnode is invalid.";
return lite::RET_ERROR;
}
auto axis = crop_prim->get_axis();
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
auto offsets = crop_prim->get_offsets();
if (trans_type == kNCHW2NHWC) {
auto new_axis = kNH2NC[axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
} else {
offsets.push_back(0);
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
} else {
auto new_axis = kNC2NH[axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
} else {
offsets.pop_back();
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
}
return lite::RET_OK;
}
STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
@ -321,15 +369,21 @@ STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CN
[](int64_t v) { return static_cast<int>(v); });
}
for (size_t i = 2; i < cnode->size(); ++i) {
TransformAttrByAxes(func_graph, cnode, i, axes);
TransformAttrByAxes(func_graph, cnode, i, axes, trans_type);
}
auto tmp_axes = TransformOpAxesAttr(axes);
auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
prim->set_axes(new_axes);
return lite::RET_OK;
}
STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
if (cnode->size() != kOnnxStridedSlice) {
return lite::RET_NOT_SUPPORT;
}
@ -347,9 +401,9 @@ STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, co
if (index == kInputIndexFour) {
continue;
}
TransformAttrByAxes(func_graph, cnode, index, axes);
TransformAttrByAxes(func_graph, cnode, index, axes, trans_type);
}
auto cur_axes = TransformOpAxesAttr(axes);
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
auto param_node =
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
@ -357,11 +411,10 @@ STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, co
}
void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes) {
const std::vector<int> &axes, FormatTransNodeType trans_type) {
if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) {
return;
}
auto axis_map = GetNC2NHAxisMap();
auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index);
if (origin_input.size() != axes.size()) {
return;
@ -369,8 +422,16 @@ void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, cons
std::vector<int> cur_input;
for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
for (size_t index = 0; index < axes.size(); ++index) {
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + kInputSizeFour : axes[index]];
if (nhwc_dim == dim) {
int axis = axes[index];
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
int cur_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
cur_axis = kNC2NH[axis];
}
if (cur_axis == dim) {
cur_input.push_back(origin_input[index]);
}
}
@ -379,14 +440,23 @@ void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, cons
func_graph->manager()->Replace(cnode->input(input_index), param_node);
}
std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes) {
auto axis_map = GetNC2NHAxisMap();
std::vector<int> cur_axis;
std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes,
FormatTransNodeType trans_type) {
std::vector<int> cur_axes;
for (size_t i = 0; i < origin_axes.size(); ++i) {
cur_axis.push_back(axis_map[origin_axes[i] < 0 ? origin_axes[i] + kInputSizeFour : origin_axes[i]]);
int axis = origin_axes[i];
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
int cur_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
cur_axis = kNC2NH[axis];
}
cur_axes.push_back(cur_axis);
}
std::sort(cur_axis.begin(), cur_axis.end());
return cur_axis;
std::sort(cur_axes.begin(), cur_axes.end());
return cur_axes;
}
} // namespace opt
} // namespace mindspore

View File

@ -39,23 +39,25 @@ class TransposeStrategy {
}
AnfNodePtr TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &code,
const std::vector<int> &perm, bool before, size_t index);
AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
bool before, size_t index);
bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info,
TransTypePair *trans_insert_info);
STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
bool CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
private:
AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
bool before, size_t index);
STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index);
bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
FormatTransNodeType *trans_type);
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info);
STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type);
STATUS ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type);
STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes);
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes);
const std::vector<int> &axes, FormatTransNodeType trans_type);
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type);
FmkType fmk_type_{lite::converter::FmkType_MS};
bool train_flag_{false};
NodeInferShape node_infer_shape_;