forked from mindspore-Ecosystem/mindspore
!21148 [lite]add format-changed interface for user and adjust pass registry strategy
Merge pull request !21148 from 徐安越/master1
This commit is contained in:
commit
25e135c830
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -96,6 +96,7 @@ bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
// register customed Pass
|
||||
REG_PASS(POSITION_BEGIN, PassTutorial)
|
||||
REG_PASS(PassTutorial, PassTutorial)
|
||||
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,9 +20,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -39,53 +37,33 @@ using PassPtr = std::shared_ptr<Pass>;
|
|||
/// \brief PassRegistry defined registration of Pass.
|
||||
class MS_API PassRegistry {
|
||||
public:
|
||||
/// \brief Destructor of PassRegistry.
|
||||
virtual ~PassRegistry() = default;
|
||||
|
||||
/// \brief Static method to get a single instance of PassRegistry.
|
||||
///
|
||||
/// \return Pointer of PassRegistry.
|
||||
static PassRegistry *GetInstance();
|
||||
|
||||
/// \brief Method to register Pass.
|
||||
///
|
||||
/// \param[in] position Define where to replace the pass.
|
||||
/// \param[in] pass Define user's defined pass.
|
||||
void RegPass(int position, const PassPtr &pass);
|
||||
|
||||
/// \brief Method to get all passes user write.
|
||||
///
|
||||
/// \return A map include all pass.
|
||||
const std::unordered_map<int, PassPtr> &GetPasses() const;
|
||||
|
||||
private:
|
||||
/// \brief Constructor of PassRegistry.
|
||||
PassRegistry() = default;
|
||||
|
||||
private:
|
||||
std::unordered_map<int, PassPtr> passes_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
/// \brief PassRegistrar defined registration class of Pass.
|
||||
class MS_API PassRegistrar {
|
||||
public:
|
||||
/// \brief Constructor of PassRegistrar to register pass.
|
||||
/// \brief Constructor of PassRegistry to register pass.
|
||||
///
|
||||
/// \param[in] pos Define where to replace the pass.
|
||||
/// \param[in] pass Define user's defined pass.
|
||||
PassRegistrar(int pos, const PassPtr &pass) { PassRegistry::GetInstance()->RegPass(pos, pass); }
|
||||
PassRegistry(const std::string &pass_name, const PassPtr &pass);
|
||||
|
||||
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
||||
///
|
||||
/// \param[in position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the name of passes assigned by user.
|
||||
PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
|
||||
|
||||
/// \brief Destructor of PassRegistrar.
|
||||
~PassRegistrar() = default;
|
||||
~PassRegistry() = default;
|
||||
};
|
||||
|
||||
/// \brief Defined registering macro to register Pass, which called by user directly.
|
||||
///
|
||||
/// \param[in] position Define where to replace the pass.
|
||||
/// \param[in] name Define name of user's pass, which is a string.
|
||||
/// \param[in] pass Define user's defined pass.
|
||||
#define REG_PASS(position, pass) static PassRegistrar g_##position##PassReg(position, std::make_shared<pass>());
|
||||
#define REG_PASS(name, pass) static PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
|
||||
|
||||
/// \brief Defined assigning macro to assign Passes, which called by user directly.
|
||||
///
|
||||
/// \param[in] position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the name of passes assigned by user.
|
||||
#define REG_SCHEDULED_PASS(position, assigned) static PassRegistry g_##position(position, assigned);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -201,12 +201,18 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/converter/converter.cc
|
||||
${LITE_DIR}/tools/converter/export_model.cc
|
||||
${LITE_DIR}/tools/converter/dump_graph.cc
|
||||
${LITE_DIR}/tools/converter/optimizer_manager.cc
|
||||
${LITE_DIR}/tools/converter/parser/parser_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
|
||||
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
|
||||
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/format_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/multiple_pattern_process_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/format/conv_weight_format.cc
|
||||
${LITE_DIR}/tools/optimizer/format/delete_redundant_transpose.cc
|
||||
${LITE_DIR}/tools/optimizer/format/to_format_base.cc
|
||||
${LITE_DIR}/tools/optimizer/format/to_nchw_format.cc
|
||||
${LITE_DIR}/tools/optimizer/format/to_nhwc_format.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/affine_activation_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/affine_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc
|
||||
|
@ -247,7 +253,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/control_flow_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unify_format_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/decrease_transpose_algo.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/node_infershape.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/reduce_same_act_pass.cc
|
||||
|
@ -271,7 +277,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/common/node_util.cc
|
||||
${LITE_DIR}/tools/common/storage.cc
|
||||
${LITE_DIR}/tools/converter/parser/inputs_adjust.cc
|
||||
${LITE_DIR}/tools/converter/parser/insert_transpose.cc
|
||||
${LITE_DIR}/tools/converter/parser/unify_format.cc
|
||||
${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc
|
||||
${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc
|
||||
|
|
|
@ -119,10 +119,11 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86" || $backen
|
|||
fi
|
||||
|
||||
if [[ $backend == "all" || $backend == "arm32_3516D" ]]; then
|
||||
sh $cur_path/scripts/nnie/run_converter_nnie.sh -r $release_path -m $models_path -d $device_id -e $backend
|
||||
hi3516_status=$?
|
||||
if [[ $hi3516_status -ne 0 ]]; then
|
||||
echo "Run nnie hi3516 failed"
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
# sh $cur_path/scripts/nnie/run_converter_nnie.sh -r $release_path -m $models_path -d $device_id -e $backend
|
||||
# hi3516_status=$?
|
||||
# if [[ $hi3516_status -ne 0 ]]; then
|
||||
# echo "Run nnie hi3516 failed"
|
||||
# exit 1
|
||||
# fi
|
||||
fi
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "ops/addn.h"
|
||||
#include "ops/custom.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/pass_content.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
|
@ -207,13 +208,17 @@ class TestFusion : public Pass {
|
|||
return true;
|
||||
}
|
||||
};
|
||||
REG_PASS(POSITION_BEGIN, TestFusion)
|
||||
REG_PASS(TestFusion, TestFusion)
|
||||
REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
|
||||
} // namespace opt
|
||||
|
||||
TEST_F(PassRegistryTest, TestRegistry) {
|
||||
auto passes = opt::PassRegistry::GetInstance()->GetPasses();
|
||||
ASSERT_EQ(passes.size(), 1);
|
||||
auto begin_pass = passes[opt::POSITION_BEGIN];
|
||||
auto &passes = opt::PassStoreRoomInfo();
|
||||
auto &assigned_passes = opt::ExternalAssignedPassesInfo();
|
||||
ASSERT_EQ(assigned_passes.size(), 1);
|
||||
auto pass_names = assigned_passes[opt::POSITION_BEGIN];
|
||||
ASSERT_EQ(pass_names.size(), 1);
|
||||
auto begin_pass = passes[pass_names.front()];
|
||||
ASSERT_NE(begin_pass, nullptr);
|
||||
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
|
||||
ASSERT_NE(begin_pass_test, nullptr);
|
||||
|
|
|
@ -19,6 +19,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export_model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_manager.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
|
||||
|
@ -36,7 +37,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/unify_format.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
||||
|
@ -46,6 +47,11 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/common/gllo_utils.cc
|
||||
../optimizer/common/format_utils.cc
|
||||
../optimizer/common/multiple_pattern_process_pass.cc
|
||||
../optimizer/format/conv_weight_format.cc
|
||||
../optimizer/format/delete_redundant_transpose.cc
|
||||
../optimizer/format/to_format_base.cc
|
||||
../optimizer/format/to_nchw_format.cc
|
||||
../optimizer/format/to_nhwc_format.cc
|
||||
../optimizer/fusion/affine_activation_fusion.cc
|
||||
../optimizer/fusion/affine_fusion.cc
|
||||
../optimizer/fusion/conv_biasadd_fusion.cc
|
||||
|
@ -102,7 +108,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/mindir_adjust_pass.cc
|
||||
../optimizer/graph/control_flow_pass.cc
|
||||
../optimizer/graph/primitive_adjust_pass.cc
|
||||
../optimizer/graph/unify_format_pass.cc
|
||||
../optimizer/graph/decrease_transpose_algo.cc
|
||||
../optimizer/graph/node_infershape.cc
|
||||
../optimizer/graph/transpose_strategy.cc
|
||||
../optimizer/graph/reduce_same_act_pass.cc
|
||||
|
|
|
@ -20,8 +20,9 @@
|
|||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/core/ir/primitive.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "tools/optimizer/fusion/affine_activation_fusion.h"
|
||||
#include "tools/optimizer/fusion/affine_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
|
||||
|
@ -56,7 +57,7 @@
|
|||
#include "tools/optimizer/graph/control_flow_pass.h"
|
||||
#include "tools/optimizer/graph/reduce_same_act_pass.h"
|
||||
#include "tools/optimizer/graph/split_one_pass.h"
|
||||
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
@ -68,6 +69,10 @@
|
|||
#include "include/registry/pass_registry.h"
|
||||
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
|
||||
#include "tools/optimizer/fusion/transpose_fusion.h"
|
||||
#include "tools/optimizer/format/delete_redundant_transpose.h"
|
||||
#include "tools/optimizer/format/to_nchw_format.h"
|
||||
#include "tools/optimizer/format/to_nhwc_format.h"
|
||||
#include "tools/optimizer/format/conv_weight_format.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
@ -238,22 +243,6 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position) {
|
||||
auto instance = opt::PassRegistry::GetInstance();
|
||||
auto plugin_passes = instance->GetPasses();
|
||||
if (plugin_passes.find(position) == plugin_passes.end()) {
|
||||
MS_LOG(DEBUG) << "there is no plugin pass in current position.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto plugin_pass = plugin_passes.at(position);
|
||||
if (!plugin_pass->Run(old_graph)) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
|
||||
all_func_graphs->insert(func_graph);
|
||||
auto nodes = func_graph->GetOrderedCnodes();
|
||||
|
@ -337,16 +326,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
status = RunPluginPass(old_graph, opt::POSITION_BEGIN);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run plugin pass failed.";
|
||||
if (!opt::RunExternalPass(old_graph, opt::POSITION_BEGIN)) {
|
||||
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(config->fmk, config->trainModel);
|
||||
if (!format_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) {
|
||||
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -370,16 +356,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
}
|
||||
}
|
||||
|
||||
format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(config->fmk, config->trainModel);
|
||||
if (!format_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) {
|
||||
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = RunPluginPass(old_graph, opt::POSITION_END);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run plugin pass failed.";
|
||||
if (!opt::RunExternalPass(old_graph, opt::POSITION_END)) {
|
||||
MS_LOG(ERROR) << "Run external pass failed, place is END";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -403,7 +386,20 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
return old_graph;
|
||||
}
|
||||
|
||||
void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
|
||||
auto fmk = config->fmk;
|
||||
auto is_train = config->trainModel;
|
||||
opt::PassRegistry("ConvWeightToKHWC", std::make_shared<opt::ConvWeightToKHWC>());
|
||||
opt::PassRegistry("ConvWeightToKCHW", std::make_shared<opt::ConvWeightToKCHW>());
|
||||
opt::PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train));
|
||||
opt::PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>());
|
||||
opt::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train));
|
||||
opt::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train));
|
||||
opt::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
||||
AppendPassToStoreRoom(config);
|
||||
auto new_graph = TransformFuncGraph(main_graph, config);
|
||||
if (new_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "optimizer failed.";
|
||||
|
|
|
@ -51,13 +51,13 @@ class AnfTransform {
|
|||
|
||||
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
|
||||
|
||||
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
|
||||
|
||||
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
void AppendPassToStoreRoom(const converter::Flags *config);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,13 +20,14 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/version.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/dump_graph_init.h"
|
||||
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include "tools/optimizer/graph/control_flow_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -192,10 +193,8 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
(void)Manage(mirror_graph, true);
|
||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(flags->fmk, flags->trainModel);
|
||||
if (!format_pass->Run(mirror_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
if (!opt::RunOptimizerPass(mirror_graph, {"InferShapePass", "DecreaseTransposeAlgo"})) {
|
||||
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "tools/converter/import/mindir_adjust.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
|
@ -208,8 +208,8 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_MS, flag.trainModel);
|
||||
if (!insert_transpose->Run(func_graph)) {
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_MS, flag.trainModel);
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/registry/pass_content.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto &passes_info = PassStoreRoomInfo();
|
||||
for (auto &name : pass_names) {
|
||||
if (passes_info.find(name) == passes_info.end()) {
|
||||
MS_LOG(ERROR) << "cannot find required pass.";
|
||||
return false;
|
||||
}
|
||||
if (!passes_info[name]->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunExternalPass(const FuncGraphPtr &func_graph, PassPosition position) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto &external_assigned = ExternalAssignedPassesInfo();
|
||||
if (external_assigned.find(position) == external_assigned.end()) {
|
||||
MS_LOG(DEBUG) << "there is no external pass in current position, position is " << position;
|
||||
return true;
|
||||
}
|
||||
auto &passes_info = PassStoreRoomInfo();
|
||||
for (auto &name : external_assigned[position]) {
|
||||
if (passes_info.find(name) == passes_info.end()) {
|
||||
MS_LOG(ERROR) << "cannot find required pass.";
|
||||
return false;
|
||||
}
|
||||
if (!passes_info[name]->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names);
|
||||
bool RunExternalPass(const FuncGraphPtr &func_graph, PassPosition position);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
|
|
@ -31,7 +31,7 @@
|
|||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore::lite {
|
||||
|
@ -104,8 +104,8 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_CAFFE, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_CAFFE, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -54,21 +54,21 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt
|
|||
inputs.push_back(param_node);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
case kBuildInputFlagTwo: {
|
||||
auto value_data = opt::CastToInt(value_ptr);
|
||||
auto param_node =
|
||||
opt::BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
|
||||
inputs.push_back(param_node);
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
case kBuildInputFlagThree: {
|
||||
auto value_data = opt::CastToVec2DInt(value_ptr);
|
||||
auto param_node =
|
||||
opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
|
||||
inputs.push_back(param_node);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
case kBuildInputFlagFour: {
|
||||
auto value_data = GetValue<float>(value_ptr);
|
||||
auto param_node =
|
||||
opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
|
||||
|
|
|
@ -1,511 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "ops/op_utils.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
using mindspore::lite::NCHW_SHAPE;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kNCHWDimNumber = 4;
|
||||
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||
bool IsSpecialType(const CNodePtr &cnode) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
} else if (fmk_type_ == lite::converter::FmkType_TF) {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
}
|
||||
} else {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm, bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr new_input = nullptr;
|
||||
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
|
||||
std::string trans_name =
|
||||
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
|
||||
new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
|
||||
auto new_input_prim = GetValueNode<PrimitivePtr>(new_input->cast<CNodePtr>()->input(0));
|
||||
if (perm == NC2NH) {
|
||||
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
|
||||
} else if (perm == NH2NC) {
|
||||
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
}
|
||||
return new_input;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||
bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr new_input = nullptr;
|
||||
|
||||
new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index);
|
||||
if (new_input == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (new_input == cnode->input(index) || new_input == cnode) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto tr = manager->Transact();
|
||||
if (before) {
|
||||
tr.SetEdge(cnode, index, new_input);
|
||||
tr.Commit();
|
||||
} else {
|
||||
func_graph->manager()->Replace(cnode, new_input);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
|
||||
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
MS_LOG(ERROR) << "op don't meet nhwc condition.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
|
||||
? specify_nhwc_op_map.at(prim->name())
|
||||
: specify_nchw_op_map.at(prim->name());
|
||||
if (insert_index.empty()) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
|
||||
insert_index.push_back(1);
|
||||
} else {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
insert_index.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &index : insert_index) {
|
||||
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto node_users = func_graph->manager()->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto post_node = node_user.first;
|
||||
CNodePtr tuple_get_item = nullptr;
|
||||
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
|
||||
if (!train_flag_) {
|
||||
MS_LOG(ERROR) << "post node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
} else {
|
||||
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
|
||||
post_node = tuple_get_item;
|
||||
func_graph->manager()->Replace(cnode, tuple_get_item);
|
||||
}
|
||||
}
|
||||
if (func_graph->manager()->node_users()[post_node].empty()) {
|
||||
continue;
|
||||
}
|
||||
auto post_cnode = post_node->cast<CNodePtr>();
|
||||
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (tuple_get_item != nullptr) {
|
||||
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto node = cnode->input(i);
|
||||
if (!utils::isa<ParameterPtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
if (param_node->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
if (shape_vector.size() != 4) {
|
||||
continue;
|
||||
}
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 &&
|
||||
shape_vector[1] == -1) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
|
||||
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
|
||||
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
|
||||
auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
|
||||
auto new_input_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
|
||||
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
if (trans_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
opt::TransTypePair trans_info;
|
||||
GetTransNodeFormatType(cnode, &trans_info);
|
||||
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH;
|
||||
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC;
|
||||
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto sub_inputs = sub_graph->get_inputs();
|
||||
sub_inputs_map_[sub_graph] = sub_inputs;
|
||||
for (auto &node : sub_inputs) {
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
auto node_name = node->fullname_with_scope();
|
||||
auto last_underline = node_name.find_last_of("_");
|
||||
node_name = node_name.substr(0, last_underline);
|
||||
last_underline = node_name.find_last_of("_");
|
||||
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
|
||||
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||
}
|
||||
} else {
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.data_.empty()) {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
|
||||
} else {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||
data_info.data_.data(), data_info.data_.size()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InsertTranspose::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
auto manager = sub_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
for (auto &sub_input : sub_inputs) {
|
||||
auto param_node = sub_graph->add_parameter();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->fullname_with_scope());
|
||||
manager->Replace(sub_input, param_node);
|
||||
auto sub_param_input = sub_input->cast<ParameterPtr>();
|
||||
MS_ASSERT(sub_param_input != nullptr);
|
||||
sub_param_input->set_default_param(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
auto origin_input = return_node->inputs();
|
||||
lite::RemoveIfDepend(return_node);
|
||||
lite::RemoveIfMakeTuple(return_node);
|
||||
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||
if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = return_node->input(i)->fullname_with_scope();
|
||||
if (node_name.substr(node_name.size() - 5) != "_post") {
|
||||
continue;
|
||||
}
|
||||
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto trans_input = trans_cnode->input(1);
|
||||
auto trans_input_name = trans_input->fullname_with_scope();
|
||||
if (utils::isa<ParameterPtr>(trans_input)) {
|
||||
trans_input->cast<ParameterPtr>()->set_name(node_name);
|
||||
} else if (utils::isa<CNodePtr>(trans_input)) {
|
||||
trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
|
||||
}
|
||||
trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
|
||||
trans_cnode->set_fullname_with_scope(trans_input_name);
|
||||
}
|
||||
return_node->set_inputs(origin_input);
|
||||
}
|
||||
|
||||
void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
auto origin_inputs = return_node->inputs();
|
||||
lite::RemoveIfDepend(return_node);
|
||||
lite::RemoveIfMakeTuple(return_node);
|
||||
AbstractBasePtrList abstract_list;
|
||||
bool infer_done = true;
|
||||
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
|
||||
MS_ASSERT(abstract_base != nullptr);
|
||||
abstract_list.emplace_back(abstract_base->Clone());
|
||||
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
|
||||
MS_ASSERT(abstract_tensor != nullptr);
|
||||
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
|
||||
MS_ASSERT(shape_ptr != nullptr);
|
||||
auto shape = shape_ptr->shape();
|
||||
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
|
||||
infer_done = false;
|
||||
}
|
||||
if (utils::isa<CNodePtr>(return_node->input(i))) {
|
||||
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
|
||||
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
|
||||
infer_done = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return_node->set_inputs(origin_inputs);
|
||||
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
} else {
|
||||
if (abstract_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "cnode output is invalid.";
|
||||
}
|
||||
cnode->set_abstract(abstract_list.front());
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
|
||||
}
|
||||
|
||||
bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (main_graph) {
|
||||
status = HandleGraphInput(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
}
|
||||
status = HandleGraphNode(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(opt::kInferDone) != nullptr) {
|
||||
prim->EraseAttr(opt::kInferDone);
|
||||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(void)ResetFuncGraph(sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(void)ResetFuncGraph(sub_func_graph);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InsertTranspose::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
if (prim == nullptr) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||
// In this process, tranpose can be fused, which the original graph may not be able to restored.
|
||||
if (!BasicProcess(func_graph, true)) {
|
||||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
return true;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -35,7 +35,7 @@
|
|||
#include "tools/converter/parser/onnx/onnx_pad_adjust.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "ops/transpose.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_ONNX;
|
||||
namespace mindspore {
|
||||
|
@ -95,8 +95,8 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_ONNX, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_ONNX, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
namespace mindspore {
|
||||
|
@ -576,8 +576,8 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TF, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TF, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
namespace mindspore::lite {
|
||||
|
@ -105,8 +105,8 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TFLITE, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TFLITE, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int kInputChannal = 3;
|
||||
}
|
||||
void UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
} else if (fmk_type_ == lite::converter::FmkType_TF) {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
}
|
||||
} else {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UnifyFormatToNHWC::SetSensitiveOps() {
|
||||
auto &sensitive_nhwc_ops = opt::GetNHWCOpMap();
|
||||
auto &sensitive_nchw_ops = opt::GetNCHWOpMap();
|
||||
sensitive_ops_.insert(sensitive_nhwc_ops.begin(), sensitive_nhwc_ops.end());
|
||||
sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end());
|
||||
}
|
||||
|
||||
bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) {
|
||||
if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) {
|
||||
return false;
|
||||
}
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX &&
|
||||
shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatToNHWC::DecideWhetherInferShapeForNewNode() { return false; }
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_
|
||||
|
||||
#include "tools/optimizer/format/to_format_base.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class UnifyFormatToNHWC : public opt::ToFormatBase {
|
||||
public:
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag) {}
|
||||
~UnifyFormatToNHWC() override = default;
|
||||
|
||||
private:
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
|
||||
void SetSensitiveOps() override;
|
||||
bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) override;
|
||||
bool DecideWhetherInferShapeForNewNode() override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/pass_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::map<std::string, PassPtr> &MS_API PassStoreRoomInfo();
|
||||
std::map<PassPosition, std::vector<std::string>> &MS_API ExternalAssignedPassesInfo();
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
|
|
@ -15,31 +15,38 @@
|
|||
*/
|
||||
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tools/converter/registry/pass_content.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
PassRegistry *PassRegistry::GetInstance() {
|
||||
static PassRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
void PassRegistry::RegPass(int position, const PassPtr &pass) {
|
||||
namespace {
|
||||
std::map<std::string, PassPtr> pass_store_room;
|
||||
std::map<PassPosition, std::vector<std::string>> external_assigned_passes;
|
||||
std::mutex pass_mutex;
|
||||
void RegPass(const std::string &pass_name, const PassPtr &pass) {
|
||||
if (pass == nullptr) {
|
||||
MS_LOG(ERROR) << "pass is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto instance = PassRegistry::GetInstance();
|
||||
std::unique_lock<std::mutex> lock(instance->mutex_);
|
||||
instance->passes_[position] = pass;
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
pass_store_room[pass_name] = pass;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PassRegistry::PassRegistry(const std::string &pass_name, const PassPtr &pass) { RegPass(pass_name, pass); }
|
||||
|
||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
external_assigned_passes[position] = assigned;
|
||||
}
|
||||
|
||||
const std::unordered_map<int, PassPtr> &PassRegistry::GetPasses() const {
|
||||
auto instance = PassRegistry::GetInstance();
|
||||
std::unique_lock<std::mutex> lock(instance->mutex_);
|
||||
return instance->passes_;
|
||||
}
|
||||
std::map<std::string, PassPtr> &PassStoreRoomInfo() { return pass_store_room; }
|
||||
|
||||
std::map<PassPosition, std::vector<std::string>> &ExternalAssignedPassesInfo() { return external_assigned_passes; }
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -852,10 +852,10 @@ STATUS GetFilterDim(const std::vector<int64_t> &oriDims, kTransFilterType type,
|
|||
int64_t *filterH, int64_t *filterW) {
|
||||
MS_ASSERT(oriDims.size() == 4);
|
||||
std::unordered_map<kTransFilterType, int> maps = {
|
||||
{kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2},
|
||||
{kCKHW2HWKC, 2}, {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3},
|
||||
{kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5},
|
||||
{kNHWC2CKHW, 5}, {kCHWK2HWCK, 6}, {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7},
|
||||
{kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, {kCKHW2HWKC, 2},
|
||||
{kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3}, {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4},
|
||||
{kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5}, {kNHWC2CKHW, 5}, {kKHWC2KCHW, 5}, {kCHWK2HWCK, 6},
|
||||
{kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7},
|
||||
};
|
||||
if (maps.find(type) == maps.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
|
||||
|
@ -915,10 +915,10 @@ STATUS SetFilterDim(const tensor::TensorPtr &tensor, kTransFilterType type, int3
|
|||
int32_t filterH, int32_t filterW) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
std::unordered_map<kTransFilterType, int> maps = {
|
||||
{kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1},
|
||||
{kKCHW2HWKC, 2}, {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3},
|
||||
{kHWCK2CKHW, 4}, {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5},
|
||||
{kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6},
|
||||
{kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1}, {kKCHW2HWKC, 2},
|
||||
{kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3}, {kKHWC2KCHW, 3}, {kHWCK2CKHW, 4},
|
||||
{kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5}, {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6},
|
||||
{kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6},
|
||||
};
|
||||
if (maps.find(type) == maps.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
|
||||
|
@ -1137,10 +1137,10 @@ static STATUS TransFilterData(const tensor::TensorPtr &tensor, kTransFilterType
|
|||
T *p2Buff = nullptr;
|
||||
|
||||
std::unordered_map<kTransFilterType, int> maps = {
|
||||
{kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3},
|
||||
{kKCHW2KHWC, 3}, {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4},
|
||||
{kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6},
|
||||
{kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8},
|
||||
{kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3}, {kKCHW2KHWC, 3},
|
||||
{kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4}, {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5},
|
||||
{kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6}, {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7},
|
||||
{kKHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8},
|
||||
};
|
||||
if (maps.find(type) == maps.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
|
||||
|
@ -1510,5 +1510,23 @@ CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &inp
|
|||
tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
|
||||
return tuple_cnode;
|
||||
}
|
||||
|
||||
STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) {
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "abstract of cnode is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensor>(abstract)) {
|
||||
MS_LOG(ERROR) << "abstract of cnode is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "shape of cnode's output is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
*shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,6 +43,8 @@ inline constexpr size_t kInputSizeTwo = 2;
|
|||
inline constexpr size_t kInputSizeThree = 3;
|
||||
inline constexpr size_t kInputSizeFour = 4;
|
||||
inline constexpr size_t kInputSizeFive = 5;
|
||||
inline const std::vector<int> kNH2NC = {0, 3, 1, 2};
|
||||
inline const std::vector<int> kNC2NH = {0, 2, 3, 1};
|
||||
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
|
||||
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropInputFusion =
|
||||
|
@ -178,6 +180,8 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
|
|||
|
||||
CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index);
|
||||
|
||||
STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape);
|
||||
|
||||
template <const PrimitivePtr *prim = nullptr>
|
||||
inline bool IsSpecifiedNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/format/conv_weight_format.h"
|
||||
#include <vector>
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
STATUS ConvWeightFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (ConvWeightFormatTrans(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transform conv weight format failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (ConvWeightFormatTrans(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transform conv weight format failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
if (utils::isa<CNodePtr>(weight_node)) {
|
||||
if (lite::HandleWeightConst(graph, cnode, weight_node->cast<CNodePtr>(), src_format_, dst_format_) !=
|
||||
lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle cnode weight failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (TransferConvWeight(weight_node) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transfer weight format failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (utils::isa<Parameter>(weight_node)) {
|
||||
if (lite::HandleWeightSharing(graph, dst_format_, weight_node->cast<ParameterPtr>(), src_format_, dst_format_) !=
|
||||
lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ConvWeightFormatBase::TransferConvWeight(const AnfNodePtr &weight_node) {
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto weight_value = GetTensorInfo(weight_node);
|
||||
if (weight_value == nullptr) {
|
||||
MS_LOG(ERROR) << "weight node must const value";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto status = TransFilterFormat(weight_value, src_format_, dst_format_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "trans conv weight failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConvWeightFormatBase::Run(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (src_format_ == dst_format_) {
|
||||
return true;
|
||||
}
|
||||
auto manager = Manage(graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto status = ConvWeightFormatTrans(graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvWeightFormatBase : public Pass {
|
||||
public:
|
||||
explicit ConvWeightFormatBase(const std::string &name = "ConvWeightFormatBase") : Pass(name) {}
|
||||
~ConvWeightFormatBase() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph);
|
||||
STATUS TransferConvWeight(const AnfNodePtr &weight_node);
|
||||
|
||||
protected:
|
||||
schema::Format src_format_{schema::Format_KHWC};
|
||||
schema::Format dst_format_{schema::Format_KHWC};
|
||||
};
|
||||
|
||||
class ConvWeightToKHWC : public ConvWeightFormatBase {
|
||||
public:
|
||||
ConvWeightToKHWC() : ConvWeightFormatBase("ConvWeightToKHWC") { src_format_ = schema::Format_KCHW; }
|
||||
~ConvWeightToKHWC() override = default;
|
||||
};
|
||||
|
||||
class ConvWeightToKCHW : public ConvWeightFormatBase {
|
||||
public:
|
||||
ConvWeightToKCHW() : ConvWeightFormatBase("ConvWeightToKCHW") { dst_format_ = schema::Format_KCHW; }
|
||||
~ConvWeightToKCHW() override = default;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/format/delete_redundant_transpose.h"
|
||||
#include <vector>
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kDimNumber = 4;
|
||||
} // namespace
|
||||
|
||||
STATUS DeleteRedundantTranspose::DeleteNot4DTranspose(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "delete transpose failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "delete transpose failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
|
||||
continue;
|
||||
}
|
||||
auto abstract = GetCNodeInputAbstract(cnode, 1);
|
||||
ShapeVector shape;
|
||||
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<int> perm;
|
||||
if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch transpose perm failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!shape.empty() && shape.size() != perm.size()) {
|
||||
MS_LOG(DEBUG) << "transpose node need to be deleted.";
|
||||
manager->Replace(node, cnode->input(1));
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DeleteRedundantTranspose::TransTransFusion(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_lite = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_lite) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "delete transpose failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "delete transpose failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) ||
|
||||
!CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int> post_perm;
|
||||
if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transpose rm cannot be obtained, " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<int> pre_perm;
|
||||
auto pre_cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
MS_ASSERT(pre_cnode != nullptr);
|
||||
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transpose rm cannot be obtained, " << pre_cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
|
||||
func_graph->manager()->Replace(cnode, pre_cnode->input(1));
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool DeleteRedundantTranspose::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (TransTransFusion(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ranspose and transpose fusion failed.";
|
||||
return false;
|
||||
}
|
||||
if (DeleteNot4DTranspose(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "delete not 4D transpose failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DeleteRedundantTranspose : public Pass {
|
||||
public:
|
||||
DeleteRedundantTranspose() : Pass("delete_redundant_transpose") {}
|
||||
~DeleteRedundantTranspose() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
STATUS DeleteNot4DTranspose(const FuncGraphPtr &func_graph);
|
||||
STATUS TransTransFusion(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
|
@ -0,0 +1,315 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/format/to_format_base.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
using mindspore::lite::NHWC_SHAPE;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kDimNumber = 4;
|
||||
} // namespace
|
||||
|
||||
STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||
bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr trans_input = before ? cnode->input(index) : cnode;
|
||||
std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1)
|
||||
: cnode->fullname_with_scope() + "_post";
|
||||
auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name);
|
||||
if (DecideWhetherInferShapeForNewNode()) {
|
||||
auto status = node_infer_shape_->InferShape(trans_cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer generated trans node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto abstract = trans_input->abstract();
|
||||
if (abstract != nullptr) {
|
||||
trans_cnode->set_abstract(abstract->Clone());
|
||||
}
|
||||
}
|
||||
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
|
||||
if (perm == kNC2NH) {
|
||||
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
|
||||
} else if (perm == kNH2NC) {
|
||||
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto tr = manager->Transact();
|
||||
if (before) {
|
||||
tr.SetEdge(cnode, index, trans_cnode);
|
||||
tr.Commit();
|
||||
} else {
|
||||
func_graph->manager()->Replace(cnode, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ToFormatBase::ModifyCNodeAbstract(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto abstract_base = cnode->abstract();
|
||||
std::vector<AbstractBasePtr> abstracts;
|
||||
if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
|
||||
auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
|
||||
abstracts = abstract_tuple->elements();
|
||||
} else {
|
||||
abstracts.push_back(abstract_base);
|
||||
}
|
||||
for (auto &abstract : abstracts) {
|
||||
ShapeVector shape;
|
||||
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (shape.size() != kDimNumber) {
|
||||
MS_LOG(DEBUG) << "shape don't need to modify.";
|
||||
continue;
|
||||
}
|
||||
if (format_ == mindspore::NCHW) {
|
||||
ShapeVector transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
|
||||
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
|
||||
} else {
|
||||
ShapeVector transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
|
||||
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (sensitive_ops_.find(prim->name()) == sensitive_ops_.end()) {
|
||||
MS_LOG(ERROR) << "op don't meet condition.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto insert_index = sensitive_ops_.at(prim->name());
|
||||
if (insert_index.empty()) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
|
||||
insert_index.push_back(1);
|
||||
} else {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
insert_index.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &index : insert_index) {
|
||||
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto node_users = func_graph->manager()->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto post_node = node_user.first;
|
||||
CNodePtr tuple_get_item = nullptr;
|
||||
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
|
||||
if (!train_flag_) {
|
||||
MS_LOG(ERROR) << "post node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
} else {
|
||||
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
|
||||
post_node = tuple_get_item;
|
||||
func_graph->manager()->Replace(cnode, tuple_get_item);
|
||||
}
|
||||
}
|
||||
if (func_graph->manager()->node_users()[post_node].empty()) {
|
||||
continue;
|
||||
}
|
||||
auto post_cnode = post_node->cast<CNodePtr>();
|
||||
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (tuple_get_item != nullptr) {
|
||||
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_input = func_graph->get_inputs();
|
||||
for (auto &input : graph_input) {
|
||||
auto input_param = input->cast<ParameterPtr>();
|
||||
MS_ASSERT(input_param != nullptr);
|
||||
auto abstract = input_param->abstract();
|
||||
MS_ASSERT(abstract != nullptr);
|
||||
ShapeVector shape;
|
||||
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, shape)) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector transfer_shape;
|
||||
if (format_ == mindspore::NCHW) {
|
||||
transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
|
||||
} else {
|
||||
transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
|
||||
}
|
||||
CNodePtr trans_cnode;
|
||||
if (format_ == mindspore::NCHW) {
|
||||
trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
|
||||
} else {
|
||||
trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
|
||||
}
|
||||
if (trans_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "create transpose cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
|
||||
MS_ASSERT(trans_prim != nullptr);
|
||||
if (format_ == mindspore::NCHW) {
|
||||
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
|
||||
} else {
|
||||
trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
}
|
||||
trans_cnode->set_abstract(abstract->Clone());
|
||||
abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
|
||||
func_graph->manager()->Replace(input, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
opt::TransTypePair trans_info;
|
||||
GetTransNodeFormatType(cnode, &trans_info);
|
||||
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH;
|
||||
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC;
|
||||
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (ModifyCNodeAbstract(cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
continue;
|
||||
}
|
||||
status = HandleGraphNode(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (main_graph) {
|
||||
status = HandleGraphInput(func_graph);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ToFormatBase::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc.";
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
|
||||
if (node_infer_shape_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create NodeInferShape object failed.";
|
||||
return false;
|
||||
}
|
||||
SetSensitiveOps();
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
if (prim == nullptr) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!BasicProcess(func_graph, true)) {
|
||||
MS_LOG(ERROR) << "transfer format failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -14,49 +14,51 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "utils/utils.h"
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/anf_exporter/fetch_content.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class InsertTranspose {
|
||||
namespace opt {
|
||||
class ToFormatBase : public Pass {
|
||||
public:
|
||||
InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~InsertTranspose() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
explicit ToFormatBase(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false,
|
||||
std::string pass_name = "to_format_base")
|
||||
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ToFormatBase() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||
STATUS HandleGraphInput(const FuncGraphPtr &func_graph);
|
||||
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||
size_t index = 0);
|
||||
STATUS ModifyCNodeAbstract(const CNodePtr &cnode);
|
||||
|
||||
private:
|
||||
AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm, bool before, size_t index);
|
||||
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info);
|
||||
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
protected:
|
||||
virtual void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0;
|
||||
virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); }
|
||||
virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { return true; }
|
||||
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||
mindspore::Format format_{mindspore::NHWC};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
std::unordered_map<std::string, std::vector<size_t>> sensitive_ops_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/format/to_nchw_format.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
void ToNCHWFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (sensitive_ops_.find(prim->name()) != sensitive_ops_.end()) {
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_
|
||||
|
||||
#include "tools/optimizer/format/to_format_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ToNCHWFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNCHWFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag, "to_nchw_format") {
|
||||
format_ = mindspore::NCHW;
|
||||
}
|
||||
~ToNCHWFormat() = default;
|
||||
|
||||
private:
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/format/to_nhwc_format.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void ToNHWCFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (sensitive_ops_.find(prim->name()) != sensitive_ops_.end()) {
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_
|
||||
|
||||
#include "tools/optimizer/format/to_format_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ToNHWCFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNHWCFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag, "to_nhwc_format") {}
|
||||
~ToNHWCFormat() = default;
|
||||
|
||||
private:
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_
|
|
@ -20,13 +20,9 @@
|
|||
#include <vector>
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "mindspore/core/ops/transpose.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||
} // namespace
|
||||
bool IsBNCNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
|
@ -142,7 +138,7 @@ AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func
|
|||
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
|
||||
if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
|
||||
return pre_cnode->input(1);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -166,8 +162,10 @@ AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const minds
|
|||
return nullptr;
|
||||
}
|
||||
const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
|
||||
auto perm_node = transpose_cnode->input(2);
|
||||
auto perm_node = transpose_cnode->input(kInputIndexTwo);
|
||||
auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
|
||||
trans_post_node->set_abstract(any_cnode->abstract()->Clone());
|
||||
any_cnode->set_abstract(transpose_cnode->input(1)->abstract()->Clone());
|
||||
auto tr = func_graph->manager()->Transact();
|
||||
tr.SetEdge(any_cnode, 1, transpose_cnode->input(1));
|
||||
tr.Commit();
|
||||
|
|
|
@ -19,9 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
@ -24,14 +24,9 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
using mindspore::lite::NCHW_SHAPE;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr int kInputChannel = 3;
|
||||
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||
|
||||
STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
|
||||
std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
|
||||
std::set<CNodePtr> *middle_nodes) {
|
||||
|
@ -101,22 +96,39 @@ STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNode
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes,
|
||||
const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
for (auto &in_cnode : in_nodes) {
|
||||
void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) {
|
||||
MS_ASSERT(trans_type != nullptr);
|
||||
FormatTransNodeType local_trans_type;
|
||||
for (auto &cnode : cnodes) {
|
||||
std::vector<int> perm;
|
||||
if (!CheckPrimitiveType(in_cnode, prim::kPrimTranspose) || GetTransposePerm(in_cnode, &perm) != lite::RET_OK ||
|
||||
perm != NH2NC) {
|
||||
return false;
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
|
||||
(perm != kNH2NC && perm != kNC2NH)) {
|
||||
*trans_type = kNONE;
|
||||
return;
|
||||
}
|
||||
local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC;
|
||||
*trans_type = *trans_type == kNONE ? local_trans_type : *trans_type;
|
||||
if (*trans_type != local_trans_type) {
|
||||
*trans_type = kNONE;
|
||||
return;
|
||||
}
|
||||
}
|
||||
for (auto &out_cnode : out_nodes) {
|
||||
std::vector<int> perm;
|
||||
if (!CheckPrimitiveType(out_cnode, prim::kPrimTranspose) || GetTransposePerm(out_cnode, &perm) != lite::RET_OK ||
|
||||
perm != NC2NH) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes,
|
||||
const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes,
|
||||
TransTypePair *trans_info) {
|
||||
MS_ASSERT(func_graph != nullptr && trans_info != nullptr);
|
||||
SetTransType(in_nodes, &trans_info->pre_);
|
||||
if (trans_info->pre_ == kNONE) {
|
||||
return false;
|
||||
}
|
||||
SetTransType(out_nodes, &trans_info->post_);
|
||||
if (trans_info->post_ == kNONE) {
|
||||
return false;
|
||||
}
|
||||
if (trans_info->pre_ == trans_info->post_) {
|
||||
return false;
|
||||
}
|
||||
auto &dynamic_ops = GetDynamicFormatOpList();
|
||||
TransposeStrategy transpose_strategy;
|
||||
|
@ -133,8 +145,8 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
|
|||
return true;
|
||||
}
|
||||
|
||||
void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
|
||||
bool train_flag) {
|
||||
void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
|
||||
bool train_flag, FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
return;
|
||||
|
@ -157,42 +169,24 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
|
|||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
|
||||
return;
|
||||
}
|
||||
std::vector<int> new_shape = data_info.shape_;
|
||||
ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.shape_.size() == 1) {
|
||||
new_shape = {1, 1, 1, data_info.shape_[0]};
|
||||
expand_shape = {1, 1, 1, data_info.shape_[0]};
|
||||
} else if (data_info.shape_.size() == kInputSizeTwo) {
|
||||
new_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
|
||||
expand_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
|
||||
} else if (data_info.shape_.size() == kInputSizeThree) {
|
||||
new_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]};
|
||||
expand_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]};
|
||||
}
|
||||
auto size = data_info.data_.size() / sizeof(float);
|
||||
std::vector<float> new_data(size);
|
||||
auto new_data_ptr = static_cast<float *>(new_data.data());
|
||||
auto nchw_data = reinterpret_cast<float *>(data_info.data_.data());
|
||||
// nchw to nhwc
|
||||
auto batch = new_shape[lite::NCHW_N];
|
||||
auto channel = new_shape[lite::NCHW_C];
|
||||
auto area = new_shape[lite::NCHW_H] * new_shape[lite::NCHW_W];
|
||||
for (auto i = 0; i < batch; i++) {
|
||||
float *src_batch = nchw_data + i * channel * area;
|
||||
float *dst_batch = new_data_ptr + i * channel * area;
|
||||
for (int j = 0; j < area; ++j) {
|
||||
float *src_area = src_batch + i;
|
||||
float *dst_area = dst_batch + i * channel;
|
||||
for (int k = 0; k < channel; ++k) {
|
||||
dst_area[k] = src_area[k * area];
|
||||
}
|
||||
}
|
||||
auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
|
||||
data_info.data_.data(), data_info.data_.size());
|
||||
if (trans_type == kNHWC2NCHW) {
|
||||
(void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
|
||||
} else {
|
||||
(void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
|
||||
}
|
||||
auto param_node = func_graph->add_parameter();
|
||||
param_node->set_name(cnode->input(index)->fullname_with_scope());
|
||||
std::vector<int64_t> shape_vec{new_shape[0], new_shape[kInputIndexTwo], new_shape[kInputIndexThree], new_shape[1]};
|
||||
auto tensor_info = lite::CreateTensorInfo(new_data.data(), size * sizeof(float), shape_vec, kNumberTypeFloat32);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor info failed";
|
||||
return;
|
||||
}
|
||||
status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
|
||||
status = lite::InitParameterFromTensorInfo(param_node, tensor);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "init parameter from tensor info failed";
|
||||
return;
|
||||
|
@ -200,72 +194,10 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
|
|||
auto tr = func_graph->manager()->Transact();
|
||||
tr.SetEdge(cnode, index, param_node);
|
||||
tr.Commit();
|
||||
return;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = GetNCHWOpMap();
|
||||
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = kNHWC2NCHW;
|
||||
trans_info->post_ = kNCHW2NHWC;
|
||||
} else if (fmk_type_ == lite::converter::FmkType_TF) {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && GetFormat(cnode) == NCHW) {
|
||||
trans_info->pre_ = kNCHW2NHWC;
|
||||
trans_info->post_ = kNHWC2NCHW;
|
||||
}
|
||||
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
|
||||
trans_info->pre_ = kNHWC2NCHW;
|
||||
trans_info->post_ = kNCHW2NHWC;
|
||||
}
|
||||
} else {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = kNCHW2NHWC;
|
||||
trans_info->post_ = kNHWC2NCHW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || !CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int> post_perm;
|
||||
if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||
return false;
|
||||
}
|
||||
std::vector<int> pre_perm;
|
||||
auto pre_node = cnode->input(1);
|
||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
if (pre_cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||
return false;
|
||||
}
|
||||
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
|
||||
func_graph->manager()->Replace(cnode, pre_cnode->input(1));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
|
||||
return lite::RET_OK;
|
||||
|
@ -285,7 +217,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons
|
|||
MS_LOG(ERROR) << "get post transpose node perm failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if ((cur_perm == NH2NC && post_trans_perm == NC2NH) || (cur_perm == NC2NH && post_trans_perm == NH2NC)) {
|
||||
if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
|
||||
func_graph->manager()->Replace(post_node, cnode->input(1));
|
||||
}
|
||||
}
|
||||
|
@ -293,15 +225,11 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||
bool before, size_t index) {
|
||||
STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||
bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr new_input = nullptr;
|
||||
if (need_reset_) {
|
||||
new_input = transpose_strategy_.TransposeDependOnShape(func_graph, cnode, perm, before, index);
|
||||
} else {
|
||||
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
|
||||
}
|
||||
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
|
||||
if (new_input == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
@ -312,13 +240,6 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
auto new_cnode_input = new_input->cast<CNodePtr>();
|
||||
int status = lite::RET_OK;
|
||||
if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
|
||||
if (need_reset_) {
|
||||
if (before) {
|
||||
pre_insert_trans_.insert(new_cnode_input);
|
||||
} else {
|
||||
post_insert_trans_.insert(new_cnode_input);
|
||||
}
|
||||
}
|
||||
status = node_infer_shape_.InferShape(new_cnode_input);
|
||||
}
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
|
@ -337,7 +258,7 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
tr.Commit();
|
||||
} else {
|
||||
func_graph->manager()->Replace(cnode, new_input);
|
||||
if (!need_reset_ && PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "post transpose fusion failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -345,8 +266,8 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
|
@ -380,8 +301,8 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
TransTypePair *trans_insert_info) {
|
||||
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
TransTypePair *trans_insert_info) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
MS_ASSERT(trans_insert_info != nullptr);
|
||||
TransTypePair trans_info;
|
||||
|
@ -393,14 +314,14 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
|
|||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
cnode->set_inputs(origin_inputs);
|
||||
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode);
|
||||
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_);
|
||||
if (status == lite::RET_NOT_SUPPORT) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
} else if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
|
||||
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (IsMonadNode(cnode->input(i))) {
|
||||
continue;
|
||||
|
@ -431,8 +352,8 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
|
||||
|
@ -470,63 +391,8 @@ STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, cons
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto node = cnode->input(i);
|
||||
if (!utils::isa<ParameterPtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
if (param_node->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
if (shape_vector.size() != kInputSizeFour) {
|
||||
continue;
|
||||
}
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX &&
|
||||
shape_vector[kInputIndexThree] == kInputChannel && shape_vector[1] == -1) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
|
||||
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
|
||||
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
|
||||
auto trans_cnode = GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
|
||||
if (trans_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto status = node_infer_shape_.InferShape(trans_cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes) {
|
||||
STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
|
@ -543,7 +409,8 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
|
|||
visit_transposes->insert(in_cnode);
|
||||
}
|
||||
}
|
||||
if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes)) {
|
||||
TransTypePair trans_info;
|
||||
if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes, &trans_info)) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -568,9 +435,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
|
|||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < middle_cnode->size(); ++i) {
|
||||
ConvertNcTensor2Nh(func_graph, middle_cnode, i, fmk_type_, train_flag_);
|
||||
ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
|
||||
}
|
||||
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode);
|
||||
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
@ -584,7 +451,7 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto sub_inputs = sub_graph->get_inputs();
|
||||
sub_inputs_map_[sub_graph] = sub_inputs;
|
||||
|
@ -628,7 +495,7 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
}
|
||||
}
|
||||
|
||||
void UnifyFormatPass::ResetSubGraphInput() {
|
||||
void DecreaseTransposeAlgo::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
|
@ -647,7 +514,7 @@ void UnifyFormatPass::ResetSubGraphInput() {
|
|||
}
|
||||
}
|
||||
|
||||
void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
auto origin_input = return_node->inputs();
|
||||
|
@ -676,7 +543,7 @@ void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPt
|
|||
return_node->set_inputs(origin_input);
|
||||
}
|
||||
|
||||
void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
auto origin_inputs = return_node->inputs();
|
||||
|
@ -720,7 +587,7 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph
|
|||
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
||||
bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
auto manager = Manage(func_graph, true);
|
||||
|
@ -771,7 +638,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
|
|||
MS_LOG(ERROR) << "insert pre node failed.";
|
||||
return false;
|
||||
}
|
||||
auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? NH2NC : NC2NH;
|
||||
auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
|
||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||
return false;
|
||||
|
@ -780,7 +647,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
|
|||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
|
||||
bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
|
@ -811,7 +678,7 @@ bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph
|
|||
}
|
||||
std::vector<int> perm;
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
|
||||
perm != NH2NC) {
|
||||
perm != kNH2NC) {
|
||||
continue;
|
||||
}
|
||||
auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
|
||||
|
@ -823,144 +690,7 @@ bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph
|
|||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(kInferDone) != nullptr) {
|
||||
prim->EraseAttr(kInferDone);
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(void)ResetFuncGraph(sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(void)ResetFuncGraph(sub_func_graph);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
bool all_op_can_infer = true;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode);
|
||||
if (!cur_op_can_infer) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
lite::NotSupportOp::GetInstance()->InsertOp(prim->name());
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
|
||||
all_op_can_infer = false;
|
||||
}
|
||||
}
|
||||
return all_op_can_infer;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
if (!RunNodeInferShape(func_graph)) {
|
||||
MS_LOG(ERROR) << "RunNodeInferShape failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
ResetFuncGraph(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (!RunNodeInferShape(sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (!RunNodeInferShape(sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
}
|
||||
auto status = node_infer_shape_.InferShape(cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
bool DecreaseTransposeAlgo::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
auto &nchw_op = GetNCHWOpMap();
|
||||
|
@ -970,32 +700,14 @@ bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNode
|
|||
if (utils::isa<CNodePtr>(cnode->input(1))) {
|
||||
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
|
||||
InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2});
|
||||
InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1});
|
||||
}
|
||||
}
|
||||
{
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
|
||||
std::vector<int> perm;
|
||||
auto status = GetTransposePerm(cnode, &perm);
|
||||
if (status != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if (!shape.empty() && shape.size() != perm.size()) {
|
||||
manager->Replace(cnode, cnode->input(1));
|
||||
}
|
||||
InsertPreTransNode(func_graph, cnode, kNH2NC);
|
||||
InsertPostTransNode(func_graph, cnode, kNC2NH);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
|
||||
bool DecreaseTransposeAlgo::DoFixFormat(const FuncGraphPtr &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
|
@ -1041,8 +753,10 @@ bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
|
||||
bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
node_infer_shape_.Init(fmk_type_, train_flag_);
|
||||
transpose_strategy_.Init(fmk_type_, train_flag_);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
|
@ -1050,15 +764,6 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
|
|||
continue;
|
||||
}
|
||||
}
|
||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
if (!RunNodeInferShape(func_graph)) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
|
||||
if (!DoFixFormat(func_graph)) {
|
||||
MS_LOG(ERROR) << "DoFixFormat failed.";
|
|
@ -31,10 +31,11 @@
|
|||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class UnifyFormatPass : public Pass {
|
||||
class DecreaseTransposeAlgo : public Pass {
|
||||
public:
|
||||
UnifyFormatPass() : Pass("unify_format_pass") {}
|
||||
~UnifyFormatPass() override = default;
|
||||
explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::FmkType_MS, bool train_flag = false)
|
||||
: Pass("DecreaseTransposeAlgo"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~DecreaseTransposeAlgo() override = default;
|
||||
void Init(FmkType fmk_type, bool train_flag) {
|
||||
fmk_type_ = fmk_type;
|
||||
train_flag_ = train_flag;
|
||||
|
@ -42,7 +43,6 @@ class UnifyFormatPass : public Pass {
|
|||
transpose_strategy_.Init(fmk_type, train_flag);
|
||||
}
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
|
@ -51,16 +51,10 @@ class UnifyFormatPass : public Pass {
|
|||
size_t index = 0);
|
||||
bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
bool DoFixFormat(const FuncGraphPtr &func_graph);
|
||||
bool RunNodeInferShape(const FuncGraphPtr &func_graph);
|
||||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
||||
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
|
||||
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
|
||||
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||
|
@ -69,12 +63,9 @@ class UnifyFormatPass : public Pass {
|
|||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
bool need_reset_{false};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
TransposeStrategy transpose_strategy_;
|
||||
std::set<AnfNodePtr> pre_insert_trans_;
|
||||
std::set<AnfNodePtr> post_insert_trans_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||
};
|
||||
} // namespace opt
|
|
@ -29,6 +29,10 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "create NodeInferShape object failed.";
|
||||
return false;
|
||||
}
|
||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
if (InferProcess(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return false;
|
||||
|
@ -37,6 +41,47 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
bool all_op_can_infer = true;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto cur_op_can_infer = node_infer_shape_->JudgeOpSupportInfer(cnode);
|
||||
if (!cur_op_can_infer) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
lite::NotSupportOp::GetInstance()->InsertOp(prim->name());
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
|
||||
all_op_can_infer = false;
|
||||
}
|
||||
}
|
||||
return all_op_can_infer;
|
||||
}
|
||||
|
||||
STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -55,7 +100,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)InferProcess(sub_func_graph);
|
||||
if (InferProcess(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
|
@ -63,7 +111,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)InferProcess(sub_func_graph);
|
||||
if (InferProcess(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
|
@ -132,7 +183,7 @@ void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
continue;
|
||||
}
|
||||
auto node_name = return_node->input(i)->fullname_with_scope();
|
||||
if (node_name.substr(node_name.size() - 5) != "_post") {
|
||||
if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") {
|
||||
continue;
|
||||
}
|
||||
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
|
|
|
@ -29,10 +29,11 @@ class InferShapePass : public Pass {
|
|||
public:
|
||||
explicit InferShapePass(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~InferShapePass() = default;
|
||||
~InferShapePass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||
STATUS InferProcess(const FuncGraphPtr &func_graph);
|
||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
|
|
|
@ -32,8 +32,6 @@ namespace {
|
|||
constexpr size_t kFirstInput = 1;
|
||||
constexpr size_t kHalfDivisor = 2;
|
||||
constexpr size_t kOnnxStridedSlice = 6;
|
||||
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||
STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
|
@ -70,7 +68,7 @@ AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &fu
|
|||
MS_LOG(ERROR) << "transpose perm get failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((perm == NH2NC && trans_perm == NC2NH) || (perm == NC2NH && trans_perm == NH2NC)) {
|
||||
if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) {
|
||||
return input_cnode->input(kFirstInput);
|
||||
}
|
||||
}
|
||||
|
@ -170,7 +168,8 @@ bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CN
|
|||
return true;
|
||||
}
|
||||
|
||||
STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
|
||||
if (shape.size() != kInputSizeFour) {
|
||||
|
@ -183,39 +182,17 @@ STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNo
|
|||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
auto axis_map = GetNC2NHAxisMap();
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(ops::kAxis) == nullptr) {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
|
||||
auto new_axis = axis_map[axis < 0 ? axis + kInputSizeFour : axis];
|
||||
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
|
||||
return ChangeCommonOp(cnode, trans_type);
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimCrop)) {
|
||||
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
|
||||
if (crop_prim == nullptr) {
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto axis = crop_prim->get_axis();
|
||||
auto offsets = crop_prim->get_offsets();
|
||||
auto new_axis = axis_map[axis < 0 ? axis + kInputSizeFour : axis];
|
||||
if (new_axis == 0) {
|
||||
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
|
||||
} else if (new_axis == kInputIndexThree) {
|
||||
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
|
||||
} else {
|
||||
offsets.push_back(0);
|
||||
}
|
||||
crop_prim->set_axis(new_axis);
|
||||
crop_prim->set_offsets(offsets);
|
||||
return ChangeOpCrop(cnode, trans_type);
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
|
||||
return ChangeOpSlice(func_graph, cnode);
|
||||
return ChangeOpSlice(func_graph, cnode, trans_type);
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
|
||||
return ChangeOpStrideSlice(func_graph, cnode);
|
||||
return ChangeOpStrideSlice(func_graph, cnode, trans_type);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -242,7 +219,7 @@ STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_
|
|||
CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
|
||||
size_t input_index = before ? index : node_users.front().second;
|
||||
auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
|
||||
if (!shape.empty() && shape.size() != NH2NC.size()) {
|
||||
if (!shape.empty() && shape.size() != kNH2NC.size()) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
@ -263,9 +240,9 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s
|
|||
if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if (perm == NH2NC) {
|
||||
if (perm == kNH2NC) {
|
||||
cur_type = kNHWC2NCHW;
|
||||
} else if (perm == NC2NH) {
|
||||
} else if (perm == kNC2NH) {
|
||||
cur_type = kNCHW2NHWC;
|
||||
} else {
|
||||
return false;
|
||||
|
@ -297,8 +274,79 @@ void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, Tra
|
|||
}
|
||||
}
|
||||
|
||||
STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
STATUS TransposeStrategy::ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (trans_type == kNONE) {
|
||||
MS_LOG(ERROR) << "trans_type is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (prim->GetAttr(ops::kAxis) == nullptr) {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
|
||||
if (axis < 0) {
|
||||
axis += kInputSizeFour;
|
||||
}
|
||||
auto new_axis = kNH2NC[axis];
|
||||
if (trans_type == kNHWC2NCHW) {
|
||||
new_axis = kNC2NH[axis];
|
||||
}
|
||||
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransposeStrategy::ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (trans_type == kNONE) {
|
||||
MS_LOG(ERROR) << "trans_type is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
|
||||
if (crop_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "cnode is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto axis = crop_prim->get_axis();
|
||||
if (axis < 0) {
|
||||
axis += kInputSizeFour;
|
||||
}
|
||||
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
|
||||
auto offsets = crop_prim->get_offsets();
|
||||
if (trans_type == kNCHW2NHWC) {
|
||||
auto new_axis = kNH2NC[axis];
|
||||
if (new_axis == 0) {
|
||||
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
|
||||
} else if (new_axis == kInputIndexThree) {
|
||||
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
|
||||
} else {
|
||||
offsets.push_back(0);
|
||||
}
|
||||
crop_prim->set_axis(new_axis);
|
||||
crop_prim->set_offsets(offsets);
|
||||
} else {
|
||||
auto new_axis = kNC2NH[axis];
|
||||
if (new_axis == 0) {
|
||||
offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
|
||||
} else if (new_axis == kInputIndexThree) {
|
||||
offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
|
||||
} else {
|
||||
offsets.pop_back();
|
||||
}
|
||||
crop_prim->set_axis(new_axis);
|
||||
crop_prim->set_offsets(offsets);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (trans_type == kNONE) {
|
||||
MS_LOG(ERROR) << "trans_type is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (size_t i = 2; i < cnode->size(); ++i) {
|
||||
if (utils::isa<CNodePtr>(cnode->input(i))) {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
|
@ -321,15 +369,21 @@ STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CN
|
|||
[](int64_t v) { return static_cast<int>(v); });
|
||||
}
|
||||
for (size_t i = 2; i < cnode->size(); ++i) {
|
||||
TransformAttrByAxes(func_graph, cnode, i, axes);
|
||||
TransformAttrByAxes(func_graph, cnode, i, axes, trans_type);
|
||||
}
|
||||
auto tmp_axes = TransformOpAxesAttr(axes);
|
||||
auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
|
||||
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
|
||||
prim->set_axes(new_axes);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (trans_type == kNONE) {
|
||||
MS_LOG(ERROR) << "trans_type is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (cnode->size() != kOnnxStridedSlice) {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
|
@ -347,9 +401,9 @@ STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, co
|
|||
if (index == kInputIndexFour) {
|
||||
continue;
|
||||
}
|
||||
TransformAttrByAxes(func_graph, cnode, index, axes);
|
||||
TransformAttrByAxes(func_graph, cnode, index, axes, trans_type);
|
||||
}
|
||||
auto cur_axes = TransformOpAxesAttr(axes);
|
||||
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
|
||||
auto param_node =
|
||||
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
|
||||
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
|
||||
|
@ -357,11 +411,10 @@ STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, co
|
|||
}
|
||||
|
||||
void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||
const std::vector<int> &axes) {
|
||||
const std::vector<int> &axes, FormatTransNodeType trans_type) {
|
||||
if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) {
|
||||
return;
|
||||
}
|
||||
auto axis_map = GetNC2NHAxisMap();
|
||||
auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index);
|
||||
if (origin_input.size() != axes.size()) {
|
||||
return;
|
||||
|
@ -369,8 +422,16 @@ void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, cons
|
|||
std::vector<int> cur_input;
|
||||
for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
|
||||
for (size_t index = 0; index < axes.size(); ++index) {
|
||||
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + kInputSizeFour : axes[index]];
|
||||
if (nhwc_dim == dim) {
|
||||
int axis = axes[index];
|
||||
if (axis < 0) {
|
||||
axis += kInputSizeFour;
|
||||
}
|
||||
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
|
||||
int cur_axis = kNH2NC[axis];
|
||||
if (trans_type == kNHWC2NCHW) {
|
||||
cur_axis = kNC2NH[axis];
|
||||
}
|
||||
if (cur_axis == dim) {
|
||||
cur_input.push_back(origin_input[index]);
|
||||
}
|
||||
}
|
||||
|
@ -379,14 +440,23 @@ void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, cons
|
|||
func_graph->manager()->Replace(cnode->input(input_index), param_node);
|
||||
}
|
||||
|
||||
std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes) {
|
||||
auto axis_map = GetNC2NHAxisMap();
|
||||
std::vector<int> cur_axis;
|
||||
std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes,
|
||||
FormatTransNodeType trans_type) {
|
||||
std::vector<int> cur_axes;
|
||||
for (size_t i = 0; i < origin_axes.size(); ++i) {
|
||||
cur_axis.push_back(axis_map[origin_axes[i] < 0 ? origin_axes[i] + kInputSizeFour : origin_axes[i]]);
|
||||
int axis = origin_axes[i];
|
||||
if (axis < 0) {
|
||||
axis += kInputSizeFour;
|
||||
}
|
||||
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
|
||||
int cur_axis = kNH2NC[axis];
|
||||
if (trans_type == kNHWC2NCHW) {
|
||||
cur_axis = kNC2NH[axis];
|
||||
}
|
||||
cur_axes.push_back(cur_axis);
|
||||
}
|
||||
std::sort(cur_axis.begin(), cur_axis.end());
|
||||
return cur_axis;
|
||||
std::sort(cur_axes.begin(), cur_axes.end());
|
||||
return cur_axes;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,23 +39,25 @@ class TransposeStrategy {
|
|||
}
|
||||
AnfNodePtr TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &code,
|
||||
const std::vector<int> &perm, bool before, size_t index);
|
||||
AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
|
||||
bool before, size_t index);
|
||||
bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info,
|
||||
TransTypePair *trans_insert_info);
|
||||
STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
|
||||
bool CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
private:
|
||||
AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
|
||||
bool before, size_t index);
|
||||
STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index);
|
||||
bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
|
||||
FormatTransNodeType *trans_type);
|
||||
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info);
|
||||
STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type);
|
||||
STATUS ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type);
|
||||
STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
|
||||
STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
|
||||
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||
const std::vector<int> &axes);
|
||||
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes);
|
||||
const std::vector<int> &axes, FormatTransNodeType trans_type);
|
||||
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
|
|
Loading…
Reference in New Issue