!44658 [lite]remove unused op whose output is equal to output

Merge pull request !44658 from 徐安越/master4
This commit is contained in:
i-robot 2022-12-02 08:13:33 +00:00 committed by Gitee
commit 87c159febe
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 202 additions and 3 deletions

View File

@ -111,6 +111,7 @@
#include "tools/optimizer/fusion/concat_concat_fusion.h"
#include "tools/optimizer/fusion/strided_slice_fusion.h"
#include "tools/optimizer/fusion/reduce_stack_fusion.h"
#include "tools/optimizer/fusion/remove_transitivity_op.h"
#include "tools/converter/import/cast_op_adjust.h"
#include "tools/converter/quantizer/quant_helper/remove_unused_quant_param.h"
#include "tools/converter/adapter/acl/plugin/acl_pass_plugin.h"
@ -338,7 +339,8 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
// the following pass needs to check the return value.
fusions = {std::make_shared<opt::MulReduceFusion>(), std::make_shared<opt::ReshapeReduceFusion>(),
std::make_shared<opt::AblateReshapeLikeOp>(), std::make_shared<opt::ConcatConcatFusion>(),
std::make_shared<opt::ReduceStackFusion>(), std::make_shared<opt::StridedSliceFusion>()};
std::make_shared<opt::ReduceStackFusion>(), std::make_shared<opt::RemoveTransitivityOp>(),
std::make_shared<opt::StridedSliceFusion>(), std::make_shared<opt::RemoveTransitivityOp>()};
for (auto &pass : fusions) {
MS_CHECK_TRUE_MSG(pass != nullptr, RET_ERROR, "pass is a nullptr.");
if (param->fusion_blacklists.find(pass->name()) != param->fusion_blacklists.end()) {

View File

@ -0,0 +1,152 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define USE_DEPRECATED_API
#include "tools/optimizer/fusion/remove_transitivity_op.h"
#include <vector>
#include "tools/optimizer/fusion/strided_slice_checker.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ops/op_name.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace opt {
bool RemoveTransitivityOp::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is a nullptr, cannot do RemoveTransitivityOp.";
return false;
}
auto ret = preprocessor_.Run(func_graph);
if (ret != lite::RET_OK && ret != lite::RET_NOT_SUPPORT) {
MS_LOG(ERROR) << "Do dynamic-shape infer failed.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsMarkedTrainOp(cnode)) {
continue;
}
int status = lite::RET_OK;
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
status = HandleStridedSlice(func_graph, cnode);
} else if (CheckPrimitiveType(cnode, prim::kPrimConcat)) {
status = HandleConcat(func_graph, cnode);
} else if (CheckPrimitiveType(cnode, prim::kPrimReduceFusion)) {
status = HandleReduce(func_graph, cnode);
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Do RemoveTransitivityOp failed, node is " << node->fullname_with_scope();
return false;
}
}
return true;
}
int RemoveTransitivityOp::HandleStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &strided_slice) {
MS_ASSERT(func_graph != nullptr && strided_slice != nullptr);
if (!StridedSliceChecker::CheckCommonInfo(strided_slice)) {
return lite::RET_OK;
}
auto prim = GetCNodePrimitive(strided_slice);
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "StridedSlice's prim is a nullptr.");
if (IsQuantParameterNode(prim)) {
return lite::RET_OK;
}
std::vector<int> begin;
auto ret = StridedSliceChecker::GetBegin(strided_slice, &begin);
if (ret == lite::RET_NOT_SUPPORT) {
return lite::RET_OK;
}
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Get Strided_slice's begin failed, node is " << strided_slice->fullname_with_scope();
return ret;
}
std::vector<int> end;
ret = StridedSliceChecker::GetEnd(strided_slice, &end);
if (ret == lite::RET_NOT_SUPPORT) {
return lite::RET_OK;
}
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Get Strided_slice's end failed, node is " << strided_slice->fullname_with_scope();
return ret;
}
MS_CHECK_TRUE_MSG(begin.size() == end.size(), lite::RET_ERROR, "Strided_slice begin-size is not equal end-size");
for (size_t i = 0; i < begin.size(); ++i) {
if (begin[i] != 0 || end[i] != INT_MAX) {
return lite::RET_OK;
}
}
return DoReplace(func_graph, strided_slice);
}
int RemoveTransitivityOp::HandleConcat(const FuncGraphPtr &func_graph, const CNodePtr &concat) {
MS_ASSERT(func_graph != nullptr && concat != nullptr);
auto prim = GetCNodePrimitive(concat);
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "Concat's prim is a nullptr.");
if (IsQuantParameterNode(prim)) {
return lite::RET_OK;
}
if (concat->size() != kInputSizeTwo || CheckPrimitiveType(concat->input(1), prim::kPrimMakeTuple) ||
CheckPrimitiveType(concat->input(1), kPrimMakeTupleV2)) {
return lite::RET_OK;
}
return DoReplace(func_graph, concat);
}
int RemoveTransitivityOp::HandleReduce(const FuncGraphPtr &func_graph, const CNodePtr &reduce) {
MS_ASSERT(func_graph != nullptr && reduce != nullptr);
auto &shape_container = preprocessor_.GetShapeContainer();
if (shape_container.find(reduce) == shape_container.end()) {
return lite::RET_OK;
}
auto prim = GetCNodePrimitive(reduce);
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "Reduce's prim is a nullptr.");
if (!IsReduceModeMeetOutEqualIn(prim)) {
return lite::RET_OK;
}
if (IsQuantParameterNode(prim)) {
return lite::RET_OK;
}
auto attr = prim->GetAttr(ops::kCoeff);
if (attr != nullptr && fabs(GetValue<float>(attr) - 1.f) > FLT_EPSILON) {
return lite::RET_OK;
}
auto &in_shapes = shape_container.at(reduce).first;
auto &out_shapes = shape_container.at(reduce).second;
if (in_shapes.empty() || out_shapes.empty()) {
return lite::RET_OK;
}
if (in_shapes.front() != out_shapes.front()) {
return lite::RET_OK;
}
return DoReplace(func_graph, reduce);
}
int RemoveTransitivityOp::DoReplace(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto manager = func_graph->manager();
MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_NULL_PTR, "Manager is a nullptr.");
if (!manager->Replace(cnode, cnode->input(1))) {
MS_LOG(ERROR) << "Do manager-Replace failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_REMOVE_TRANSITIVITY_OP_H
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_REMOVE_TRANSITIVITY_OP_H
#include "backend/common/optimizer/pass.h"
#include "tools/optimizer/graph/preprocess_dynamic_shape.h"
namespace mindspore {
namespace opt {
// remove the op whose output is equal to its input.
class RemoveTransitivityOp : public Pass {
public:
RemoveTransitivityOp() : Pass("RemoveTransitivityOp") {}
~RemoveTransitivityOp() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
int HandleStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &strided_slice);
int HandleConcat(const FuncGraphPtr &func_graph, const CNodePtr &concat);
int HandleReduce(const FuncGraphPtr &func_graph, const CNodePtr &reduce);
int DoReplace(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
DynamicShapePreprocessor preprocessor_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_REMOVE_TRANSITIVITY_OP_H

View File

@ -24,7 +24,10 @@
namespace mindspore {
namespace opt {
bool StridedSliceChecker::CheckCommonInfo(const CNodePtr &strided_slice) {
if (strided_slice == nullptr) {
if (strided_slice == nullptr || strided_slice->size() > kInputSizeFive) {
return false;
}
if (IsMarkedTrainOp(strided_slice)) {
return false;
}
auto prim = GetCNodePrimitive(strided_slice);
@ -131,10 +134,11 @@ int StridedSliceChecker::GetConstTensor(const CNodePtr &strided_slice, size_t in
<< strided_slice->fullname_with_scope();
return lite::RET_NOT_SUPPORT;
}
if (lite::FetchConstData(strided_slice, index, converter::kFmkTypeMs, data_info, false) != lite::RET_OK) {
if (lite::FetchConstData(strided_slice, index, converter::kFmkTypeMs, data_info, true) != lite::RET_OK) {
MS_LOG(ERROR) << "Get Strided_slice " << index << "-input failed, node is " << strided_slice->fullname_with_scope();
return lite::RET_ERROR;
}
data_info->data_ptr_ = data_info->data_.data();
if (data_info->data_ptr_ == nullptr ||
(data_info->data_type_ != kNumberTypeInt && data_info->data_type_ != kNumberTypeInt32)) {
MS_LOG(ERROR) << "Get Strided_slice's constant failed, node name is " << strided_slice->fullname_with_scope();