forked from mindspore-Ecosystem/mindspore
add 3 patterns for lamb_next_mv_with_decay_rule pass
This commit is contained in:
parent
ca74e624e2
commit
eaff850f11
|
@ -97,6 +97,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||
|
|
|
@ -163,5 +163,128 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||
|
||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
|
||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
||||
return add3;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
||||
return add5;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||
|
||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
|
||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
||||
return add3;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
||||
return add5;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||
|
||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
|
||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
||||
return add3;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]});
|
||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
||||
return add5;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,6 +74,36 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
|
|||
VarPtr add0_var_;
|
||||
VarPtr add1_var_;
|
||||
};
|
||||
|
||||
class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule {
|
||||
public:
|
||||
explicit LambNextMVWithDecayRuleCond1(bool multigraph = true)
|
||||
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {}
|
||||
|
||||
~LambNextMVWithDecayRuleCond1() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
|
||||
public:
|
||||
explicit LambNextMVWithDecayRuleCond2(bool multigraph = true)
|
||||
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {}
|
||||
|
||||
~LambNextMVWithDecayRuleCond2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
|
||||
public:
|
||||
explicit LambNextMVWithDecayRuleCond3(bool multigraph = true)
|
||||
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {}
|
||||
|
||||
~LambNextMVWithDecayRuleCond3() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
|
@ -36,6 +38,14 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) {
|
||||
if (!opt::IsUsedByOthers(graph, cnode)) {
|
||||
auto full_name = cnode->fullname_with_scope();
|
||||
// exclude lamb and adam, and only work in bert
|
||||
if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") ||
|
||||
std::string::npos == full_name.find("bert")) {
|
||||
MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion";
|
||||
return false;
|
||||
}
|
||||
|
||||
*mul = cnode;
|
||||
*mul_index = index;
|
||||
return true;
|
||||
|
@ -45,8 +55,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace opt {
|
||||
} // namespace
|
||||
const BaseRef MulAddFusion::DefinePattern() const {
|
||||
VarPtr x = std::make_shared<Var>();
|
||||
VarPtr y = std::make_shared<Var>();
|
||||
|
@ -74,7 +83,12 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
for (size_t index = 1; index < mul->size(); ++index) {
|
||||
inputs.push_back(mul->input(index));
|
||||
}
|
||||
inputs.push_back(add->input(add->size() - mul_index));
|
||||
auto another_input_node = add->input(add->size() - mul_index);
|
||||
if (IsUsedByOthers(graph, another_input_node)) {
|
||||
MS_LOG(INFO) << "Add's another input node is used by others, do not fuse";
|
||||
return nullptr;
|
||||
}
|
||||
inputs.push_back(another_input_node);
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
fusion_node->set_scope(add->scope());
|
||||
fusion_node->set_abstract(add->abstract());
|
||||
|
|
|
@ -253,5 +253,134 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1_un_match) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
DumpIR("fg.ir", fg, true);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "after");
|
||||
DumpIR("g_after.ir", g_after, true);
|
||||
DumpIR("new_graph.ir", new_graph, true);
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2_un_match) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
DumpIR("fg.ir", fg, true);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "after");
|
||||
DumpIR("g_after.ir", g_after, true);
|
||||
DumpIR("new_graph.ir", new_graph, true);
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3_un_match) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) {
|
|||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto scope = std::make_shared<Scope>("bert");
|
||||
for (auto nd : fg->execution_order()) {
|
||||
nd->set_scope(scope);
|
||||
}
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -57,6 +61,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) {
|
|||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto scope = std::make_shared<Scope>("bert");
|
||||
for (auto nd : fg->execution_order()) {
|
||||
nd->set_scope(scope);
|
||||
}
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
|
|
@ -174,3 +174,201 @@ def test_lamb_next_mv_with_decay_rule(tag):
|
|||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_lamb_next_mv_with_decay_rule_cond1(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(input0, constant_mul3_sub1)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
mul4 = Mul(constant_mul4_x, input6)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
add5 = Add(mul4, real_div4)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, add5)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6,
|
||||
constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
|
||||
constant_add2_y)
|
||||
outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1),
|
||||
tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(input0, constant_mul3_sub1)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
mul4 = Mul(constant_mul4_x, input6)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
add5 = Add(mul4, real_div4)
|
||||
# un match
|
||||
add3 = Add(real_div2, mul4)
|
||||
outputs = make_tuple(add3, add0, add1, add5)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_lamb_next_mv_with_decay_rule_cond2(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul1 = Mul(constant_mul1_sub, input3)
|
||||
mul0 = Mul(constant_mul0_x, input4)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(constant_mul2_x, input1)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
add4 = Add(constant_add2_y, sqrt1)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
mul4 = Mul(constant_mul4_x, input6)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
add5 = Add(mul4, real_div4)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, add5)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6,
|
||||
constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
|
||||
constant_add2_y)
|
||||
outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1),
|
||||
tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul1 = Mul(constant_mul1_sub, input3)
|
||||
mul0 = Mul(constant_mul0_x, input4)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(constant_mul2_x, input1)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
add4 = Add(constant_add2_y, sqrt1)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
mul4 = Mul(constant_mul4_x, input6)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
add5 = Add(mul4, real_div4)
|
||||
# un_match
|
||||
add3 = Add(real_div2, mul4)
|
||||
outputs = make_tuple(add3, add0, add1, add5)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_lamb_next_mv_with_decay_rule_cond3(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(real_div1, constant_add2_y)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
mul4 = Mul(input6, constant_mul4_x)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
add5 = Add(mul4, real_div4)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, add5)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6,
|
||||
constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
|
||||
constant_add2_y)
|
||||
outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1),
|
||||
tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(real_div1, constant_add2_y)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
mul4 = Mul(input6, constant_mul4_x)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
add5 = Add(mul4, real_div4)
|
||||
# un match
|
||||
add3 = Add(real_div2, mul4)
|
||||
outputs = make_tuple(add3, add0, add1, add5)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue