forked from mindspore-Ecosystem/mindspore
add newly 5 patterns for AdamApplyOneWithDecayRule fusion pass
This commit is contained in:
parent
7d763a9162
commit
ff05aa1faa
|
@ -99,6 +99,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
|
||||
|
@ -114,11 +115,15 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond4Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond1>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond3>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond4>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond5>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
||||
|
|
|
@ -41,24 +41,104 @@ std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ
|
|||
return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y};
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const {
|
||||
const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const {
|
||||
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_});
|
||||
VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_});
|
||||
VectorRef square0_pattern({prim::kPrimSquare, input0_});
|
||||
VectorRef add0_pattern({add0_var_, mul0_pattern, mul1_pattern});
|
||||
VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_});
|
||||
VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern});
|
||||
VectorRef add1_pattern({add1_var_, mul2_pattern, mul3_pattern});
|
||||
VectorRef sqrt0_pattern({sqrt, add1_pattern});
|
||||
VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_});
|
||||
VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_});
|
||||
VectorRef real_div_pattern({real_div, add0_pattern, add2_pattern});
|
||||
VectorRef add3_pattern({prim::kPrimTensorAdd, real_div_pattern, mul4_pattern});
|
||||
VectorRef mul5_pattern({prim::kPrimMul, input4_, add3_pattern});
|
||||
VectorRef sub0_pattern({prim::kPrimSub, input3_, mul5_pattern});
|
||||
return sub0_pattern;
|
||||
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
||||
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
||||
VectorRef square0({prim::kPrimSquare, input0_});
|
||||
VectorRef add0({add0_var_, mul0, mul1});
|
||||
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||
VectorRef add1({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0({sqrt, add1});
|
||||
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||
VectorRef real_div0({real_div, add0, add2});
|
||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
||||
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||
return sub0;
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const {
|
||||
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul0({prim::kPrimMul, input2_, mul0_x_});
|
||||
VectorRef mul1({prim::kPrimMul, input0_, mul1_x_});
|
||||
VectorRef square0({prim::kPrimSquare, input0_});
|
||||
VectorRef add0({add0_var_, mul0, mul1});
|
||||
VectorRef mul2({prim::kPrimMul, input1_, mul2_x_});
|
||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||
VectorRef add1({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0({sqrt, add1});
|
||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
||||
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
||||
VectorRef real_div0({real_div, add0, add2});
|
||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||
return sub0;
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const {
|
||||
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
||||
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
||||
VectorRef square0({prim::kPrimSquare, input0_});
|
||||
VectorRef add0({add0_var_, mul0, mul1});
|
||||
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
||||
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
||||
VectorRef add1({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0({sqrt, add1});
|
||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||
VectorRef real_div0({real_div, add0, add2});
|
||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||
return sub0;
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const {
|
||||
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
||||
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
||||
VectorRef square0({prim::kPrimSquare, input0_});
|
||||
VectorRef add0({add0_var_, mul0, mul1});
|
||||
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||
VectorRef add1({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0({sqrt, add1});
|
||||
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||
VectorRef real_div0({real_div, add0, add2});
|
||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||
return sub0;
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
|
||||
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
||||
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
||||
VectorRef square0({prim::kPrimSquare, input0_});
|
||||
VectorRef add0({add0_var_, mul0, mul1});
|
||||
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||
VectorRef add1({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0({sqrt, add1});
|
||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||
VectorRef real_div0({real_div, add0, add2});
|
||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||
return sub0;
|
||||
}
|
||||
|
||||
const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
|
|
|
@ -18,14 +18,15 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AdamApplyOneWithDecayRule : public PatternProcessPass {
|
||||
public:
|
||||
explicit AdamApplyOneWithDecayRule(bool multigraph = true)
|
||||
: PatternProcessPass("adam_apply_one_with_decay_rule", multigraph) {
|
||||
explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph) {
|
||||
input0_ = std::make_shared<Var>();
|
||||
input1_ = std::make_shared<Var>();
|
||||
input2_ = std::make_shared<Var>();
|
||||
|
@ -41,10 +42,10 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
|
|||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
}
|
||||
~AdamApplyOneWithDecayRule() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefinePattern() const override = 0;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv) const;
|
||||
VarPtr input0_;
|
||||
VarPtr input1_;
|
||||
|
@ -60,6 +61,51 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
|
|||
VarPtr add0_var_;
|
||||
VarPtr add1_var_;
|
||||
};
|
||||
|
||||
class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule {
|
||||
public:
|
||||
explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true)
|
||||
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {}
|
||||
|
||||
~AdamApplyOneWithDecayRuleCond1() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule {
|
||||
public:
|
||||
explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true)
|
||||
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {}
|
||||
|
||||
~AdamApplyOneWithDecayRuleCond2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule {
|
||||
public:
|
||||
explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true)
|
||||
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {}
|
||||
|
||||
~AdamApplyOneWithDecayRuleCond3() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule {
|
||||
public:
|
||||
explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true)
|
||||
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {}
|
||||
|
||||
~AdamApplyOneWithDecayRuleCond4() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule {
|
||||
public:
|
||||
explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true)
|
||||
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {}
|
||||
|
||||
~AdamApplyOneWithDecayRuleCond5() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_
|
||||
|
|
|
@ -30,8 +30,8 @@ class TestHWOptimizeAdamApplyOneWithDecayRule : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before");
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "before");
|
||||
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
|
@ -43,16 +43,16 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond1>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_no_match) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "no_match");
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "before");
|
||||
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
|
@ -61,15 +61,78 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_no_match) {
|
|||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond2>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "before");
|
||||
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 11; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond3>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "before");
|
||||
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 11; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "before");
|
||||
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 11; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond5>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -89,3 +89,168 @@ def test_adam_apply_one_with_decay_rule(tag):
|
|||
return make_tuple(make_tuple(item0, item1, item2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adam_apply_one_with_decay_rule_cond1(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
mul0 = mul(mul0_x, input2)
|
||||
mul1 = mul(mul1_x, input0)
|
||||
square0 = square(input0)
|
||||
add0 = add(mul0, mul1)
|
||||
mul2 = mul(mul2_x, input1)
|
||||
mul3 = mul(mul3_x, square0)
|
||||
add1 = add(mul2, mul3)
|
||||
sqrt0 = sqrt(add1)
|
||||
add2 = add(add2_y, sqrt0)
|
||||
mul4 = mul(mul4_x, input3)
|
||||
real_div0 = real_div(add0, add2)
|
||||
add3 = add(mul4, real_div0)
|
||||
mul5 = mul(input4, add3)
|
||||
sub0 = sub(input3, mul5)
|
||||
return make_tuple(add1, add0, sub0)
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
|
||||
add2_y)
|
||||
item0 = tuple_getitem(res, 0)
|
||||
item1 = tuple_getitem(res, 1)
|
||||
item2 = tuple_getitem(res, 2)
|
||||
return make_tuple(make_tuple(item0, item1, item2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adam_apply_one_with_decay_rule_cond2(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
mul0 = mul(input2, mul0_x)
|
||||
mul1 = mul(input0, mul1_x)
|
||||
square0 = square(input0)
|
||||
add0 = add(mul0, mul1)
|
||||
mul2 = mul(input1, mul2_x)
|
||||
mul3 = mul(mul3_x, square0)
|
||||
add1 = add(mul2, mul3)
|
||||
sqrt0 = sqrt(add1)
|
||||
add2 = add(sqrt0, add2_y)
|
||||
mul4 = mul(input3, mul4_x)
|
||||
real_div0 = real_div(add0, add2)
|
||||
add3 = add(mul4, real_div0)
|
||||
mul5 = mul(add3, input4)
|
||||
sub0 = sub(input3, mul5)
|
||||
return make_tuple(add1, add0, sub0)
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
|
||||
add2_y)
|
||||
item0 = tuple_getitem(res, 0)
|
||||
item1 = tuple_getitem(res, 1)
|
||||
item2 = tuple_getitem(res, 2)
|
||||
return make_tuple(make_tuple(item0, item1, item2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adam_apply_one_with_decay_rule_cond3(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
mul0 = mul(mul0_x, input2)
|
||||
mul1 = mul(mul1_x, input0)
|
||||
square0 = square(input0)
|
||||
add0 = add(mul0, mul1)
|
||||
mul2 = mul(mul2_x, input1)
|
||||
mul3 = mul(square0, mul3_x)
|
||||
add1 = add(mul2, mul3)
|
||||
sqrt0 = sqrt(add1)
|
||||
add2 = add(sqrt0, add2_y)
|
||||
mul4 = mul(mul4_x, input3)
|
||||
real_div0 = real_div(add0, add2)
|
||||
add3 = add(mul4, real_div0)
|
||||
mul5 = mul(add3, input4)
|
||||
sub0 = sub(input3, mul5)
|
||||
return make_tuple(add1, add0, sub0)
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
|
||||
add2_y)
|
||||
item0 = tuple_getitem(res, 0)
|
||||
item1 = tuple_getitem(res, 1)
|
||||
item2 = tuple_getitem(res, 2)
|
||||
return make_tuple(make_tuple(item0, item1, item2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adam_apply_one_with_decay_rule_cond4(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
mul0 = mul(mul0_x, input2)
|
||||
mul1 = mul(mul1_x, input0)
|
||||
square0 = square(input0)
|
||||
add0 = add(mul0, mul1)
|
||||
mul2 = mul(mul2_x, input1)
|
||||
mul3 = mul(mul3_x, square0)
|
||||
add1 = add(mul2, mul3)
|
||||
sqrt0 = sqrt(add1)
|
||||
add2 = add(add2_y, sqrt0)
|
||||
mul4 = mul(mul4_x, input3)
|
||||
real_div0 = real_div(add0, add2)
|
||||
add3 = add(mul4, real_div0)
|
||||
mul5 = mul(add3, input4)
|
||||
sub0 = sub(input3, mul5)
|
||||
return make_tuple(add1, add0, sub0)
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
|
||||
add2_y)
|
||||
item0 = tuple_getitem(res, 0)
|
||||
item1 = tuple_getitem(res, 1)
|
||||
item2 = tuple_getitem(res, 2)
|
||||
return make_tuple(make_tuple(item0, item1, item2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adam_apply_one_with_decay_rule_cond5(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
mul0 = mul(mul0_x, input2)
|
||||
mul1 = mul(mul1_x, input0)
|
||||
square0 = square(input0)
|
||||
add0 = add(mul0, mul1)
|
||||
mul2 = mul(mul2_x, input1)
|
||||
mul3 = mul(mul3_x, square0)
|
||||
add1 = add(mul2, mul3)
|
||||
sqrt0 = sqrt(add1)
|
||||
add2 = add(sqrt0, add2_y)
|
||||
mul4 = mul(mul4_x, input3)
|
||||
real_div0 = real_div(add0, add2)
|
||||
add3 = add(mul4, real_div0)
|
||||
mul5 = mul(add3, input4)
|
||||
sub0 = sub(input3, mul5)
|
||||
return make_tuple(add1, add0, sub0)
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
|
||||
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
|
||||
add2_y)
|
||||
item0 = tuple_getitem(res, 0)
|
||||
item1 = tuple_getitem(res, 1)
|
||||
item2 = tuple_getitem(res, 2)
|
||||
return make_tuple(make_tuple(item0, item1, item2))
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue