From ae1116a724ac5674a6d6b1a60eed346fb3d53846 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 15 Dec 2021 17:42:20 +0800 Subject: [PATCH] module constant-fold optimizer --- .../fusion/constant_folding_fusion_test.cc | 2 +- mindspore/lite/tools/converter/CMakeLists.txt | 7 +- .../lite/tools/converter/anf_transform.cc | 2 +- .../const_fold/constant_folding_fusion.h | 68 +++++++ .../const_fold/fold_along_infershape.cc | 65 +++++++ .../const_fold/fold_along_infershape.h | 39 ++++ .../fold_utils.cc} | 179 +++--------------- .../tools/optimizer/const_fold/fold_utils.h | 44 +++++ .../const_fold/fold_with_infershape.cc | 158 ++++++++++++++++ .../fold_with_infershape.h} | 22 +-- .../tools/optimizer/graph/infershape_pass.cc | 97 ++++++---- .../tools/optimizer/graph/infershape_pass.h | 15 +- 12 files changed, 481 insertions(+), 217 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/const_fold/constant_folding_fusion.h create mode 100644 mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.cc create mode 100644 mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.h rename mindspore/lite/tools/optimizer/{fusion/constant_folding_fusion.cc => const_fold/fold_utils.cc} (64%) create mode 100644 mindspore/lite/tools/optimizer/const_fold/fold_utils.h create mode 100644 mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.cc rename mindspore/lite/tools/optimizer/{fusion/constant_folding_fusion.h => const_fold/fold_with_infershape.h} (65%) diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc index c1e50b67a19..91e605f1496 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc @@ -23,7 +23,7 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "tools/converter/anf_transform.h" -#include "tools/optimizer/fusion/constant_folding_fusion.h" +#include "tools/optimizer/const_fold/constant_folding_fusion.h" #include "tools/anf_exporter/anf_exporter.h" #include "test/common/import_from_meta_graphT.h" diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 646e38d73d0..344ab927a8a 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -54,12 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${SRC_DIR}/common/dynamic_library_loader.cc ${SRC_DIR}/train/train_populate_parameter.cc - ../optimizer/common/*.cc - ../optimizer/format/*.cc - ../optimizer/fusion/*.cc - ../optimizer/fisson/*.cc - ../optimizer/parallel/*.cc - ../optimizer/graph/*.cc + ../optimizer/*.cc ) add_subdirectory(../anf_exporter anf_exporter) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 2492b94c653..285ef02330a 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -33,7 +33,7 @@ #include "tools/optimizer/fusion/conv_scale_fusion.h" #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" -#include "tools/optimizer/fusion/constant_folding_fusion.h" +#include "tools/optimizer/const_fold/constant_folding_fusion.h" #include "tools/optimizer/fusion/norm_fusion.h" #include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h" diff --git a/mindspore/lite/tools/optimizer/const_fold/constant_folding_fusion.h b/mindspore/lite/tools/optimizer/const_fold/constant_folding_fusion.h new file mode 100644 index 00000000000..3b63226f92d --- /dev/null +++ b/mindspore/lite/tools/optimizer/const_fold/constant_folding_fusion.h @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_ + +#include "backend/optimizer/common/pass.h" +#include "include/registry/converter_context.h" +#include "tools/optimizer/const_fold/fold_along_infershape.h" +#include "tools/optimizer/const_fold/fold_with_infershape.h" +#include "tools/optimizer/graph/update_conv2d_param_pass.h" + +namespace mindspore { +namespace opt { +class ConstFoldPass : public Pass { + public: + explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) + : Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {} + ~ConstFoldPass() override = default; + bool Run(const FuncGraphPtr &func_graph) override { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "func_graph is nullptr, do constant fold failed."; + return false; + } + // current infer-shape cannot support the control-flow model of Mindir. + if (fmk_type_ == converter::kFmkTypeMs) { + auto fold_schedule = ConstFoldWithInferShape(fmk_type_, train_flag_); + if (!fold_schedule.Run(func_graph)) { + MS_LOG(ERROR) << "Do constant fold failed."; + return false; + } + } else { + auto fold_schedule = ConstFoldAlongInferShape(fmk_type_, train_flag_); + if (!fold_schedule.Run(func_graph)) { + MS_LOG(ERROR) << "Do constant fold failed."; + return false; + } + } + + // the attrs of convolution only can be update after constant fold. + auto update_attrs = UpdateConv2DParamPass(); + if (!update_attrs.Run(func_graph)) { + MS_LOG(ERROR) << "update attrs failed."; + return false; + } + return true; + } + + private: + FmkType fmk_type_{converter::kFmkTypeMs}; + bool train_flag_{false}; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.cc b/mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.cc new file mode 100644 index 00000000000..1a11c0af1a7 --- /dev/null +++ b/mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/const_fold/fold_along_infershape.h" +#include +#include "nnacl/op_base.h" + +namespace mindspore { +namespace opt { +STATUS ConstFoldAlongInferShape::PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + if (!CheckCanFold(func_graph, cnode)) { + return lite::RET_OK; + } + if (const_fold_processor_ == nullptr) { + const_fold_processor_ = std::make_shared(fmk_type_, train_flag_); + } + MS_CHECK_TRUE_MSG(const_fold_processor_ != nullptr, lite::RET_NULL_PTR, "const fold processor is nullptr"); + auto status = const_fold_processor_->DoConstantFold(func_graph, cnode); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold failed, the node is " << cnode->fullname_with_scope(); + } + return status; +} + +bool ConstFoldAlongInferShape::CheckCanFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + if (IsSpecialType(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom) || IsMarkedTrainOp(cnode)) { + return false; + } + auto prim = GetCNodePrimitive(cnode); + if (prim == nullptr) { + return false; + } + auto is_inferred = prim->GetAttr(kInferDone) != nullptr && GetValue(prim->GetAttr(kInferDone)); + if (!is_inferred) { + return false; + } + if (CheckPrimitiveType(cnode, prim::kPrimShape)) { + return lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() != 0; + } + auto inputs = cnode->inputs(); + auto graph_inputs = + sub_inputs_map_.find(func_graph) != sub_inputs_map_.end() ? sub_inputs_map_[func_graph] : func_graph->get_inputs(); + return std::all_of(inputs.begin(), inputs.end(), [&graph_inputs](const AnfNodePtr &node) { + return (node->isa() && !IsValueNode(node)) || + (node->isa() && node->cast()->has_default() && + std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end()); + }); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.h b/mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.h new file mode 100644 index 00000000000..0b9e0e67f76 --- /dev/null +++ b/mindspore/lite/tools/optimizer/const_fold/fold_along_infershape.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_ + +#include +#include "tools/optimizer/graph/infershape_pass.h" +#include "tools/optimizer/const_fold/fold_utils.h" + +namespace mindspore { +namespace opt { +class ConstFoldAlongInferShape : public InferShapePass { + public: + explicit ConstFoldAlongInferShape(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) + : InferShapePass(fmk_type, train_flag, "ConstFoldAlongInferShape") {} + ~ConstFoldAlongInferShape() override = default; + + private: + STATUS PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) override; + bool CheckCanFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + std::shared_ptr const_fold_processor_{nullptr}; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/const_fold/fold_utils.cc similarity index 64% rename from mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc rename to mindspore/lite/tools/optimizer/const_fold/fold_utils.cc index 8ddb6166179..23c228452ad 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/const_fold/fold_utils.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,12 @@ * limitations under the License. */ -#include "tools/optimizer/fusion/constant_folding_fusion.h" +#include "tools/optimizer/const_fold/fold_utils.h" #include #include -#include #include #include "backend/optimizer/common/helper.h" -#include "tools/anf_exporter/fetch_content.h" +#include "ir/anf.h" #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/format_utils.h" #include "tools/common/node_util.h" @@ -31,15 +30,15 @@ #include "src/kernel_registry.h" #include "src/inner_context.h" #include "src/tensor.h" +#include "src/ops/ops_utils.h" #include "src/runtime/infer_manager.h" #include "tools/optimizer/graph/lite_tensor_extractor.h" using mindspore::lite::KernelRegistry; using mindspore::lite::Tensor; -namespace mindspore::opt { +namespace mindspore { +namespace opt { namespace { -constexpr size_t INITIAL_SIZE = 1024; -constexpr auto kIsLinkWithControlFlow = "link_with_control_flow"; ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(tensor != nullptr); @@ -47,8 +46,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { MS_CHECK_TRUE_RET(parameter != nullptr, nullptr); std::vector shape(tensor->shape()); std::vector shape_vector; - std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); auto tensor_info = std::make_shared(tensor->data_type(), shape_vector); if (tensor_info == nullptr) { @@ -75,7 +74,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vectorinput(0), ¶meter); - if (ret != lite::RET_OK) { + if (ret != RET_OK) { MS_LOG(ERROR) << cnode->fullname_with_scope() << " FetchOpParameterFromNode failed. "; return nullptr; } @@ -91,7 +90,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vectordata_type(); kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast(parameter->type_)}; - kernel::LiteKernel *lite_kernel; + kernel::LiteKernel *lite_kernel = nullptr; ret = lite::KernelRegistry::GetInstance()->GetKernel(inputs, *outputs, context, ms_context, desc, parameter, &lite_kernel); if (ret != lite::RET_OK) { @@ -104,8 +103,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vectorPrepare(); if (ret != lite::RET_OK) { MS_LOG(ERROR) << "init failed."; - // op_parameter is free by lite_kernel destructor - delete lite_kernel; + delete lite_kernel; // parameter will be freed in destructor of lite-kernel. return nullptr; } return lite_kernel; @@ -194,27 +192,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector } } // namespace -bool ConstFoldPass::Run(const FuncGraphPtr &func_graph) { - MS_CHECK_TRUE_RET(func_graph != nullptr, false); - manager_ = Manage(func_graph); - MS_CHECK_TRUE_RET(manager_ != nullptr, false); - if (!Init()) { - MS_LOG(ERROR) << "initial constant fold pass failed."; - return false; - } - std::set has_visited; - if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) { - MS_LOG(ERROR) << "do constant fold pass failed,"; - return false; - } - if (HandleSpecialFold(func_graph) != lite::RET_OK) { - MS_LOG(ERROR) << "do constant fold pass failed,"; - return false; - } - return true; -} - -bool ConstFoldPass::Init() { +bool ConstFoldProcessor::Init() { if (context_ == nullptr) { context_ = std::make_shared(); MS_CHECK_TRUE_RET(context_ != nullptr, false); @@ -230,119 +208,15 @@ bool ConstFoldPass::Init() { return true; } -int ConstFoldPass::HandleCommonFold(const FuncGraphPtr &func_graph, std::set *has_visited) { - MS_ASSERT(func_graph != nullptr); - if (has_visited->find(func_graph) != has_visited->end()) { - return lite::RET_OK; +int ConstFoldProcessor::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + if (func_graph == nullptr || cnode == nullptr) { + MS_LOG(ERROR) << "input param is nullptr."; + return lite::RET_NULL_PTR; } - has_visited->insert(func_graph); - MS_ASSERT(manager_ != nullptr); - manager_->AddFuncGraph(func_graph); - auto node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (!utils::isa(node)) { - continue; - } - auto cnode = node->cast(); - for (size_t i = 0; i < cnode->size(); ++i) { - if (IsValueNode(cnode->input(i))) { - auto sub_graph = GetValueNode(cnode->input(i)); - MS_ASSERT(sub_graph != nullptr); - if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) { - MS_LOG(ERROR) << "do subgraph const-fold failed."; - return lite::RET_ERROR; - } - } - } - if (!CheckCanCommonFold(cnode)) { - continue; - } - if (DoConstantFold(func_graph, cnode) != lite::RET_OK) { - MS_LOG(ERROR) << "do constant fold failed."; - return lite::RET_ERROR; - } + if (!Init()) { + MS_LOG(ERROR) << "initial context failed."; + return lite::RET_ERROR; } - return lite::RET_OK; -} - -bool ConstFoldPass::CheckCanCommonFold(const CNodePtr &cnode) const { - MS_CHECK_TRUE_RET(cnode != nullptr, false); - if (IsSpecialType(cnode)) { - return false; - } - if (IsMarkedTrainOp(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom)) { - return false; - } - auto inputs = cnode->inputs(); - return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { - return (node->isa() && !IsValueNode(node)) || - (node->isa() && node->cast()->has_default()); - }); -} - -int ConstFoldPass::HandleSpecialFold(const FuncGraphPtr &func_graph) { - MS_ASSERT(func_graph != nullptr); - if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) { - return lite::RET_OK; - } - if (node_infershape_ == nullptr) { - node_infershape_ = std::make_shared(fmk_type_, train_flag_); - MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR); - } - MS_ASSERT(manager_ != nullptr); - auto node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (!utils::isa(node)) { - continue; - } - auto cnode = node->cast(); - if (!CheckCanSpecialFold(cnode)) { - continue; - } - if (DoConstantFold(func_graph, cnode) != lite::RET_OK) { - MS_LOG(ERROR) << "do constant fold failed."; - return lite::RET_ERROR; - } - } - return lite::RET_OK; -} - -bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const { - MS_CHECK_TRUE_RET(cnode != nullptr, false); - for (size_t i = 0; i < cnode->size(); ++i) { - auto input_node = cnode->input(i); - MS_CHECK_TRUE_RET(input_node != nullptr, false); - if (IsValueNode(input_node)) { - return false; - } - if (!input_node->isa()) { - continue; - } - auto input_cnode = input_node->cast(); - auto input_prim = GetValueNode(input_cnode->input(0)); - MS_CHECK_TRUE_RET(input_prim != nullptr, false); - bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr || - GetValue(input_prim->GetAttr(kIsLinkWithControlFlow)); - if (is_link_with_control_flow) { - return false; - } - } - auto prim = GetValueNode(cnode->input(0)); - MS_CHECK_TRUE_RET(prim != nullptr, false); - prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false)); - if (IsSpecialType(cnode)) { - return false; - } - MS_ASSERT(node_infershape_ != nullptr); - auto status = node_infershape_->InferShape(cnode); - if (CheckPrimitiveType(cnode, prim::kPrimShape)) { - return status == lite::RET_OK; - } - return CheckCanCommonFold(cnode); -} - -int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { - MS_ASSERT(func_graph != nullptr && cnode != nullptr); std::vector inputs_ptr; if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, true) != lite::RET_OK) { MS_LOG(ERROR) << "extract input tensor from cnode failed. " << cnode->fullname_with_scope(); @@ -363,12 +237,12 @@ int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr MS_LOG(DEBUG) << "this op is control flow op, which is not supported now."; return lite::RET_OK; } - std::vector input_tensors; - std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(input_tensors), - [](const TensorPtr &input) { return input.get(); }); - std::vector output_tensors; - std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(output_tensors), - [](const TensorPtr &output) { return output.get(); }); + std::vector input_tensors; + (void)std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(input_tensors), + [](const TensorPtr &input) { return input.get(); }); + std::vector output_tensors; + (void)std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(output_tensors), + [](const TensorPtr &output) { return output.get(); }); if (CopyQuantParams(cnode, input_tensors, output_tensors) != lite::RET_OK) { MS_LOG(ERROR) << "copy quant params failed."; return lite::RET_ERROR; @@ -402,4 +276,5 @@ int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr } return status; } -} // namespace mindspore::opt +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/const_fold/fold_utils.h b/mindspore/lite/tools/optimizer/const_fold/fold_utils.h new file mode 100644 index 00000000000..eff97e1efbf --- /dev/null +++ b/mindspore/lite/tools/optimizer/const_fold/fold_utils.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_ + +#include +#include "ir/anf.h" +#include "include/api/context.h" +#include "include/registry/converter_context.h" +#include "src/inner_context.h" + +namespace mindspore { +namespace opt { +class ConstFoldProcessor { + public: + explicit ConstFoldProcessor(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) + : fmk_type_(fmk_type), train_flag_(train_flag) {} + int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + ~ConstFoldProcessor() = default; + + private: + bool Init(); + converter::FmkType fmk_type_{converter::kFmkTypeMs}; + bool train_flag_{false}; + std::shared_ptr context_{nullptr}; + std::shared_ptr ms_context_{nullptr}; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.cc b/mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.cc new file mode 100644 index 00000000000..50dbbe2cbab --- /dev/null +++ b/mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/const_fold/fold_with_infershape.h" +#include +#include +#include "tools/optimizer/common/format_utils.h" +#include "nnacl/op_base.h" + +namespace mindspore::opt { +namespace { +constexpr auto kIsLinkWithControlFlow = "link_with_control_flow"; +} // namespace + +bool ConstFoldWithInferShape::Run(const FuncGraphPtr &func_graph) { + MS_CHECK_TRUE_RET(func_graph != nullptr, false); + manager_ = Manage(func_graph); + MS_CHECK_TRUE_RET(manager_ != nullptr, false); + if (const_fold_processor_ == nullptr) { + const_fold_processor_ = std::make_shared(fmk_type_, train_flag_); + } + MS_CHECK_TRUE_RET(const_fold_processor_ != nullptr, false); + std::set has_visited; + if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold pass failed,"; + return false; + } + if (HandleSpecialFold(func_graph) != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold pass failed,"; + return false; + } + return true; +} + +int ConstFoldWithInferShape::HandleCommonFold(const FuncGraphPtr &func_graph, std::set *has_visited) { + MS_ASSERT(func_graph != nullptr); + if (has_visited->find(func_graph) != has_visited->end()) { + return lite::RET_OK; + } + has_visited->insert(func_graph); + MS_ASSERT(manager_ != nullptr); + manager_->AddFuncGraph(func_graph); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + for (size_t i = 0; i < cnode->size(); ++i) { + if (IsValueNode(cnode->input(i))) { + auto sub_graph = GetValueNode(cnode->input(i)); + MS_ASSERT(sub_graph != nullptr); + if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) { + MS_LOG(ERROR) << "do subgraph const-fold failed."; + return lite::RET_ERROR; + } + } + } + if (!CheckCanCommonFold(cnode)) { + continue; + } + if (const_fold_processor_->DoConstantFold(func_graph, cnode) != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold failed."; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +bool ConstFoldWithInferShape::CheckCanCommonFold(const CNodePtr &cnode) const { + MS_CHECK_TRUE_RET(cnode != nullptr, false); + if (IsSpecialType(cnode)) { + return false; + } + if (IsMarkedTrainOp(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom)) { + return false; + } + auto inputs = cnode->inputs(); + return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { + return (node->isa() && !IsValueNode(node)) || + (node->isa() && node->cast()->has_default()); + }); +} + +int ConstFoldWithInferShape::HandleSpecialFold(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) { + return lite::RET_OK; + } + if (node_infershape_ == nullptr) { + node_infershape_ = std::make_shared(fmk_type_, train_flag_); + MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR); + } + MS_ASSERT(manager_ != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (!CheckCanSpecialFold(cnode)) { + continue; + } + if (const_fold_processor_->DoConstantFold(func_graph, cnode) != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold failed."; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +bool ConstFoldWithInferShape::CheckCanSpecialFold(const CNodePtr &cnode) const { + MS_CHECK_TRUE_RET(cnode != nullptr, false); + for (size_t i = 0; i < cnode->size(); ++i) { + auto input_node = cnode->input(i); + MS_CHECK_TRUE_RET(input_node != nullptr, false); + if (IsValueNode(input_node)) { + return false; + } + if (!input_node->isa()) { + continue; + } + auto input_cnode = input_node->cast(); + auto input_prim = GetValueNode(input_cnode->input(0)); + MS_CHECK_TRUE_RET(input_prim != nullptr, false); + bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr || + GetValue(input_prim->GetAttr(kIsLinkWithControlFlow)); + if (is_link_with_control_flow) { + return false; + } + } + auto prim = GetValueNode(cnode->input(0)); + MS_CHECK_TRUE_RET(prim != nullptr, false); + prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false)); + if (IsSpecialType(cnode)) { + return false; + } + MS_ASSERT(node_infershape_ != nullptr); + auto status = node_infershape_->InferShape(cnode); + if (CheckPrimitiveType(cnode, prim::kPrimShape)) { + return status == lite::RET_OK; + } + return CheckCanCommonFold(cnode); +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h b/mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.h similarity index 65% rename from mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h rename to mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.h index 6147b756fb0..598d022850f 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h +++ b/mindspore/lite/tools/optimizer/const_fold/fold_with_infershape.h @@ -14,41 +14,37 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ -#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_ #include #include #include #include "backend/optimizer/common/pass.h" -#include "include/api/context.h" #include "include/registry/converter_context.h" -#include "src/inner_context.h" #include "tools/optimizer/graph/node_infershape.h" +#include "tools/optimizer/const_fold/fold_utils.h" namespace mindspore { namespace opt { -class ConstFoldPass : public Pass { +class ConstFoldWithInferShape : public Pass { public: - explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) - : Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {} - ~ConstFoldPass() override = default; + explicit ConstFoldWithInferShape(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) + : Pass("ConstFoldWithInferShape"), fmk_type_(fmk_type), train_flag_(train_flag) {} + ~ConstFoldWithInferShape() override = default; bool Run(const FuncGraphPtr &func_graph) override; private: - bool Init(); int HandleCommonFold(const FuncGraphPtr &func_graph, std::set *has_visited); bool CheckCanCommonFold(const CNodePtr &cnode) const; int HandleSpecialFold(const FuncGraphPtr &func_graph); bool CheckCanSpecialFold(const CNodePtr &cnode) const; - int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; converter::FmkType fmk_type_{converter::kFmkTypeMs}; bool train_flag_{false}; - std::shared_ptr context_{nullptr}; - std::shared_ptr ms_context_{nullptr}; + std::shared_ptr const_fold_processor_{nullptr}; std::shared_ptr node_infershape_{nullptr}; FuncGraphManagerPtr manager_{nullptr}; }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_ diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index 413c60fda9c..a3ef10a293f 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -16,6 +16,7 @@ #include "tools/optimizer/graph/infershape_pass.h" #include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" #include "nnacl/op_base.h" #include "src/common/log_util.h" @@ -47,8 +48,8 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma auto primitive = GetValueNode(real_cnode->input(0)); MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed"); if (primitive->GetAttr(ops::kFormat) == nullptr) { - MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope(); - return lite::RET_ERROR; + MS_LOG(DEBUG) << "cnode has no format attr. " << real_cnode->fullname_with_scope(); + return lite::RET_NO_CHANGE; } auto format_attr = primitive->GetAttr(ops::kFormat); MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed"); @@ -110,11 +111,20 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { MS_LOG(WARNING) << "exist op cannot support infer shape."; return false; } + manager_ = Manage(func_graph, true); + if (manager_ == nullptr) { + MS_LOG(ERROR) << "generate a manager for func_graph failed."; + return false; + } if (InferProcess(func_graph) != lite::RET_OK) { MS_LOG(ERROR) << "infer shape failed."; return false; } - return ResetSubGraphInput(); + if (ResetSubGraphInput() != lite::RET_OK) { + MS_LOG(ERROR) << "ResetSubGraphInput failed."; + return false; + } + return true; } bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { @@ -165,6 +175,7 @@ bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); + manager_->AddFuncGraph(func_graph); auto node_list = TopoSort(func_graph->get_return()); for (auto &node : node_list) { if (!utils::isa(node)) { @@ -184,32 +195,38 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { auto ret = SetSubGraphInput(cnode, sub_func_graph); if (ret != lite::RET_OK) { MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret; - return lite::RET_ERROR; + return RET_ERROR; } if (InferProcess(sub_func_graph) != lite::RET_OK) { MS_LOG(ERROR) << "subgraph infer shape failed."; - return lite::RET_ERROR; + return RET_ERROR; + } + if (SetSubGraphOutput(sub_func_graph) != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphOutput failed."; + return RET_ERROR; } - SetSubGraphOutput(sub_func_graph); sub_func_graph = GetValueNode(cnode->input(kInputIndexTwo)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return lite::RET_ERROR; + return RET_ERROR; } ret = SetSubGraphInput(cnode, sub_func_graph); - if (ret != lite::RET_OK) { + if (ret != RET_OK) { MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret; - return lite::RET_ERROR; + return RET_ERROR; } if (InferProcess(sub_func_graph) != lite::RET_OK) { MS_LOG(ERROR) << "subgraph infer shape failed."; - return lite::RET_ERROR; + return RET_ERROR; + } + if (SetSubGraphOutput(sub_func_graph) != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphOutput failed."; + return RET_ERROR; } - SetSubGraphOutput(sub_func_graph); ret = SetSubGraphAbstract(cnode, sub_func_graph); - if (ret != lite::RET_OK) { + if (ret != RET_OK) { MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret; - return lite::RET_ERROR; + return RET_ERROR; } continue; } @@ -218,6 +235,11 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "node infer shape failed, node is " << node->fullname_with_scope(); return lite::RET_ERROR; } + status = PostProcess(func_graph, cnode); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "post process current node failed, node is " << node->fullname_with_scope(); + return lite::RET_ERROR; + } } return lite::RET_OK; } @@ -260,35 +282,27 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) { MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var(); } - } else { + } + if (utils::isa(cnode->input(index))) { + param_node->set_default_param(cnode->input(index)->cast()->default_param()); + } + if (utils::isa(cnode->input(index))) { lite::DataInfo data_info; - if (utils::isa(cnode->input(index))) { - if (cnode->input(index)->cast()->has_default()) { - param_node->set_default_param(cnode->input(index)->cast()->default_param()); - } - continue; - } auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, false); if (status != lite::RET_OK) { continue; } ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); - if (data_info.data_.empty()) { - auto tensor_info = std::make_shared((TypeId)data_info.data_type_, shape_vec); - CHECK_NULL_RETURN(tensor_info); - param_node->set_default_param(tensor_info); - } else { - auto tensor_info = std::make_shared((TypeId)data_info.data_type_, shape_vec, - data_info.data_.data(), data_info.data_.size()); - CHECK_NULL_RETURN(tensor_info); - param_node->set_default_param(tensor_info); - } + auto tensor_info = + lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec, (TypeId)data_info.data_type_); + MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "created tensor is a nullptr."); + param_node->set_default_param(tensor_info); } } return RET_OK; } -void InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) { +STATUS InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) { MS_ASSERT(sub_graph != nullptr); auto return_node = sub_graph->get_return(); MS_ASSERT(return_node != nullptr); @@ -317,6 +331,7 @@ void InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) { trans_cnode->set_fullname_with_scope(trans_input_name); } return_node->set_inputs(origin_input); + return lite::RET_OK; } STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { @@ -371,24 +386,26 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap return RET_OK; } -bool InferShapePass::ResetSubGraphInput() { - for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { - auto &sub_graph = iter->first; - auto &sub_inputs = iter->second; - auto manager = sub_graph->manager(); - MS_ASSERT(manager != nullptr); +int InferShapePass::ResetSubGraphInput() { + for (auto &iter : sub_inputs_map_) { + auto &sub_graph = iter.first; + auto &sub_inputs = iter.second; + MS_ASSERT(manager_ != nullptr); for (auto &sub_input : sub_inputs) { auto param_node = sub_graph->add_parameter(); - MS_CHECK_TRUE_MSG(param_node != nullptr, false, "Add parameter Failed"); + MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "Add parameter Failed"); param_node->set_abstract(sub_input->abstract()->Clone()); param_node->set_name(sub_input->fullname_with_scope()); - manager->Replace(sub_input, param_node); + if (!manager_->Replace(sub_input, param_node)) { + MS_LOG(ERROR) << "replace cnode failed."; + return RET_ERROR; + } auto sub_param_input = sub_input->cast(); MS_ASSERT(sub_param_input != nullptr); sub_param_input->set_default_param(nullptr); } } - return true; + return lite::RET_OK; } } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h index 9d9f2f668d5..46a6ab4a4eb 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "backend/optimizer/common/pass.h" #include "tools/optimizer/graph/node_infershape.h" @@ -27,23 +28,29 @@ namespace mindspore { namespace opt { class InferShapePass : public Pass { public: - explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) - : Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {} + explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false, + const std::string &name = "InferShapePass") + : Pass(name), fmk_type_(fmk_type), train_flag_(train_flag) {} ~InferShapePass() override = default; bool Run(const FuncGraphPtr &func_graph) override; + protected: + virtual STATUS PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { return lite::RET_OK; } + private: bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); STATUS InferProcess(const FuncGraphPtr &func_graph); STATUS SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - void SetSubGraphOutput(const FuncGraphPtr &sub_graph); + STATUS SetSubGraphOutput(const FuncGraphPtr &sub_graph); STATUS SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - bool ResetSubGraphInput(); + int ResetSubGraphInput(); + protected: FmkType fmk_type_{converter::kFmkTypeMs}; bool train_flag_{false}; std::shared_ptr node_infer_shape_{nullptr}; std::map> sub_inputs_map_{}; + FuncGraphManagerPtr manager_{nullptr}; }; } // namespace opt } // namespace mindspore