diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc index 6ce2d3a72ab..1c139c36602 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc @@ -26,6 +26,28 @@ namespace mindspore { namespace opt { +bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + auto shape = in->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->shape().size() != 0) { + return false; + } + auto dtype = in->Type(); + if (dtype->type_id() != kObjectTypeTensorType) { + return false; + } + auto element_type = dyn_cast(dtype)->element()->type_id(); + if (element_type != kNumberTypeFloat32) { + return false; + } + return true; + } + return false; +} + const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); VectorRef apply_momentum = diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h index c9112ab6e95..8888f40c7b1 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h @@ -18,13 +18,14 @@ #include #include "backend/optimizer/common/optimizer.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore { namespace opt { class ApplyMomentumScaleFusion : public PatternProcessPass { public: explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { - scale_ = std::make_shared(); + scale_ = std::make_shared(IsScalar); variable_ = std::make_shared(); accumulation_ = std::make_shared(); learning_rate_ = std::make_shared(); @@ -36,6 +37,8 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: + static bool IsScalar(const BaseRef &n); + VarPtr scale_; VarPtr variable_; VarPtr accumulation_; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc index 9e235a756f2..743015c50cd 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc @@ -26,6 +26,28 @@ namespace mindspore { namespace opt { +bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + auto shape = in->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->shape().size() != 0) { + return false; + } + auto dtype = in->Type(); + if (dtype->type_id() != kObjectTypeTensorType) { + return false; + } + auto element_type = dyn_cast(dtype)->element()->type_id(); + if (element_type != kNumberTypeFloat32) { + return false; + } + return true; + } + return false; +} + const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const { VectorRef weight = VectorRef( {prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h index f047881d810..c1b92c8242b 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h @@ -26,7 +26,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { weight_decay_ = std::make_shared(); - scale_ = std::make_shared(); + scale_ = std::make_shared(IsScalar); variable_ = std::make_shared(); accumulation_ = std::make_shared(); learning_rate_ = std::make_shared(); @@ -38,9 +38,10 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: + static bool IsScalar(const BaseRef &n); + VarPtr weight_decay_; VarPtr scale_; - VarPtr variable_; VarPtr accumulation_; VarPtr learning_rate_;