forked from mindspore-Ecosystem/mindspore
module constant-fold optimizer
This commit is contained in:
parent
b93494c6bb
commit
ae1116a724
|
@ -23,7 +23,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "tools/optimizer/const_fold/constant_folding_fusion.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "test/common/import_from_meta_graphT.h"
|
||||
|
||||
|
|
|
@ -54,12 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${SRC_DIR}/common/dynamic_library_loader.cc
|
||||
${SRC_DIR}/train/train_populate_parameter.cc
|
||||
|
||||
../optimizer/common/*.cc
|
||||
../optimizer/format/*.cc
|
||||
../optimizer/fusion/*.cc
|
||||
../optimizer/fisson/*.cc
|
||||
../optimizer/parallel/*.cc
|
||||
../optimizer/graph/*.cc
|
||||
../optimizer/*.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_exporter anf_exporter)
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "tools/optimizer/fusion/conv_scale_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_bn_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "tools/optimizer/const_fold/constant_folding_fusion.h"
|
||||
#include "tools/optimizer/fusion/norm_fusion.h"
|
||||
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
||||
#include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h"
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "tools/optimizer/const_fold/fold_along_infershape.h"
|
||||
#include "tools/optimizer/const_fold/fold_with_infershape.h"
|
||||
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConstFoldPass : public Pass {
|
||||
public:
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ConstFoldPass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func_graph is nullptr, do constant fold failed.";
|
||||
return false;
|
||||
}
|
||||
// current infer-shape cannot support the control-flow model of Mindir.
|
||||
if (fmk_type_ == converter::kFmkTypeMs) {
|
||||
auto fold_schedule = ConstFoldWithInferShape(fmk_type_, train_flag_);
|
||||
if (!fold_schedule.Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Do constant fold failed.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto fold_schedule = ConstFoldAlongInferShape(fmk_type_, train_flag_);
|
||||
if (!fold_schedule.Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Do constant fold failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// the attrs of convolution only can be update after constant fold.
|
||||
auto update_attrs = UpdateConv2DParamPass();
|
||||
if (!update_attrs.Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "update attrs failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_CONSTANT_FOLDING_FUSION_H_
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/const_fold/fold_along_infershape.h"
|
||||
#include <memory>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
STATUS ConstFoldAlongInferShape::PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!CheckCanFold(func_graph, cnode)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (const_fold_processor_ == nullptr) {
|
||||
const_fold_processor_ = std::make_shared<ConstFoldProcessor>(fmk_type_, train_flag_);
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(const_fold_processor_ != nullptr, lite::RET_NULL_PTR, "const fold processor is nullptr");
|
||||
auto status = const_fold_processor_->DoConstantFold(func_graph, cnode);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold failed, the node is " << cnode->fullname_with_scope();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
bool ConstFoldAlongInferShape::CheckCanFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (IsSpecialType(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom) || IsMarkedTrainOp(cnode)) {
|
||||
return false;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto is_inferred = prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(prim->GetAttr(kInferDone));
|
||||
if (!is_inferred) {
|
||||
return false;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
|
||||
return lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() != 0;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
auto graph_inputs =
|
||||
sub_inputs_map_.find(func_graph) != sub_inputs_map_.end() ? sub_inputs_map_[func_graph] : func_graph->get_inputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(), [&graph_inputs](const AnfNodePtr &node) {
|
||||
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
|
||||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default() &&
|
||||
std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end());
|
||||
});
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
#include "tools/optimizer/const_fold/fold_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConstFoldAlongInferShape : public InferShapePass {
|
||||
public:
|
||||
explicit ConstFoldAlongInferShape(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: InferShapePass(fmk_type, train_flag, "ConstFoldAlongInferShape") {}
|
||||
~ConstFoldAlongInferShape() override = default;
|
||||
|
||||
private:
|
||||
STATUS PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) override;
|
||||
bool CheckCanFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
std::shared_ptr<ConstFoldProcessor> const_fold_processor_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_ALONG_INFERSHAPE_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,13 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "tools/optimizer/const_fold/fold_utils.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "tools/anf_exporter/fetch_content.h"
|
||||
#include "ir/anf.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
@ -31,15 +30,15 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "tools/optimizer/graph/lite_tensor_extractor.h"
|
||||
|
||||
using mindspore::lite::KernelRegistry;
|
||||
using mindspore::lite::Tensor;
|
||||
namespace mindspore::opt {
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
constexpr auto kIsLinkWithControlFlow = "link_with_control_flow";
|
||||
ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
|
@ -47,8 +46,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
|||
MS_CHECK_TRUE_RET(parameter != nullptr, nullptr);
|
||||
std::vector<int> shape(tensor->shape());
|
||||
std::vector<int64_t> shape_vector;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
|
||||
if (tensor_info == nullptr) {
|
||||
|
@ -75,7 +74,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
|||
MS_ASSERT(outputs != nullptr && cnode != nullptr && context != nullptr && ms_context != nullptr);
|
||||
OpParameter *parameter;
|
||||
auto ret = lite::FetchOpParameterFromNode(cnode->input(0), ¶meter);
|
||||
if (ret != lite::RET_OK) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " FetchOpParameterFromNode failed. ";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -91,7 +90,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
|||
}
|
||||
auto data_type = inputs.front()->data_type();
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(parameter->type_)};
|
||||
kernel::LiteKernel *lite_kernel;
|
||||
kernel::LiteKernel *lite_kernel = nullptr;
|
||||
ret = lite::KernelRegistry::GetInstance()->GetKernel(inputs, *outputs, context, ms_context, desc, parameter,
|
||||
&lite_kernel);
|
||||
if (ret != lite::RET_OK) {
|
||||
|
@ -104,8 +103,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
|||
ret = lite_kernel->Prepare();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init failed.";
|
||||
// op_parameter is free by lite_kernel destructor
|
||||
delete lite_kernel;
|
||||
delete lite_kernel; // parameter will be freed in destructor of lite-kernel.
|
||||
return nullptr;
|
||||
}
|
||||
return lite_kernel;
|
||||
|
@ -194,27 +192,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *>
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool ConstFoldPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, false);
|
||||
manager_ = Manage(func_graph);
|
||||
MS_CHECK_TRUE_RET(manager_ != nullptr, false);
|
||||
if (!Init()) {
|
||||
MS_LOG(ERROR) << "initial constant fold pass failed.";
|
||||
return false;
|
||||
}
|
||||
std::set<FuncGraphPtr> has_visited;
|
||||
if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold pass failed,";
|
||||
return false;
|
||||
}
|
||||
if (HandleSpecialFold(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold pass failed,";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstFoldPass::Init() {
|
||||
bool ConstFoldProcessor::Init() {
|
||||
if (context_ == nullptr) {
|
||||
context_ = std::make_shared<lite::InnerContext>();
|
||||
MS_CHECK_TRUE_RET(context_ != nullptr, false);
|
||||
|
@ -230,119 +208,15 @@ bool ConstFoldPass::Init() {
|
|||
return true;
|
||||
}
|
||||
|
||||
int ConstFoldPass::HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (has_visited->find(func_graph) != has_visited->end()) {
|
||||
return lite::RET_OK;
|
||||
int ConstFoldProcessor::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
if (func_graph == nullptr || cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "input param is nullptr.";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
has_visited->insert(func_graph);
|
||||
MS_ASSERT(manager_ != nullptr);
|
||||
manager_->AddFuncGraph(func_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
if (IsValueNode<FuncGraph>(cnode->input(i))) {
|
||||
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(i));
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do subgraph const-fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!CheckCanCommonFold(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!Init()) {
|
||||
MS_LOG(ERROR) << "initial context failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConstFoldPass::CheckCanCommonFold(const CNodePtr &cnode) const {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||
if (IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
if (IsMarkedTrainOp(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom)) {
|
||||
return false;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
|
||||
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
|
||||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default());
|
||||
});
|
||||
}
|
||||
|
||||
int ConstFoldPass::HandleSpecialFold(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (node_infershape_ == nullptr) {
|
||||
node_infershape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
|
||||
MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR);
|
||||
}
|
||||
MS_ASSERT(manager_ != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!CheckCanSpecialFold(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
auto input_node = cnode->input(i);
|
||||
MS_CHECK_TRUE_RET(input_node != nullptr, false);
|
||||
if (IsValueNode<FuncGraph>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
if (!input_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(input_prim != nullptr, false);
|
||||
bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr ||
|
||||
GetValue<bool>(input_prim->GetAttr(kIsLinkWithControlFlow));
|
||||
if (is_link_with_control_flow) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, false);
|
||||
prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false));
|
||||
if (IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
MS_ASSERT(node_infershape_ != nullptr);
|
||||
auto status = node_infershape_->InferShape(cnode);
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
|
||||
return status == lite::RET_OK;
|
||||
}
|
||||
return CheckCanCommonFold(cnode);
|
||||
}
|
||||
|
||||
int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
std::vector<TensorPtr> inputs_ptr;
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, true) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "extract input tensor from cnode failed. " << cnode->fullname_with_scope();
|
||||
|
@ -363,12 +237,12 @@ int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
MS_LOG(DEBUG) << "this op is control flow op, which is not supported now.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::vector<Tensor *> input_tensors;
|
||||
std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(input_tensors),
|
||||
[](const TensorPtr &input) { return input.get(); });
|
||||
std::vector<Tensor *> output_tensors;
|
||||
std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(output_tensors),
|
||||
[](const TensorPtr &output) { return output.get(); });
|
||||
std::vector<lite::Tensor *> input_tensors;
|
||||
(void)std::transform(inputs_ptr.begin(), inputs_ptr.end(), std::back_inserter(input_tensors),
|
||||
[](const TensorPtr &input) { return input.get(); });
|
||||
std::vector<lite::Tensor *> output_tensors;
|
||||
(void)std::transform(outputs_ptr.begin(), outputs_ptr.end(), std::back_inserter(output_tensors),
|
||||
[](const TensorPtr &output) { return output.get(); });
|
||||
if (CopyQuantParams(cnode, input_tensors, output_tensors) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "copy quant params failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
@ -402,4 +276,5 @@ int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
}
|
||||
return status;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "src/inner_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConstFoldProcessor {
|
||||
public:
|
||||
explicit ConstFoldProcessor(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
~ConstFoldProcessor() = default;
|
||||
|
||||
private:
|
||||
bool Init();
|
||||
converter::FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
std::shared_ptr<lite::InnerContext> context_{nullptr};
|
||||
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_UTILS_H_
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/const_fold/fold_with_infershape.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr auto kIsLinkWithControlFlow = "link_with_control_flow";
|
||||
} // namespace
|
||||
|
||||
bool ConstFoldWithInferShape::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, false);
|
||||
manager_ = Manage(func_graph);
|
||||
MS_CHECK_TRUE_RET(manager_ != nullptr, false);
|
||||
if (const_fold_processor_ == nullptr) {
|
||||
const_fold_processor_ = std::make_shared<ConstFoldProcessor>(fmk_type_, train_flag_);
|
||||
}
|
||||
MS_CHECK_TRUE_RET(const_fold_processor_ != nullptr, false);
|
||||
std::set<FuncGraphPtr> has_visited;
|
||||
if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold pass failed,";
|
||||
return false;
|
||||
}
|
||||
if (HandleSpecialFold(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold pass failed,";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int ConstFoldWithInferShape::HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (has_visited->find(func_graph) != has_visited->end()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
has_visited->insert(func_graph);
|
||||
MS_ASSERT(manager_ != nullptr);
|
||||
manager_->AddFuncGraph(func_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
if (IsValueNode<FuncGraph>(cnode->input(i))) {
|
||||
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(i));
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do subgraph const-fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!CheckCanCommonFold(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (const_fold_processor_->DoConstantFold(func_graph, cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConstFoldWithInferShape::CheckCanCommonFold(const CNodePtr &cnode) const {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||
if (IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
if (IsMarkedTrainOp(cnode) || CheckPrimitiveType(cnode, prim::kPrimCustom)) {
|
||||
return false;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
|
||||
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
|
||||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default());
|
||||
});
|
||||
}
|
||||
|
||||
int ConstFoldWithInferShape::HandleSpecialFold(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (node_infershape_ == nullptr) {
|
||||
node_infershape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
|
||||
MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR);
|
||||
}
|
||||
MS_ASSERT(manager_ != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!CheckCanSpecialFold(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (const_fold_processor_->DoConstantFold(func_graph, cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConstFoldWithInferShape::CheckCanSpecialFold(const CNodePtr &cnode) const {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
auto input_node = cnode->input(i);
|
||||
MS_CHECK_TRUE_RET(input_node != nullptr, false);
|
||||
if (IsValueNode<FuncGraph>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
if (!input_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(input_prim != nullptr, false);
|
||||
bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr ||
|
||||
GetValue<bool>(input_prim->GetAttr(kIsLinkWithControlFlow));
|
||||
if (is_link_with_control_flow) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, false);
|
||||
prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false));
|
||||
if (IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
MS_ASSERT(node_infershape_ != nullptr);
|
||||
auto status = node_infershape_->InferShape(cnode);
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
|
||||
return status == lite::RET_OK;
|
||||
}
|
||||
return CheckCanCommonFold(cnode);
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -14,41 +14,37 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_
|
||||
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "tools/optimizer/graph/node_infershape.h"
|
||||
#include "tools/optimizer/const_fold/fold_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConstFoldPass : public Pass {
|
||||
class ConstFoldWithInferShape : public Pass {
|
||||
public:
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ConstFoldPass() override = default;
|
||||
explicit ConstFoldWithInferShape(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("ConstFoldWithInferShape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ConstFoldWithInferShape() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Init();
|
||||
int HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited);
|
||||
bool CheckCanCommonFold(const CNodePtr &cnode) const;
|
||||
int HandleSpecialFold(const FuncGraphPtr &func_graph);
|
||||
bool CheckCanSpecialFold(const CNodePtr &cnode) const;
|
||||
int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
converter::FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
std::shared_ptr<lite::InnerContext> context_{nullptr};
|
||||
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
|
||||
std::shared_ptr<ConstFoldProcessor> const_fold_processor_{nullptr};
|
||||
std::shared_ptr<NodeInferShape> node_infershape_{nullptr};
|
||||
FuncGraphManagerPtr manager_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_CONST_FOLD_FOLD_WITH_INFERSHAPE_H_
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
|
@ -47,8 +48,8 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma
|
|||
auto primitive = GetValueNode<PrimitivePtr>(real_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
|
||||
if (primitive->GetAttr(ops::kFormat) == nullptr) {
|
||||
MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
MS_LOG(DEBUG) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto format_attr = primitive->GetAttr(ops::kFormat);
|
||||
MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed");
|
||||
|
@ -110,11 +111,20 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(WARNING) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
manager_ = Manage(func_graph, true);
|
||||
if (manager_ == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a manager for func_graph failed.";
|
||||
return false;
|
||||
}
|
||||
if (InferProcess(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
return ResetSubGraphInput();
|
||||
if (ResetSubGraphInput() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ResetSubGraphInput failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
||||
|
@ -165,6 +175,7 @@ bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
|||
|
||||
STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
manager_->AddFuncGraph(func_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
|
@ -184,32 +195,38 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
auto ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (InferProcess(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (SetSubGraphOutput(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
SetSubGraphOutput(sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (InferProcess(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (SetSubGraphOutput(sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
SetSubGraphOutput(sub_func_graph);
|
||||
ret = SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret;
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -218,6 +235,11 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "node infer shape failed, node is " << node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = PostProcess(func_graph, cnode);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "post process current node failed, node is " << node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -260,35 +282,27 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) {
|
||||
MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var();
|
||||
}
|
||||
} else {
|
||||
}
|
||||
if (utils::isa<Parameter>(cnode->input(index))) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
}
|
||||
if (utils::isa<ValueNode>(cnode->input(index))) {
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, false);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.data_.empty()) {
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec);
|
||||
CHECK_NULL_RETURN(tensor_info);
|
||||
param_node->set_default_param(tensor_info);
|
||||
} else {
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||
data_info.data_.data(), data_info.data_.size());
|
||||
CHECK_NULL_RETURN(tensor_info);
|
||||
param_node->set_default_param(tensor_info);
|
||||
}
|
||||
auto tensor_info =
|
||||
lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec, (TypeId)data_info.data_type_);
|
||||
MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "created tensor is a nullptr.");
|
||||
param_node->set_default_param(tensor_info);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
|
||||
STATUS InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
MS_ASSERT(return_node != nullptr);
|
||||
|
@ -317,6 +331,7 @@ void InferShapePass::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
|
|||
trans_cnode->set_fullname_with_scope(trans_input_name);
|
||||
}
|
||||
return_node->set_inputs(origin_input);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
|
@ -371,24 +386,26 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool InferShapePass::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
auto manager = sub_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
int InferShapePass::ResetSubGraphInput() {
|
||||
for (auto &iter : sub_inputs_map_) {
|
||||
auto &sub_graph = iter.first;
|
||||
auto &sub_inputs = iter.second;
|
||||
MS_ASSERT(manager_ != nullptr);
|
||||
for (auto &sub_input : sub_inputs) {
|
||||
auto param_node = sub_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, false, "Add parameter Failed");
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "Add parameter Failed");
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->fullname_with_scope());
|
||||
manager->Replace(sub_input, param_node);
|
||||
if (!manager_->Replace(sub_input, param_node)) {
|
||||
MS_LOG(ERROR) << "replace cnode failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto sub_param_input = sub_input->cast<ParameterPtr>();
|
||||
MS_ASSERT(sub_param_input != nullptr);
|
||||
sub_param_input->set_default_param(nullptr);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/graph/node_infershape.h"
|
||||
|
@ -27,23 +28,29 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class InferShapePass : public Pass {
|
||||
public:
|
||||
explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
const std::string &name = "InferShapePass")
|
||||
: Pass(name), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~InferShapePass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
virtual STATUS PostProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { return lite::RET_OK; }
|
||||
|
||||
private:
|
||||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||
STATUS InferProcess(const FuncGraphPtr &func_graph);
|
||||
STATUS SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphOutput(const FuncGraphPtr &sub_graph);
|
||||
STATUS SetSubGraphOutput(const FuncGraphPtr &sub_graph);
|
||||
STATUS SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
bool ResetSubGraphInput();
|
||||
int ResetSubGraphInput();
|
||||
|
||||
protected:
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
std::map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_{};
|
||||
FuncGraphManagerPtr manager_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue