!44658 [lite]remove unused op whose output is equal to output
Merge pull request !44658 from 徐安越/master4
This commit is contained in:
commit
87c159febe
|
@ -111,6 +111,7 @@
|
|||
#include "tools/optimizer/fusion/concat_concat_fusion.h"
|
||||
#include "tools/optimizer/fusion/strided_slice_fusion.h"
|
||||
#include "tools/optimizer/fusion/reduce_stack_fusion.h"
|
||||
#include "tools/optimizer/fusion/remove_transitivity_op.h"
|
||||
#include "tools/converter/import/cast_op_adjust.h"
|
||||
#include "tools/converter/quantizer/quant_helper/remove_unused_quant_param.h"
|
||||
#include "tools/converter/adapter/acl/plugin/acl_pass_plugin.h"
|
||||
|
@ -338,7 +339,8 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
|
|||
// the following pass needs to check the return value.
|
||||
fusions = {std::make_shared<opt::MulReduceFusion>(), std::make_shared<opt::ReshapeReduceFusion>(),
|
||||
std::make_shared<opt::AblateReshapeLikeOp>(), std::make_shared<opt::ConcatConcatFusion>(),
|
||||
std::make_shared<opt::ReduceStackFusion>(), std::make_shared<opt::StridedSliceFusion>()};
|
||||
std::make_shared<opt::ReduceStackFusion>(), std::make_shared<opt::RemoveTransitivityOp>(),
|
||||
std::make_shared<opt::StridedSliceFusion>(), std::make_shared<opt::RemoveTransitivityOp>()};
|
||||
for (auto &pass : fusions) {
|
||||
MS_CHECK_TRUE_MSG(pass != nullptr, RET_ERROR, "pass is a nullptr.");
|
||||
if (param->fusion_blacklists.find(pass->name()) != param->fusion_blacklists.end()) {
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define USE_DEPRECATED_API
|
||||
#include "tools/optimizer/fusion/remove_transitivity_op.h"
|
||||
#include <vector>
|
||||
#include "tools/optimizer/fusion/strided_slice_checker.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ops/op_name.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool RemoveTransitivityOp::Run(const FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func_graph is a nullptr, cannot do RemoveTransitivityOp.";
|
||||
return false;
|
||||
}
|
||||
auto ret = preprocessor_.Run(func_graph);
|
||||
if (ret != lite::RET_OK && ret != lite::RET_NOT_SUPPORT) {
|
||||
MS_LOG(ERROR) << "Do dynamic-shape infer failed.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsMarkedTrainOp(cnode)) {
|
||||
continue;
|
||||
}
|
||||
int status = lite::RET_OK;
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
|
||||
status = HandleStridedSlice(func_graph, cnode);
|
||||
} else if (CheckPrimitiveType(cnode, prim::kPrimConcat)) {
|
||||
status = HandleConcat(func_graph, cnode);
|
||||
} else if (CheckPrimitiveType(cnode, prim::kPrimReduceFusion)) {
|
||||
status = HandleReduce(func_graph, cnode);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Do RemoveTransitivityOp failed, node is " << node->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int RemoveTransitivityOp::HandleStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &strided_slice) {
|
||||
MS_ASSERT(func_graph != nullptr && strided_slice != nullptr);
|
||||
if (!StridedSliceChecker::CheckCommonInfo(strided_slice)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(strided_slice);
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "StridedSlice's prim is a nullptr.");
|
||||
if (IsQuantParameterNode(prim)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::vector<int> begin;
|
||||
auto ret = StridedSliceChecker::GetBegin(strided_slice, &begin);
|
||||
if (ret == lite::RET_NOT_SUPPORT) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get Strided_slice's begin failed, node is " << strided_slice->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
std::vector<int> end;
|
||||
ret = StridedSliceChecker::GetEnd(strided_slice, &end);
|
||||
if (ret == lite::RET_NOT_SUPPORT) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get Strided_slice's end failed, node is " << strided_slice->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(begin.size() == end.size(), lite::RET_ERROR, "Strided_slice begin-size is not equal end-size");
|
||||
for (size_t i = 0; i < begin.size(); ++i) {
|
||||
if (begin[i] != 0 || end[i] != INT_MAX) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
}
|
||||
return DoReplace(func_graph, strided_slice);
|
||||
}
|
||||
|
||||
int RemoveTransitivityOp::HandleConcat(const FuncGraphPtr &func_graph, const CNodePtr &concat) {
|
||||
MS_ASSERT(func_graph != nullptr && concat != nullptr);
|
||||
auto prim = GetCNodePrimitive(concat);
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "Concat's prim is a nullptr.");
|
||||
if (IsQuantParameterNode(prim)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (concat->size() != kInputSizeTwo || CheckPrimitiveType(concat->input(1), prim::kPrimMakeTuple) ||
|
||||
CheckPrimitiveType(concat->input(1), kPrimMakeTupleV2)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
return DoReplace(func_graph, concat);
|
||||
}
|
||||
|
||||
int RemoveTransitivityOp::HandleReduce(const FuncGraphPtr &func_graph, const CNodePtr &reduce) {
|
||||
MS_ASSERT(func_graph != nullptr && reduce != nullptr);
|
||||
auto &shape_container = preprocessor_.GetShapeContainer();
|
||||
if (shape_container.find(reduce) == shape_container.end()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(reduce);
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "Reduce's prim is a nullptr.");
|
||||
if (!IsReduceModeMeetOutEqualIn(prim)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (IsQuantParameterNode(prim)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto attr = prim->GetAttr(ops::kCoeff);
|
||||
if (attr != nullptr && fabs(GetValue<float>(attr) - 1.f) > FLT_EPSILON) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto &in_shapes = shape_container.at(reduce).first;
|
||||
auto &out_shapes = shape_container.at(reduce).second;
|
||||
if (in_shapes.empty() || out_shapes.empty()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (in_shapes.front() != out_shapes.front()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
return DoReplace(func_graph, reduce);
|
||||
}
|
||||
|
||||
int RemoveTransitivityOp::DoReplace(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto manager = func_graph->manager();
|
||||
MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_NULL_PTR, "Manager is a nullptr.");
|
||||
if (!manager->Replace(cnode, cnode->input(1))) {
|
||||
MS_LOG(ERROR) << "Do manager-Replace failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_REMOVE_TRANSITIVITY_OP_H
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_REMOVE_TRANSITIVITY_OP_H
|
||||
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
#include "tools/optimizer/graph/preprocess_dynamic_shape.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// remove the op whose output is equal to its input.
|
||||
class RemoveTransitivityOp : public Pass {
|
||||
public:
|
||||
RemoveTransitivityOp() : Pass("RemoveTransitivityOp") {}
|
||||
~RemoveTransitivityOp() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
int HandleStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &strided_slice);
|
||||
int HandleConcat(const FuncGraphPtr &func_graph, const CNodePtr &concat);
|
||||
int HandleReduce(const FuncGraphPtr &func_graph, const CNodePtr &reduce);
|
||||
int DoReplace(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
DynamicShapePreprocessor preprocessor_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_REMOVE_TRANSITIVITY_OP_H
|
|
@ -24,7 +24,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool StridedSliceChecker::CheckCommonInfo(const CNodePtr &strided_slice) {
|
||||
if (strided_slice == nullptr) {
|
||||
if (strided_slice == nullptr || strided_slice->size() > kInputSizeFive) {
|
||||
return false;
|
||||
}
|
||||
if (IsMarkedTrainOp(strided_slice)) {
|
||||
return false;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(strided_slice);
|
||||
|
@ -131,10 +134,11 @@ int StridedSliceChecker::GetConstTensor(const CNodePtr &strided_slice, size_t in
|
|||
<< strided_slice->fullname_with_scope();
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
if (lite::FetchConstData(strided_slice, index, converter::kFmkTypeMs, data_info, false) != lite::RET_OK) {
|
||||
if (lite::FetchConstData(strided_slice, index, converter::kFmkTypeMs, data_info, true) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get Strided_slice " << index << "-input failed, node is " << strided_slice->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
data_info->data_ptr_ = data_info->data_.data();
|
||||
if (data_info->data_ptr_ == nullptr ||
|
||||
(data_info->data_type_ != kNumberTypeInt && data_info->data_type_ != kNumberTypeInt32)) {
|
||||
MS_LOG(ERROR) << "Get Strided_slice's constant failed, node name is " << strided_slice->fullname_with_scope();
|
||||
|
|
Loading…
Reference in New Issue