!7925 fix gpu momentum fusion

Merge pull request !7925 from chenweifeng/momentum-fusion-fix
This commit is contained in:
mindspore-ci-bot 2020-10-29 10:41:15 +08:00 committed by Gitee
commit ea10c7a146
4 changed files with 51 additions and 3 deletions

View File

@ -26,6 +26,28 @@
namespace mindspore {
namespace opt {
bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
VectorRef apply_momentum =

View File

@ -18,13 +18,14 @@
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
class ApplyMomentumScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
scale_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
@ -36,6 +37,8 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
static bool IsScalar(const BaseRef &n);
VarPtr scale_;
VarPtr variable_;
VarPtr accumulation_;

View File

@ -26,6 +26,28 @@
namespace mindspore {
namespace opt {
bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
VectorRef weight = VectorRef(
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});

View File

@ -26,7 +26,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
weight_decay_ = std::make_shared<Var>();
scale_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
@ -38,9 +38,10 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
static bool IsScalar(const BaseRef &n);
VarPtr weight_decay_;
VarPtr scale_;
VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;