forked from OSSInnovation/mindspore
!2585 Replace TransformFuncType with OptimizerCaller
Merge pull request !2585 from Giancarlo/remove_transformfunc
This commit is contained in:
commit
262e4fc041
|
@ -17,13 +17,23 @@
|
|||
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
|
||||
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "optimizer/opt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class Optimizer;
|
||||
using OptimizerPtr = std::shared_ptr<Optimizer>;
|
||||
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
|
||||
|
||||
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
|
||||
} // namespace opt
|
||||
|
||||
class OptimizerCaller {
|
||||
public:
|
||||
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
|
||||
};
|
||||
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
|
||||
|
|
|
@ -14,140 +14,154 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "optimizer/irpass/symbol_resolver.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/irpass/arithmetic_simplify.h"
|
||||
#include "optimizer/irpass/special_op_eliminate.h"
|
||||
#include "optimizer/irpass/item_tuple_eliminate.h"
|
||||
#include "optimizer/irpass/env_item_eliminate.h"
|
||||
#include "optimizer/irpass/tile_eliminate.h"
|
||||
#include "optimizer/irpass/cast_eliminate.h"
|
||||
#include "optimizer/irpass/reshape_eliminate.h"
|
||||
#include "optimizer/irpass/transpose_eliminate.h"
|
||||
#include "optimizer/irpass/reduce_eliminate.h"
|
||||
#include "optimizer/irpass/partial_eliminate.h"
|
||||
#include "optimizer/irpass/ref_eliminate.h"
|
||||
#include "optimizer/irpass/merge_addn.h"
|
||||
#include "optimizer/irpass/branch_culling.h"
|
||||
#include "optimizer/irpass/gradient_eliminate.h"
|
||||
#include "optimizer/irpass/minmax_grad.h"
|
||||
#include "optimizer/irpass/inline.h"
|
||||
#include "optimizer/irpass/cast_eliminate.h"
|
||||
#include "optimizer/irpass/convert.h"
|
||||
#include "optimizer/irpass/specialize_transform.h"
|
||||
#include "optimizer/irpass/incorporate_getitem.h"
|
||||
#include "optimizer/irpass/incorporate_call.h"
|
||||
#include "optimizer/irpass/env_item_eliminate.h"
|
||||
#include "optimizer/irpass/grad_var_prepare.h"
|
||||
#include "optimizer/irpass/param_replace.h"
|
||||
#include "optimizer/irpass/gradient_eliminate.h"
|
||||
#include "optimizer/irpass/inline.h"
|
||||
#include "optimizer/irpass/incorporate_call.h"
|
||||
#include "optimizer/irpass/incorporate_getitem.h"
|
||||
#include "optimizer/irpass/item_tuple_eliminate.h"
|
||||
#include "optimizer/irpass/mark_interface_fusion.h"
|
||||
#include "optimizer/irpass/merge_addn.h"
|
||||
#include "optimizer/irpass/minmax_grad.h"
|
||||
#include "optimizer/irpass/param_replace.h"
|
||||
#include "optimizer/irpass/partial_eliminate.h"
|
||||
#include "optimizer/irpass/reduce_eliminate.h"
|
||||
#include "optimizer/irpass/ref_eliminate.h"
|
||||
#include "optimizer/irpass/reshape_eliminate.h"
|
||||
#include "optimizer/irpass/special_op_eliminate.h"
|
||||
#include "optimizer/irpass/specialize_transform.h"
|
||||
#include "optimizer/irpass/symbol_resolver.h"
|
||||
#include "optimizer/irpass/tile_eliminate.h"
|
||||
#include "optimizer/irpass/transpose_eliminate.h"
|
||||
#include "optimizer/opt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
||||
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
|
||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
|
||||
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
|
||||
arithmetic_simplify2_ =
|
||||
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
|
||||
special_op_eliminate_ =
|
||||
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
|
||||
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
||||
zero_like_fill_zero_ =
|
||||
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||
adjust_all_reduce_mul_add_ =
|
||||
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
||||
|
||||
// ops eliminate
|
||||
item_tuple_eliminate_ =
|
||||
MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
|
||||
tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile);
|
||||
cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast);
|
||||
reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape);
|
||||
transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose);
|
||||
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
|
||||
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
|
||||
tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
|
||||
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
|
||||
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
|
||||
transpose_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
|
||||
reduce_eliminate_ = MakeSubstitution(
|
||||
ReduceOneEliminater(), "reduce_eliminate",
|
||||
std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
|
||||
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
|
||||
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
|
||||
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
|
||||
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
||||
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
|
||||
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
|
||||
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
|
||||
check_bprop_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ =
|
||||
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
||||
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
|
||||
|
||||
// Env Item Eliminate
|
||||
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
|
||||
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
|
||||
env_get_item_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
|
||||
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_ =
|
||||
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_switch_ =
|
||||
MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
|
||||
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
|
||||
// Ref eliminate
|
||||
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
|
||||
make_ref_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
|
||||
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
|
||||
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
|
||||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
|
||||
replace_refkey_by_param_ =
|
||||
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
|
||||
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
|
||||
IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
|
||||
// Gradient transforms
|
||||
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
|
||||
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
|
||||
expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
|
||||
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
|
||||
|
||||
// branch culling
|
||||
switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch);
|
||||
float_tuple_getitem_switch_ =
|
||||
MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
|
||||
switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
|
||||
float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
|
||||
"float_tuple_getitem_switch", prim::kPrimTupleGetItem);
|
||||
float_env_getitem_switch_ =
|
||||
MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup);
|
||||
MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
convert_switch_replacement_ =
|
||||
MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);
|
||||
|
||||
// Addn
|
||||
merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN);
|
||||
addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN);
|
||||
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
|
||||
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
|
||||
|
||||
// inline
|
||||
inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph);
|
||||
replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>);
|
||||
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
|
||||
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
||||
replace_applicator_ =
|
||||
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
|
||||
specialize_transform_ =
|
||||
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
|
||||
|
||||
// Incorporation
|
||||
incorporate_getitem_set_ =
|
||||
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
||||
incorporate_getitem_from_param_ =
|
||||
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel);
|
||||
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
|
||||
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
|
||||
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
||||
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
|
||||
"incorporate_getitem_from_param", IsCNodeGraphKernel);
|
||||
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
|
||||
incorporate_call_switch_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
|
||||
|
||||
// Virtual Dataset
|
||||
virtual_dataset_eliminate_ =
|
||||
MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
|
||||
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
|
||||
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
|
||||
|
||||
// Convert
|
||||
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
|
||||
print_tuple_wrapper_ =
|
||||
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
|
||||
|
||||
// Unused parameter eliminate
|
||||
unused_parameter_eliminate_ =
|
||||
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel);
|
||||
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel);
|
||||
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
|
||||
unused_output_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// AddN eliminate
|
||||
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel);
|
||||
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// Mark interface fusion
|
||||
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
|
||||
mark_interface_fusion_ =
|
||||
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
|
||||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
|
||||
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
|
||||
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
|
||||
grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -17,15 +17,16 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/irpass/prim_eliminate.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/irpass/prim_eliminate.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
|
|||
FuncGraphPtr all_reduce_fg_{nullptr};
|
||||
};
|
||||
|
||||
class ArithmeticSimplify {
|
||||
class ArithmeticSimplify : public OptimizerCaller {
|
||||
public:
|
||||
ArithmeticSimplify()
|
||||
: multiply_by_zero_or_one_(),
|
||||
tensor_multiply_by_one_(),
|
||||
add_by_zero_(),
|
||||
tensor_add_by_zero_(),
|
||||
identity_(prim::kPrimIdentity),
|
||||
opt_update_zero_tensor_(),
|
||||
constant_duplicate_mul_(),
|
||||
power_one_() {
|
||||
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
|
||||
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
|
||||
add_by_zero_(std::make_shared<AddByZero>()),
|
||||
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
|
||||
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
|
||||
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
|
||||
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
|
||||
power_one_(std::make_shared<PowerOneEliminate>()) {
|
||||
eliminaters_.emplace_back(multiply_by_zero_or_one_);
|
||||
eliminaters_.emplace_back(tensor_multiply_by_one_);
|
||||
eliminaters_.emplace_back(add_by_zero_);
|
||||
|
@ -761,10 +762,10 @@ class ArithmeticSimplify {
|
|||
}
|
||||
~ArithmeticSimplify() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = eliminater(optimizer, node);
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
|
@ -773,15 +774,9 @@ class ArithmeticSimplify {
|
|||
}
|
||||
|
||||
private:
|
||||
MultiplyByZeroOrOne multiply_by_zero_or_one_;
|
||||
TensorMultiplyByOne tensor_multiply_by_one_;
|
||||
AddByZero add_by_zero_;
|
||||
TensorAddByZero tensor_add_by_zero_;
|
||||
PrimEliminater identity_;
|
||||
OptUpdateZeroTensor opt_update_zero_tensor_;
|
||||
ConstantDuplicateMul constant_duplicate_mul_;
|
||||
PowerOneEliminate power_one_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_,
|
||||
opt_update_zero_tensor_, constant_duplicate_mul_, power_one_;
|
||||
std::vector<OptimizerCallerPtr> eliminaters_{};
|
||||
};
|
||||
|
||||
// Arithmetic Simplifications should be done after step_parallel.
|
||||
|
@ -789,15 +784,17 @@ class ArithmeticSimplify {
|
|||
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
|
||||
// shape of the constant tensor should also be changed. So this pass is seperated from
|
||||
// ArithmeticSimplify and deferred until step_parallel.
|
||||
class ArithmeticSimplify2 {
|
||||
class ArithmeticSimplify2 : public OptimizerCaller {
|
||||
public:
|
||||
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); }
|
||||
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
|
||||
eliminaters_.emplace_back(tensor_multiply_by_zero_);
|
||||
}
|
||||
~ArithmeticSimplify2() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = eliminater(optimizer, node);
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
|
@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
|
|||
}
|
||||
|
||||
private:
|
||||
TensorMultiplyByZero tensor_multiply_by_zero_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
OptimizerCallerPtr tensor_multiply_by_zero_;
|
||||
std::vector<OptimizerCallerPtr> eliminaters_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
||||
|
||||
#include "ir/visitor.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr}, t_{nullptr};
|
||||
};
|
||||
|
||||
class CastEliminater {
|
||||
class CastEliminater : public OptimizerCaller {
|
||||
public:
|
||||
CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {}
|
||||
~CastEliminater() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
auto new_node = cast_same_type_eliminater_(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
|
|
|
@ -17,18 +17,19 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
|
|||
bool is_match_{false};
|
||||
};
|
||||
|
||||
class EnvGetItemEliminater {
|
||||
class EnvGetItemEliminater : public OptimizerCaller {
|
||||
public:
|
||||
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() {
|
||||
EnvGetItemEliminater()
|
||||
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
|
||||
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
|
||||
env_get_set_item_(std::make_shared<EnvGetSetItem>()) {
|
||||
eliminaters_.emplace_back(new_env_get_item_);
|
||||
eliminaters_.emplace_back(add_env_get_item_);
|
||||
eliminaters_.emplace_back(env_get_set_item_);
|
||||
}
|
||||
~EnvGetItemEliminater() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = eliminater(optimizer, node);
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
|
@ -246,10 +250,8 @@ class EnvGetItemEliminater {
|
|||
}
|
||||
|
||||
private:
|
||||
NewEnvGetItem new_env_get_item_;
|
||||
AddEnvGetItem add_env_get_item_;
|
||||
EnvGetSetItem env_get_set_item_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_;
|
||||
std::vector<OptimizerCallerPtr> eliminaters_{};
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
|
||||
|
|
|
@ -17,18 +17,20 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
internal::GetitemTransform getitem_transform_;
|
||||
};
|
||||
|
||||
class IncorporateGetitemSet {
|
||||
class IncorporateGetitemSet : public OptimizerCaller {
|
||||
public:
|
||||
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() {
|
||||
IncorporateGetitemSet()
|
||||
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
|
||||
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
|
||||
eliminaters_.emplace_back(incorporate_getitem_);
|
||||
eliminaters_.emplace_back(incorporate_getitem_switch_);
|
||||
}
|
||||
~IncorporateGetitemSet() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = eliminater(optimizer, node);
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
|
@ -403,9 +407,8 @@ class IncorporateGetitemSet {
|
|||
}
|
||||
|
||||
private:
|
||||
IncorporateGetitem incorporate_getitem_;
|
||||
IncorporateGetitemSwitch incorporate_getitem_switch_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
|
||||
std::vector<OptimizerCallerPtr> eliminaters_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -17,13 +17,15 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
|
||||
};
|
||||
|
||||
class ItemTupleEliminater {
|
||||
class ItemTupleEliminater : public OptimizerCaller {
|
||||
public:
|
||||
ItemTupleEliminater()
|
||||
: get_item_eliminater_(),
|
||||
get_item_const_eliminater_(),
|
||||
set_item_eliminater_(),
|
||||
get_set_item_eliminater_(),
|
||||
get_item_depend_reorder_() {
|
||||
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
|
||||
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
|
||||
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
|
||||
get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
|
||||
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
|
||||
eliminaters_.emplace_back(get_item_eliminater_);
|
||||
eliminaters_.emplace_back(get_item_const_eliminater_);
|
||||
eliminaters_.emplace_back(set_item_eliminater_);
|
||||
|
@ -277,10 +279,10 @@ class ItemTupleEliminater {
|
|||
}
|
||||
~ItemTupleEliminater() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = eliminater(optimizer, node);
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
|
@ -289,12 +291,9 @@ class ItemTupleEliminater {
|
|||
}
|
||||
|
||||
private:
|
||||
GetitemEliminater get_item_eliminater_;
|
||||
GetitemConstEliminater get_item_const_eliminater_;
|
||||
SetitemEliminater set_item_eliminater_;
|
||||
GetSetitemEliminater get_set_item_eliminater_;
|
||||
GetitemDependReorder get_item_depend_reorder_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
|
||||
get_item_depend_reorder_;
|
||||
std::vector<OptimizerCallerPtr> eliminaters_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -19,9 +19,9 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -19,11 +19,12 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "operator/ops.h"
|
||||
#include "pipeline/static_analysis/dshape.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr}, shape_{nullptr};
|
||||
};
|
||||
|
||||
class ReshapeEliminater {
|
||||
class ReshapeEliminater : public OptimizerCaller {
|
||||
public:
|
||||
ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {}
|
||||
~ReshapeEliminater() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
auto new_node = reshape_same_shape_eliminater_(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
|
|
|
@ -18,31 +18,31 @@
|
|||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
|
||||
|
||||
#include <securec.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "optimizer/irpass/prim_eliminate.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/irpass/prim_eliminate.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
class SpecialOpEliminater {
|
||||
class SpecialOpEliminater : public OptimizerCaller {
|
||||
public:
|
||||
SpecialOpEliminater()
|
||||
: insert_gradient_of_(prim::kPrimInsertGradientOf),
|
||||
stop_gradient_(prim::kPrimStopGradient),
|
||||
hook_backward_(prim::kPrimHookBackward),
|
||||
print_shape_type_(prim::kPrimPrintShapeType),
|
||||
get_ref_value_(prim::kPrimGetRefValue),
|
||||
mirror_(prim::kPrimMirror),
|
||||
virtual_div_(prim::kPrimVirtualDiv) {
|
||||
: insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
|
||||
stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
|
||||
hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
|
||||
print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
|
||||
get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
|
||||
mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
|
||||
virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
|
||||
eliminaters_.emplace_back(insert_gradient_of_);
|
||||
eliminaters_.emplace_back(stop_gradient_);
|
||||
eliminaters_.emplace_back(hook_backward_);
|
||||
|
@ -53,10 +53,10 @@ class SpecialOpEliminater {
|
|||
}
|
||||
~SpecialOpEliminater() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = eliminater(optimizer, node);
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
|
@ -65,9 +65,9 @@ class SpecialOpEliminater {
|
|||
}
|
||||
|
||||
private:
|
||||
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
|
||||
OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
|
||||
virtual_div_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
std::vector<OptimizerCallerPtr> eliminaters_{};
|
||||
};
|
||||
|
||||
// {PrimVirtualDataset, X} -> X
|
||||
|
|
|
@ -16,28 +16,27 @@
|
|||
|
||||
#include "optimizer/opt.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <deque>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/manager.h"
|
||||
#include "utils/ordered_set.h"
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ordered_set.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &renorm_action) {
|
||||
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
|
||||
auto fn = [prims](const AnfNodePtr &node) -> bool {
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
|
|||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
|
||||
}
|
||||
|
||||
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
|
||||
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t = GetTime();
|
||||
#endif
|
||||
AnfNodePtr result = transform_(optimizer, node);
|
||||
AnfNodePtr result = (*transform_)(optimizer, node);
|
||||
#ifdef ENABLE_PROFILE
|
||||
if (optimizer != nullptr) {
|
||||
auto time = GetTime();
|
||||
|
|
|
@ -17,24 +17,18 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
class Optimizer;
|
||||
|
||||
using OptimizerPtr = std::shared_ptr<Optimizer>;
|
||||
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
|
||||
|
||||
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
|
||||
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
|
||||
|
||||
// Define the interaction mode between an Optimize pass and Renormalize pass
|
||||
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
|
||||
|
@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
|
|||
|
||||
class Substitution {
|
||||
public:
|
||||
TransformFuncType transform_{nullptr};
|
||||
OptimizerCallerPtr transform_;
|
||||
std::string name_;
|
||||
PredicateFuncType predicate_{nullptr};
|
||||
// an enum to mark this Substitution relation to renormalize pass
|
||||
RenormAction renorm_action_;
|
||||
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
|
||||
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
|
||||
const RenormAction &renorm_action)
|
||||
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
|
||||
~Substitution() = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
|
||||
};
|
||||
|
||||
using SubstitutionPtr = std::shared_ptr<Substitution>;
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &action_renorm = CHECK_RENORM);
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims,
|
||||
const RenormAction &action_renorm = CHECK_RENORM);
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
|
||||
|
||||
class SubstitutionList {
|
||||
|
|
|
@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
|
|||
};
|
||||
|
||||
void SetUp() {
|
||||
elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd);
|
||||
elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R);
|
||||
idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P);
|
||||
Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q);
|
||||
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
|
||||
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
|
||||
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
|
||||
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
|
||||
}
|
||||
|
||||
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
|
||||
|
|
Loading…
Reference in New Issue