module constant-fold optimizer

This commit is contained in:
xuanyue 2021-12-15 17:42:20 +08:00
parent b93494c6bb
commit ae1116a724
12 changed files with 481 additions and 217 deletions

View File

@ -23,7 +23,7 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "tools/converter/anf_transform.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/const_fold/constant_folding_fusion.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "test/common/import_from_meta_graphT.h"

View File

@ -54,12 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${SRC_DIR}/common/dynamic_library_loader.cc
${SRC_DIR}/train/train_populate_parameter.cc
../optimizer/common/*.cc
../optimizer/format/*.cc
../optimizer/fusion/*.cc
../optimizer/fisson/*.cc
../optimizer/parallel/*.cc
../optimizer/graph/*.cc
../optimizer/*.cc
)
add_subdirectory(../anf_exporter anf_exporter)

View File

@ -33,7 +33,7 @@
#include "tools/optimizer/fusion/conv_scale_fusion.h"
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/const_fold/constant_folding_fusion.h"
#include "tools/optimizer/fusion/norm_fusion.h"
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h"

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_
#include "backend/optimizer/common/pass.h"
#include "include/registry/converter_context.h"
#include "tools/optimizer/const_fold/fold_along_infershape.h"
#include "tools/optimizer/const_fold/fold_with_infershape.h"
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
namespace mindspore {
namespace opt {
class ConstFoldPass : public Pass {
public:
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ConstFoldPass() override = default;
bool Run(const FuncGraphPtr &func_graph) override {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is nullptr, do constant fold failed.";
return false;
}
// current infer-shape cannot support the control-flow model of Mindir.
if (fmk_type_ == converter::kFmkTypeMs) {
auto fold_schedule = ConstFoldWithInferShape(fmk_type_, train_flag_);
if (!fold_schedule.Run(func_graph)) {
MS_LOG(ERROR) << "Do constant fold failed.";
return false;
}
} else {
auto fold_schedule = ConstFoldAlongInferShape(fmk_type_, train_flag_);
if (!fold_schedule.Run(func_graph)) {
MS_LOG(ERROR) << "Do constant fold failed.";
return false;
}
}
// the attrs of convolution only can be update after constant fold.
auto update_attrs = UpdateConv2DParamPass();
if (!update_attrs.Run(func_graph)) {
MS_LOG(ERROR) << "update attrs failed.";
return false;
}
return true;
}
private:
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/const_fold/fold_along_infershape.h"
#include <memory>
#include "nnacl/op_base.h"
namespace mindspore {
namespace opt {
STATUS ConstFoldAlongInferShape::PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!CheckCanFold(func_graph, cnode)) {
return lite::RET_OK;
}
if (const_fold_processor_ == nullptr) {
const_fold_processor_ = std::make_shared<ConstFoldProcessor>(fmk_type_, train_flag_);
}
MS_CHECK_TRUE_MSG(const_fold_processor_ != nullptr, lite::RET_NULL_PTR, "const fold processor is nullptr");
auto status = const_fold_processor_->DoConstantFold(func_graph, cnode);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold failed, the node is " << cnode->fullname_with_scope();
}
return status;
}
bool ConstFoldAlongInferShape::CheckCanFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (IsSpecialType(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom) || IsMarkedTrainOp(cnode)) {
return false;
}
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
return false;
}
auto is_inferred = prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(prim->GetAttr(kInferDone));
if (!is_inferred) {
return false;
}
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
return lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() != 0;
}
auto inputs = cnode->inputs();
auto graph_inputs =
sub_inputs_map_.find(func_graph) != sub_inputs_map_.end() ? sub_inputs_map_[func_graph] : func_graph->get_inputs();
return std::all_of(inputs.begin(), inputs.end(), [&graph_inputs](const AnfNodePtr &node) {
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default() &&
std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end());
});
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_
#include <memory>
#include "tools/optimizer/graph/infershape_pass.h"
#include "tools/optimizer/const_fold/fold_utils.h"
namespace mindspore {
namespace opt {
class ConstFoldAlongInferShape : public InferShapePass {
public:
explicit ConstFoldAlongInferShape(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: InferShapePass(fmk_type, train_flag, "ConstFoldAlongInferShape") {}
~ConstFoldAlongInferShape() override = default;
private:
STATUS PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) override;
bool CheckCanFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
std::shared_ptr<ConstFoldProcessor> const_fold_processor_{nullptr};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,13 +14,12 @@
* limitations under the License.
*/
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/const_fold/fold_utils.h"
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "backend/optimizer/common/helper.h"
#include "tools/anf_exporter/fetch_content.h"
#include "ir/anf.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h"
@ -31,15 +30,15 @@
#include "src/kernel_registry.h"
#include "src/inner_context.h"
#include "src/tensor.h"
#include "src/ops/ops_utils.h"
#include "src/runtime/infer_manager.h"
#include "tools/optimizer/graph/lite_tensor_extractor.h"
using mindspore::lite::KernelRegistry;
using mindspore::lite::Tensor;
namespace mindspore::opt {
namespace mindspore {
namespace opt {
namespace {
constexpr size_t INITIAL_SIZE = 1024;
constexpr auto kIsLinkWithControlFlow = "link_with_control_flow";
ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(tensor != nullptr);
@ -47,8 +46,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
MS_CHECK_TRUE_RET(parameter != nullptr, nullptr);
std::vector<int> shape(tensor->shape());
std::vector<int64_t> shape_vector;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
if (tensor_info == nullptr) {
@ -75,7 +74,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
MS_ASSERT(outputs != nullptr && cnode != nullptr && context != nullptr && ms_context != nullptr);
OpParameter *parameter;
auto ret = lite::FetchOpParameterFromNode(cnode->input(0), &parameter);
if (ret != lite::RET_OK) {
if (ret != RET_OK) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " FetchOpParameterFromNode failed. ";
return nullptr;
}
@ -91,7 +90,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
}
auto data_type = inputs.front()->data_type();
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(parameter->type_)};
kernel::LiteKernel *lite_kernel;
kernel::LiteKernel *lite_kernel = nullptr;
ret = lite::KernelRegistry::GetInstance()->GetKernel(inputs, *outputs, context, ms_context, desc, parameter,
&lite_kernel);
if (ret != lite::RET_OK) {
@ -104,8 +103,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
ret = lite_kernel->Prepare();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "init failed.";
// op_parameter is free by lite_kernel destructor
delete lite_kernel;
delete lite_kernel; // parameter will be freed in destructor of lite-kernel.
return nullptr;
}
return lite_kernel;
@ -194,27 +192,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *>
}
} // namespace
bool ConstFoldPass::Run(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, false);
manager_ = Manage(func_graph);
MS_CHECK_TRUE_RET(manager_ != nullptr, false);
if (!Init()) {
MS_LOG(ERROR) << "initial constant fold pass failed.";
return false;
}
std::set<FuncGraphPtr> has_visited;
if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold pass failed,";
return false;
}
if (HandleSpecialFold(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold pass failed,";
return false;
}
return true;
}
bool ConstFoldPass::Init() {
bool ConstFoldProcessor::Init() {
if (context_ == nullptr) {
context_ = std::make_shared<lite::InnerContext>();
MS_CHECK_TRUE_RET(context_ != nullptr, false);
@ -230,119 +208,15 @@ bool ConstFoldPass::Init() {
return true;
}
int ConstFoldPass::HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
MS_ASSERT(func_graph != nullptr);
if (has_visited->find(func_graph) != has_visited->end()) {
return lite::RET_OK;
int ConstFoldProcessor::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (func_graph == nullptr || cnode == nullptr) {
MS_LOG(ERROR) << "input param is nullptr.";
return lite::RET_NULL_PTR;
}
has_visited->insert(func_graph);
MS_ASSERT(manager_ != nullptr);
manager_->AddFuncGraph(func_graph);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t i = 0; i < cnode->size(); ++i) {
if (IsValueNode<FuncGraph>(cnode->input(i))) {
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(i));
MS_ASSERT(sub_graph != nullptr);
if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) {
MS_LOG(ERROR) << "do subgraph const-fold failed.";
return lite::RET_ERROR;
}
}
}
if (!CheckCanCommonFold(cnode)) {
continue;
}
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold failed.";
return lite::RET_ERROR;
}
if (!Init()) {
MS_LOG(ERROR) << "initial context failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
bool ConstFoldPass::CheckCanCommonFold(const CNodePtr &cnode) const {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
if (IsSpecialType(cnode)) {
return false;
}
if (IsMarkedTrainOp(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom)) {
return false;
}
auto inputs = cnode->inputs();
return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default());
});
}
int ConstFoldPass::HandleSpecialFold(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
return lite::RET_OK;
}
if (node_infershape_ == nullptr) {
node_infershape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR);
}
MS_ASSERT(manager_ != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!CheckCanSpecialFold(cnode)) {
continue;
}
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
for (size_t i = 0; i < cnode->size(); ++i) {
auto input_node = cnode->input(i);
MS_CHECK_TRUE_RET(input_node != nullptr, false);
if (IsValueNode<FuncGraph>(input_node)) {
return false;
}
if (!input_node->isa<CNode>()) {
continue;
}
auto input_cnode = input_node->cast<CNodePtr>();
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_CHECK_TRUE_RET(input_prim != nullptr, false);
bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr ||
GetValue<bool>(input_prim->GetAttr(kIsLinkWithControlFlow));
if (is_link_with_control_flow) {
return false;
}
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_RET(prim != nullptr, false);
prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false));
if (IsSpecialType(cnode)) {
return false;
}
MS_ASSERT(node_infershape_ != nullptr);
auto status = node_infershape_->InferShape(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
return status == lite::RET_OK;
}
return CheckCanCommonFold(cnode);
}
int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
std::vector<TensorPtr> inputs_ptr;
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, true) != lite::RET_OK) {
MS_LOG(ERROR) << "extract input tensor from cnode failed. " << cnode->fullname_with_scope();
@ -363,12 +237,12 @@ int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr
MS_LOG(DEBUG) << "this op is control flow op, which is not supported now.";
return lite::RET_OK;
}
std::vector<Tensor *> input_tensors;
std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(input_tensors),
[](const TensorPtr &input) { return input.get(); });
std::vector<Tensor *> output_tensors;
std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(output_tensors),
[](const TensorPtr &output) { return output.get(); });
std::vector<lite::Tensor *> input_tensors;
(void)std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(input_tensors),
[](const TensorPtr &input) { return input.get(); });
std::vector<lite::Tensor *> output_tensors;
(void)std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(output_tensors),
[](const TensorPtr &output) { return output.get(); });
if (CopyQuantParams(cnode, input_tensors, output_tensors) != lite::RET_OK) {
MS_LOG(ERROR) << "copy quant params failed.";
return lite::RET_ERROR;
@ -402,4 +276,5 @@ int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr
}
return status;
}
} // namespace mindspore::opt
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_
#include <memory>
#include "ir/anf.h"
#include "include/api/context.h"
#include "include/registry/converter_context.h"
#include "src/inner_context.h"
namespace mindspore {
namespace opt {
class ConstFoldProcessor {
public:
explicit ConstFoldProcessor(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: fmk_type_(fmk_type), train_flag_(train_flag) {}
int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
~ConstFoldProcessor() = default;
private:
bool Init();
converter::FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
std::shared_ptr<lite::InnerContext> context_{nullptr};
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_

View File

@ -0,0 +1,158 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/const_fold/fold_with_infershape.h"
#include <memory>
#include <set>
#include "tools/optimizer/common/format_utils.h"
#include "nnacl/op_base.h"
namespace mindspore::opt {
namespace {
constexpr auto kIsLinkWithControlFlow = "link_with_control_flow";
} // namespace
bool ConstFoldWithInferShape::Run(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, false);
manager_ = Manage(func_graph);
MS_CHECK_TRUE_RET(manager_ != nullptr, false);
if (const_fold_processor_ == nullptr) {
const_fold_processor_ = std::make_shared<ConstFoldProcessor>(fmk_type_, train_flag_);
}
MS_CHECK_TRUE_RET(const_fold_processor_ != nullptr, false);
std::set<FuncGraphPtr> has_visited;
if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold pass failed,";
return false;
}
if (HandleSpecialFold(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold pass failed,";
return false;
}
return true;
}
int ConstFoldWithInferShape::HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
MS_ASSERT(func_graph != nullptr);
if (has_visited->find(func_graph) != has_visited->end()) {
return lite::RET_OK;
}
has_visited->insert(func_graph);
MS_ASSERT(manager_ != nullptr);
manager_->AddFuncGraph(func_graph);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t i = 0; i < cnode->size(); ++i) {
if (IsValueNode<FuncGraph>(cnode->input(i))) {
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(i));
MS_ASSERT(sub_graph != nullptr);
if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) {
MS_LOG(ERROR) << "do subgraph const-fold failed.";
return lite::RET_ERROR;
}
}
}
if (!CheckCanCommonFold(cnode)) {
continue;
}
if (const_fold_processor_->DoConstantFold(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
bool ConstFoldWithInferShape::CheckCanCommonFold(const CNodePtr &cnode) const {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
if (IsSpecialType(cnode)) {
return false;
}
if (IsMarkedTrainOp(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom)) {
return false;
}
auto inputs = cnode->inputs();
return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default());
});
}
int ConstFoldWithInferShape::HandleSpecialFold(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
return lite::RET_OK;
}
if (node_infershape_ == nullptr) {
node_infershape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR);
}
MS_ASSERT(manager_ != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!CheckCanSpecialFold(cnode)) {
continue;
}
if (const_fold_processor_->DoConstantFold(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
bool ConstFoldWithInferShape::CheckCanSpecialFold(const CNodePtr &cnode) const {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
for (size_t i = 0; i < cnode->size(); ++i) {
auto input_node = cnode->input(i);
MS_CHECK_TRUE_RET(input_node != nullptr, false);
if (IsValueNode<FuncGraph>(input_node)) {
return false;
}
if (!input_node->isa<CNode>()) {
continue;
}
auto input_cnode = input_node->cast<CNodePtr>();
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_CHECK_TRUE_RET(input_prim != nullptr, false);
bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr ||
GetValue<bool>(input_prim->GetAttr(kIsLinkWithControlFlow));
if (is_link_with_control_flow) {
return false;
}
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_RET(prim != nullptr, false);
prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false));
if (IsSpecialType(cnode)) {
return false;
}
MS_ASSERT(node_infershape_ != nullptr);
auto status = node_infershape_->InferShape(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
return status == lite::RET_OK;
}
return CheckCanCommonFold(cnode);
}
} // namespace mindspore::opt

View File

@ -14,41 +14,37 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_
#include <set>
#include <utility>
#include <memory>
#include "backend/optimizer/common/pass.h"
#include "include/api/context.h"
#include "include/registry/converter_context.h"
#include "src/inner_context.h"
#include "tools/optimizer/graph/node_infershape.h"
#include "tools/optimizer/const_fold/fold_utils.h"
namespace mindspore {
namespace opt {
class ConstFoldPass : public Pass {
class ConstFoldWithInferShape : public Pass {
public:
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ConstFoldPass() override = default;
explicit ConstFoldWithInferShape(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: Pass("ConstFoldWithInferShape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ConstFoldWithInferShape() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Init();
int HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited);
bool CheckCanCommonFold(const CNodePtr &cnode) const;
int HandleSpecialFold(const FuncGraphPtr &func_graph);
bool CheckCanSpecialFold(const CNodePtr &cnode) const;
int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
converter::FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
std::shared_ptr<lite::InnerContext> context_{nullptr};
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
std::shared_ptr<ConstFoldProcessor> const_fold_processor_{nullptr};
std::shared_ptr<NodeInferShape> node_infershape_{nullptr};
FuncGraphManagerPtr manager_{nullptr};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_

View File

@ -16,6 +16,7 @@
#include "tools/optimizer/graph/infershape_pass.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "nnacl/op_base.h"
#include "src/common/log_util.h"
@ -47,8 +48,8 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma
auto primitive = GetValueNode<PrimitivePtr>(real_cnode->input(0));
MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
if (primitive->GetAttr(ops::kFormat) == nullptr) {
MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
return lite::RET_ERROR;
MS_LOG(DEBUG) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
return lite::RET_NO_CHANGE;
}
auto format_attr = primitive->GetAttr(ops::kFormat);
MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed");
@ -110,11 +111,20 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(WARNING) << "exist op cannot support infer shape.";
return false;
}
manager_ = Manage(func_graph, true);
if (manager_ == nullptr) {
MS_LOG(ERROR) << "generate a manager for func_graph failed.";
return false;
}
if (InferProcess(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "infer shape failed.";
return false;
}
return ResetSubGraphInput();
if (ResetSubGraphInput() != lite::RET_OK) {
MS_LOG(ERROR) << "ResetSubGraphInput failed.";
return false;
}
return true;
}
bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
@ -165,6 +175,7 @@ bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
manager_->AddFuncGraph(func_graph);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
@ -184,32 +195,38 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
auto ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
return lite::RET_ERROR;
return RET_ERROR;
}
if (InferProcess(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
return lite::RET_ERROR;
return RET_ERROR;
}
if (SetSubGraphOutput(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
return RET_ERROR;
}
SetSubGraphOutput(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_ERROR;
return RET_ERROR;
}
ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
return lite::RET_ERROR;
return RET_ERROR;
}
if (InferProcess(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
return lite::RET_ERROR;
return RET_ERROR;
}
if (SetSubGraphOutput(sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
return RET_ERROR;
}
SetSubGraphOutput(sub_func_graph);
ret = SetSubGraphAbstract(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret;
return lite::RET_ERROR;
return RET_ERROR;
}
continue;
}
@ -218,6 +235,11 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "node infer shape failed, node is " << node->fullname_with_scope();
return lite::RET_ERROR;
}
status = PostProcess(func_graph, cnode);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "post process current node failed, node is " << node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
@ -260,35 +282,27 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) {
MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var();
}
} else {
}
if (utils::isa<Parameter>(cnode->input(index))) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
}
if (utils::isa<ValueNode>(cnode->input(index))) {
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, false);
if (status != lite::RET_OK) {
continue;
}
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.data_.empty()) {
auto tensor_info = std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec);
CHECK_NULL_RETURN(tensor_info);
param_node->set_default_param(tensor_info);
} else {
auto tensor_info = std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
data_info.data_.data(), data_info.data_.size());
CHECK_NULL_RETURN(tensor_info);
param_node->set_default_param(tensor_info);
}
auto tensor_info =
lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec, (TypeId)data_info.data_type_);
MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "created tensor is a nullptr.");
param_node->set_default_param(tensor_info);
}
}
return RET_OK;
}
void InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
STATUS InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
MS_ASSERT(sub_graph != nullptr);
auto return_node = sub_graph->get_return();
MS_ASSERT(return_node != nullptr);
@ -317,6 +331,7 @@ void InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
trans_cnode->set_fullname_with_scope(trans_input_name);
}
return_node->set_inputs(origin_input);
return lite::RET_OK;
}
STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
@ -371,24 +386,26 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap
return RET_OK;
}
bool InferShapePass::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
auto manager = sub_graph->manager();
MS_ASSERT(manager != nullptr);
int InferShapePass::ResetSubGraphInput() {
for (auto &iter : sub_inputs_map_) {
auto &sub_graph = iter.first;
auto &sub_inputs = iter.second;
MS_ASSERT(manager_ != nullptr);
for (auto &sub_input : sub_inputs) {
auto param_node = sub_graph->add_parameter();
MS_CHECK_TRUE_MSG(param_node != nullptr, false, "Add parameter Failed");
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "Add parameter Failed");
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->fullname_with_scope());
manager->Replace(sub_input, param_node);
if (!manager_->Replace(sub_input, param_node)) {
MS_LOG(ERROR) << "replace cnode failed.";
return RET_ERROR;
}
auto sub_param_input = sub_input->cast<ParameterPtr>();
MS_ASSERT(sub_param_input != nullptr);
sub_param_input->set_default_param(nullptr);
}
}
return true;
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/graph/node_infershape.h"
@ -27,23 +28,29 @@ namespace mindspore {
namespace opt {
class InferShapePass : public Pass {
public:
explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
const std::string &name = "InferShapePass")
: Pass(name), fmk_type_(fmk_type), train_flag_(train_flag) {}
~InferShapePass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
protected:
virtual STATUS PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { return lite::RET_OK; }
private:
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
STATUS InferProcess(const FuncGraphPtr &func_graph);
STATUS SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphOutput(const FuncGraphPtr &sub_graph);
STATUS SetSubGraphOutput(const FuncGraphPtr &sub_graph);
STATUS SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
bool ResetSubGraphInput();
int ResetSubGraphInput();
protected:
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
std::map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_{};
FuncGraphManagerPtr manager_{nullptr};
};
} // namespace opt
} // namespace mindspore