Add optimization pass to remove redundant Select, fix uninitiated parameter in test_bert_train.py script

Create a new pass named ValueBasedEliminate to reduce the load of Arithmetic Simplify

Code review
This commit is contained in:
Hoai Linh Tran 2020-07-21 11:07:02 -04:00
parent 16079e6356
commit 2861e5462d
6 changed files with 104 additions and 10 deletions

View File

@ -41,6 +41,7 @@
#include "frontend/optimizer/irpass/symbol_resolver.h"
#include "frontend/optimizer/irpass/tile_eliminate.h"
#include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/irpass/value_based_eliminate.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
@ -165,6 +166,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
sparse_tensor_eliminate_ = MakeSubstitution(
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
// Value_Based Eliminate
value_based_eliminate_ =
MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", {prim::kPrimSelect});
}
ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -110,6 +110,9 @@ class OptimizeIRPassLib {
// SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_;
// Value_Based Eliminate
SubstitutionPtr value_based_eliminate_;
};
// the collection of irpass for resolve action

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/irpass/value_based_eliminate.h"
namespace mindspore {
namespace opt {
namespace irpass {
bool IsCNodePositive(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
return IsCNodePositive(node->cast<CNodePtr>()->input(1));
}
if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) {
return true;
}
return false;
}
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y, z;
PConstant zero_(node, false, 0);
PConstant zero_scalar_(node, false, 0, true);
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y,
IsCNodePositive(x.GetNode(node)));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y,
IsCNodePositive(x.GetNode(node)));
return nullptr;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/pattern_matcher.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
class ValueBasedEliminate : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_

View File

@ -162,15 +162,10 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
}
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig({
irpass.zero_like_fill_zero_,
irpass.item_tuple_eliminate_,
irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_,
irpass.inline_,
irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_,
});
opt::OptPassConfig b_1 =
opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,

View File

@ -22,14 +22,15 @@ import os
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn import learning_rate_schedule as lr_schedules
from mindspore.ops import operations as P
from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from ...dataset_mock import MindData
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
_current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data"
context.set_context(mode=context.GRAPH_MODE)