!2585 Replace TransformFuncType with OptimizerCaller

Merge pull request !2585 from Giancarlo/remove_transformfunc
This commit is contained in:
mindspore-ci-bot 2020-06-28 10:12:18 +08:00 committed by Gitee
commit 262e4fc041
13 changed files with 227 additions and 208 deletions

View File

@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h"
#include "optimizer/opt.h"
namespace mindspore {
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
} // namespace opt
class OptimizerCaller {
public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
};
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_

View File

@ -14,140 +14,154 @@
* limitations under the License.
*/
#include "optimizer/irpass.h"
#include <string>
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"
namespace mindspore {
namespace opt {
namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
zero_like_fill_zero_ =
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ =
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate
item_tuple_eliminate_ =
MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose);
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ =
MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
reduce_eliminate_ = MakeSubstitution(
ReduceOneEliminater(), "reduce_eliminate",
std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
// Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ =
MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
// Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
// Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
// branch culling
switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ =
MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
"float_tuple_getitem_switch", prim::kPrimTupleGetItem);
float_env_getitem_switch_ =
MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup);
MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ =
MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);
// Addn
merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN);
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
// inline
inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph);
replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
replace_applicator_ =
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
// Incorporation
incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ =
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
"incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
// Virtual Dataset
virtual_dataset_eliminate_ =
MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel);
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel);
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
// Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
}
ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt

View File

@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr all_reduce_fg_{nullptr};
};
class ArithmeticSimplify {
class ArithmeticSimplify : public OptimizerCaller {
public:
ArithmeticSimplify()
: multiply_by_zero_or_one_(),
tensor_multiply_by_one_(),
add_by_zero_(),
tensor_add_by_zero_(),
identity_(prim::kPrimIdentity),
opt_update_zero_tensor_(),
constant_duplicate_mul_(),
power_one_() {
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
@ -761,10 +762,10 @@ class ArithmeticSimplify {
}
~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -773,15 +774,9 @@ class ArithmeticSimplify {
}
private:
MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByOne tensor_multiply_by_one_;
AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_,
opt_update_zero_tensor_, constant_duplicate_mul_, power_one_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// Arithmetic Simplifications should be done after step_parallel.
@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 {
class ArithmeticSimplify2 : public OptimizerCaller {
public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); }
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
}
private:
TensorMultiplyByZero tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt

View File

@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
namespace mindspore {
namespace opt {
@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, t_{nullptr};
};
class CastEliminater {
class CastEliminater : public OptimizerCaller {
public:
CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {}
~CastEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = cast_same_type_eliminater_(optimizer, node);
if (new_node != nullptr) {
return new_node;

View File

@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h"
namespace mindspore {
@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool is_match_{false};
};
class EnvGetItemEliminater {
class EnvGetItemEliminater : public OptimizerCaller {
public:
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() {
EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()) {
eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_);
}
~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -246,10 +250,8 @@ class EnvGetItemEliminater {
}
private:
NewEnvGetItem new_env_get_item_;
AddEnvGetItem add_env_get_item_;
EnvGetSetItem env_get_set_item_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}

View File

@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal::GetitemTransform getitem_transform_;
};
class IncorporateGetitemSet {
class IncorporateGetitemSet : public OptimizerCaller {
public:
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() {
IncorporateGetitemSet()
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_);
}
~IncorporateGetitemSet() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -403,9 +407,8 @@ class IncorporateGetitemSet {
}
private:
IncorporateGetitem incorporate_getitem_;
IncorporateGetitemSwitch incorporate_getitem_switch_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt

View File

@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
};
class ItemTupleEliminater {
class ItemTupleEliminater : public OptimizerCaller {
public:
ItemTupleEliminater()
: get_item_eliminater_(),
get_item_const_eliminater_(),
set_item_eliminater_(),
get_set_item_eliminater_(),
get_item_depend_reorder_() {
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_);
@ -277,10 +279,10 @@ class ItemTupleEliminater {
}
~ItemTupleEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -289,12 +291,9 @@ class ItemTupleEliminater {
}
private:
GetitemEliminater get_item_eliminater_;
GetitemConstEliminater get_item_const_eliminater_;
SetitemEliminater set_item_eliminater_;
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
get_item_depend_reorder_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt

View File

@ -19,9 +19,9 @@
#include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {

View File

@ -19,11 +19,12 @@
#include <vector>
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/dshape.h"
namespace mindspore {
@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, shape_{nullptr};
};
class ReshapeEliminater {
class ReshapeEliminater : public OptimizerCaller {
public:
ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {}
~ReshapeEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = reshape_same_shape_eliminater_(optimizer, node);
if (new_node != nullptr) {
return new_node;

View File

@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#include <securec.h>
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
class SpecialOpEliminater {
class SpecialOpEliminater : public OptimizerCaller {
public:
SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf),
stop_gradient_(prim::kPrimStopGradient),
hook_backward_(prim::kPrimHookBackward),
print_shape_type_(prim::kPrimPrintShapeType),
get_ref_value_(prim::kPrimGetRefValue),
mirror_(prim::kPrimMirror),
virtual_div_(prim::kPrimVirtualDiv) {
: insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(stop_gradient_);
eliminaters_.emplace_back(hook_backward_);
@ -53,10 +53,10 @@ class SpecialOpEliminater {
}
~SpecialOpEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -65,9 +65,9 @@ class SpecialOpEliminater {
}
private:
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
virtual_div_;
std::vector<TransformFuncType> eliminaters_{};
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// {PrimVirtualDataset, X} -> X

View File

@ -16,28 +16,27 @@
#include "optimizer/opt.h"
#include <algorithm>
#include <deque>
#include <memory>
#include <unordered_set>
#include <deque>
#include <algorithm>
#include "ir/anf.h"
#include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/log_adapter.h"
#include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
}
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
AnfNodePtr result = transform_(optimizer, node);
AnfNodePtr result = (*transform_)(optimizer, node);
#ifdef ENABLE_PROFILE
if (optimizer != nullptr) {
auto time = GetTime();

View File

@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#include <vector>
#include <string>
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class Substitution {
public:
TransformFuncType transform_{nullptr};
OptimizerCallerPtr transform_;
std::string name_;
PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_;
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
};
using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
class SubstitutionList {

View File

@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
};
void SetUp() {
elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R);
idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q);
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
}
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {