forked from mindspore-Ecosystem/mindspore
!7925 fix gpu momentum fusion
Merge pull request !7925 from chenweifeng/momentum-fusion-fix
This commit is contained in:
commit
ea10c7a146
|
@ -26,6 +26,28 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
|
||||||
|
if (utils::isa<AnfNodePtr>(n)) {
|
||||||
|
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||||
|
MS_EXCEPTION_IF_NULL(in);
|
||||||
|
auto shape = in->Shape()->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->shape().size() != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto dtype = in->Type();
|
||||||
|
if (dtype->type_id() != kObjectTypeTensorType) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||||
|
if (element_type != kNumberTypeFloat32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
|
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
|
||||||
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
|
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
|
||||||
VectorRef apply_momentum =
|
VectorRef apply_momentum =
|
||||||
|
|
|
@ -18,13 +18,14 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ApplyMomentumScaleFusion : public PatternProcessPass {
|
class ApplyMomentumScaleFusion : public PatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
|
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
|
||||||
scale_ = std::make_shared<Var>();
|
scale_ = std::make_shared<CondVar>(IsScalar);
|
||||||
variable_ = std::make_shared<Var>();
|
variable_ = std::make_shared<Var>();
|
||||||
accumulation_ = std::make_shared<Var>();
|
accumulation_ = std::make_shared<Var>();
|
||||||
learning_rate_ = std::make_shared<Var>();
|
learning_rate_ = std::make_shared<Var>();
|
||||||
|
@ -36,6 +37,8 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
static bool IsScalar(const BaseRef &n);
|
||||||
|
|
||||||
VarPtr scale_;
|
VarPtr scale_;
|
||||||
VarPtr variable_;
|
VarPtr variable_;
|
||||||
VarPtr accumulation_;
|
VarPtr accumulation_;
|
||||||
|
|
|
@ -26,6 +26,28 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
|
||||||
|
if (utils::isa<AnfNodePtr>(n)) {
|
||||||
|
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||||
|
MS_EXCEPTION_IF_NULL(in);
|
||||||
|
auto shape = in->Shape()->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->shape().size() != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto dtype = in->Type();
|
||||||
|
if (dtype->type_id() != kObjectTypeTensorType) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||||
|
if (element_type != kNumberTypeFloat32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
|
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
|
||||||
VectorRef weight = VectorRef(
|
VectorRef weight = VectorRef(
|
||||||
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
|
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
|
||||||
|
|
|
@ -26,7 +26,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
||||||
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
|
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
|
||||||
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
|
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
|
||||||
weight_decay_ = std::make_shared<Var>();
|
weight_decay_ = std::make_shared<Var>();
|
||||||
scale_ = std::make_shared<Var>();
|
scale_ = std::make_shared<CondVar>(IsScalar);
|
||||||
variable_ = std::make_shared<Var>();
|
variable_ = std::make_shared<Var>();
|
||||||
accumulation_ = std::make_shared<Var>();
|
accumulation_ = std::make_shared<Var>();
|
||||||
learning_rate_ = std::make_shared<Var>();
|
learning_rate_ = std::make_shared<Var>();
|
||||||
|
@ -38,9 +38,10 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
static bool IsScalar(const BaseRef &n);
|
||||||
|
|
||||||
VarPtr weight_decay_;
|
VarPtr weight_decay_;
|
||||||
VarPtr scale_;
|
VarPtr scale_;
|
||||||
|
|
||||||
VarPtr variable_;
|
VarPtr variable_;
|
||||||
VarPtr accumulation_;
|
VarPtr accumulation_;
|
||||||
VarPtr learning_rate_;
|
VarPtr learning_rate_;
|
||||||
|
|
Loading…
Reference in New Issue