From 2861e5462d99f706792018c2f12bdf48da33753c Mon Sep 17 00:00:00 2001 From: Hoai Linh Tran Date: Tue, 21 Jul 2020 11:07:02 -0400 Subject: [PATCH] Add optimization pass to remove redundant Select, fix uninitiated parameter in test_bert_train.py script Create a new pass named ValueBasedEliminate to reduce the load of Arithmetic Simplify Code review --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 5 ++ mindspore/ccsrc/frontend/optimizer/irpass.h | 3 ++ .../optimizer/irpass/value_based_eliminate.cc | 48 +++++++++++++++++++ .../optimizer/irpass/value_based_eliminate.h | 42 ++++++++++++++++ mindspore/ccsrc/pipeline/jit/pass.cc | 13 ++--- tests/perf_test/bert/test_bert_train.py | 3 +- 6 files changed, 104 insertions(+), 10 deletions(-) create mode 100644 mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc create mode 100644 mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index c0242ccacb..0b2f4bd540 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -41,6 +41,7 @@ #include "frontend/optimizer/irpass/symbol_resolver.h" #include "frontend/optimizer/irpass/tile_eliminate.h" #include "frontend/optimizer/irpass/transpose_eliminate.h" +#include "frontend/optimizer/irpass/value_based_eliminate.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/irpass/indexed_slices_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" @@ -165,6 +166,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { sparse_tensor_eliminate_ = MakeSubstitution( std::make_shared(), "sparse_tensor_eliminate", {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); + + // Value_Based Eliminate + value_based_eliminate_ = + MakeSubstitution(std::make_shared(), "value_based_eliminate", {prim::kPrimSelect}); } ResolveIRPassLib::ResolveIRPassLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 31aaeac781..d20afe7a79 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -110,6 +110,9 @@ class OptimizeIRPassLib { // SparseTensor Eliminate SubstitutionPtr sparse_tensor_eliminate_; + + // Value_Based Eliminate + SubstitutionPtr value_based_eliminate_; }; // the collection of irpass for resolve action diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc new file mode 100644 index 0000000000..365859ab4f --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/value_based_eliminate.h" + +namespace mindspore { +namespace opt { +namespace irpass { +bool IsCNodePositive(const AnfNodePtr &node) { + if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { + return IsCNodePositive(node->cast()->input(1)); + } + if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { + return true; + } + return false; +} + +AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode x, y, z; + PConstant zero_(node, false, 0); + PConstant zero_scalar_(node, false, 0, true); + + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y, + IsCNodePositive(x.GetNode(node))); + + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, + IsCNodePositive(x.GetNode(node))); + + return nullptr; +} + +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h new file mode 100644 index 0000000000..eca5fb4dbd --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" +#include "ir/pattern_matcher.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0 +class ValueBasedEliminate : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index ee2e910c07..7bd5bc12ee 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -162,15 +162,10 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { } OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = opt::OptPassConfig({ - irpass.zero_like_fill_zero_, - irpass.item_tuple_eliminate_, - irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, - irpass.inline_, - irpass.special_op_eliminate_, - irpass.get_make_ref_eliminate_, - }); + opt::OptPassConfig b_1 = + opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, + irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, diff --git a/tests/perf_test/bert/test_bert_train.py b/tests/perf_test/bert/test_bert_train.py index e4cd2f4a75..705318c283 100644 --- a/tests/perf_test/bert/test_bert_train.py +++ b/tests/perf_test/bert/test_bert_train.py @@ -22,14 +22,15 @@ import os import mindspore.common.dtype as mstype import mindspore.context as context from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.nn.optim import AdamWeightDecay from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.nn import learning_rate_schedule as lr_schedules -from mindspore.ops import operations as P from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from ...dataset_mock import MindData from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph + _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data" context.set_context(mode=context.GRAPH_MODE)