forked from mindspore-Ecosystem/mindspore
refactor multiple patterns pass
This commit is contained in:
parent
964a757db2
commit
f16ff539ba
|
@ -99,11 +99,11 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
|
||||
|
|
|
@ -27,82 +27,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
std::tuple<CNodePtr, CNodePtr, AnfNodePtr> GetSharedNodesByPattern(const AnfNodePtr &node) {
|
||||
auto add3_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kAddInputNum);
|
||||
MS_EXCEPTION_IF_NULL(add3_cnode);
|
||||
auto real_div2_cnode = CheckAnfNodeIfCNodeAndInputSize(add3_cnode->input(1), kMulInputNum);
|
||||
MS_EXCEPTION_IF_NULL(real_div2_cnode);
|
||||
auto real_div0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(1), kRealDivInputNum);
|
||||
MS_EXCEPTION_IF_NULL(real_div0_cnode);
|
||||
auto sqrt0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(2), kSqrtInputNum);
|
||||
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
|
||||
auto add2_cnode = CheckAnfNodeIfCNodeAndInputSize(sqrt0_cnode->input(1), kAddInputNum);
|
||||
MS_EXCEPTION_IF_NULL(add2_cnode);
|
||||
auto real_div1_cnode = CheckAnfNodeIfCNodeAndInputSize(add2_cnode->input(1), kRealDivInputNum);
|
||||
auto constant_add2_y = add2_cnode->input(2);
|
||||
|
||||
return std::make_tuple(real_div0_cnode, real_div1_cnode, constant_add2_y);
|
||||
}
|
||||
|
||||
bool MatchRealDiv4(const AnfNodePtr &real_div4, const AnfNodePtr &real_div1, const AnfNodePtr &constant_add2_y) {
|
||||
if (real_div4 == nullptr || !real_div4->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto real_div4_cnode = real_div4->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div4_cnode);
|
||||
if (AnfAlgo::GetCNodeName(real_div4_cnode) != kRealDivOpName || real_div4_cnode->inputs().size() < kRealDivInputNum) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr add4_cnode = nullptr;
|
||||
if (!CheckIfCNodeAndInputSize(real_div4_cnode->input(2), kAddInputNum, &add4_cnode) ||
|
||||
AnfAlgo::GetCNodeName(add4_cnode) != prim::kPrimTensorAdd->name()) {
|
||||
return false;
|
||||
}
|
||||
CNodePtr sqrt1_cnode = nullptr;
|
||||
if (!CheckIfCNodeAndInputSize(add4_cnode->input(1), kSqrtInputNum, &sqrt1_cnode) ||
|
||||
AnfAlgo::GetCNodeName(sqrt1_cnode) != kSqrtOpName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(add4_cnode->input(2));
|
||||
MS_EXCEPTION_IF_NULL(constant_add2_y);
|
||||
return sqrt1_cnode->input(1) == real_div1 && *(add4_cnode->input(2)) == *constant_add2_y;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef LambNextMVRule::DefinePattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
|
||||
auto mul0 = VectorRef({prim::kPrimMul, input_varptr_[7], input_varptr_[4]});
|
||||
auto mul1 = VectorRef({prim::kPrimMul, input_varptr_[8], input_varptr_[3]});
|
||||
auto mul2 = VectorRef({prim::kPrimMul, input_varptr_[9], input_varptr_[1]});
|
||||
auto mul3 = VectorRef({prim::kPrimMul, input_varptr_[10], input_varptr_[0]});
|
||||
auto mul4 = VectorRef({prim::kPrimMul, input_varptr_[11], input_varptr_[6]});
|
||||
auto add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1});
|
||||
auto add1 = VectorRef({prim::kPrimTensorAdd, mul2, mul3});
|
||||
|
||||
auto real_div0 = VectorRef({prim_deal_div, add0, input_varptr_[5]});
|
||||
auto real_div1 = VectorRef({prim_deal_div, add1, input_varptr_[2]});
|
||||
|
||||
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, input_varptr_[12]});
|
||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
auto real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
||||
|
||||
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
||||
}
|
||||
|
||||
bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
|
||||
std::vector<AnfNodePtr> *old_pattern_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
CNodePtr real_div0 = nullptr;
|
||||
CNodePtr real_div1 = nullptr;
|
||||
AnfNodePtr constant_add2_y = nullptr;
|
||||
std::tie(real_div0, real_div1, constant_add2_y) = GetSharedNodesByPattern(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_);
|
||||
auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_);
|
||||
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -112,19 +42,17 @@ bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNode
|
|||
}
|
||||
AnfNodeIndexSet real_div0_outputs = users[real_div0];
|
||||
auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(),
|
||||
[&node, &real_div1, &constant_add2_y](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return node_index.first != node && node_index.second == 1 &&
|
||||
MatchRealDiv4(node_index.first, real_div1, constant_add2_y);
|
||||
[&real_div2, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return node_index.first != real_div2 && node_index.second == 1 &&
|
||||
MatchAnotherPattern(node_index.first, equiv);
|
||||
});
|
||||
if (iter == real_div0_outputs.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto add0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div0->input(1), kAddInputNum);
|
||||
auto add1_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div1->input(1), kAddInputNum);
|
||||
(*old_pattern_outputs).push_back(node);
|
||||
(*old_pattern_outputs).push_back(add0_cnode);
|
||||
(*old_pattern_outputs).push_back(add1_cnode);
|
||||
(*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_));
|
||||
(*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_));
|
||||
(*old_pattern_outputs).push_back(iter->first);
|
||||
|
||||
return true;
|
||||
|
@ -136,8 +64,19 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto prim = std::make_shared<Primitive>(kLambNextMVOpName);
|
||||
std::vector<AnfNodePtr> lamb_next_mv_rule_inputs = {NewValueNode(prim)};
|
||||
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(lamb_next_mv_rule_inputs),
|
||||
[&equiv](const VarPtr &in) { return utils::cast<AnfNodePtr>((*equiv)[in]); });
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input0_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input1_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input2_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input3_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input4_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input5_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input6_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul0_x_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul1_sub_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul2_x_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
|
||||
auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
|
||||
MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
|
||||
|
||||
|
@ -162,14 +101,60 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
|
|||
return lamb_next_mv_rule_outputs[0];
|
||||
}
|
||||
|
||||
bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
|
||||
return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) &&
|
||||
IsSameNode(equiv1, equiv2, add2_y_);
|
||||
}
|
||||
|
||||
const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
std::vector<AnfNodePtr> old_pattern_outputs;
|
||||
if (!IsRuleMatched(func_graph, node, &old_pattern_outputs)) {
|
||||
if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVRuleCond4::DefinePattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
|
||||
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
|
||||
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
|
||||
auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
|
||||
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
|
||||
auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
|
||||
auto add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
auto add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
|
||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||
|
||||
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
|
||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
|
||||
|
||||
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
||||
}
|
||||
|
||||
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_real_div);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
// Two patterns share: real_div0, real_div1, add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||
return real_div4;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,23 +29,71 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class LambNextMVRule : public PatternProcessPass {
|
||||
class LambNextMVRule : public MultipleOutputPatternProcessPass {
|
||||
public:
|
||||
explicit LambNextMVRule(bool multigraph = true) : PatternProcessPass("lamb_next_mv_rule", multigraph) {
|
||||
for (size_t i = 0; i < kLambNextMVRuleInputNum - 1; ++i) {
|
||||
input_varptr_.push_back(std::make_shared<Var>());
|
||||
}
|
||||
explicit LambNextMVRule(const std::string &name = "", bool multigraph = true)
|
||||
: MultipleOutputPatternProcessPass(name, multigraph) {
|
||||
input0_ = std::make_shared<Var>();
|
||||
input1_ = std::make_shared<Var>();
|
||||
input2_ = std::make_shared<Var>();
|
||||
input3_ = std::make_shared<Var>();
|
||||
input4_ = std::make_shared<Var>();
|
||||
input5_ = std::make_shared<Var>();
|
||||
input6_ = std::make_shared<Var>();
|
||||
mul0_x_ = std::make_shared<Var>();
|
||||
mul1_sub_ = std::make_shared<Var>();
|
||||
mul2_x_ = std::make_shared<Var>();
|
||||
mul3_sub1_ = std::make_shared<Var>();
|
||||
mul4_x_ = std::make_shared<Var>();
|
||||
add2_y_ = std::make_shared<Var>();
|
||||
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||
real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
|
||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
}
|
||||
~LambNextMVRule() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefinePattern() const override = 0;
|
||||
BaseRef DefineAnotherPattern() const override = 0;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override;
|
||||
|
||||
private:
|
||||
std::vector<VarPtr> input_varptr_;
|
||||
bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
protected:
|
||||
bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
|
||||
std::vector<AnfNodePtr> *old_pattern_outputs) const;
|
||||
AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &old_pattern_outputs,
|
||||
const EquivPtr &equiv) const;
|
||||
|
||||
VarPtr input0_;
|
||||
VarPtr input1_;
|
||||
VarPtr input2_;
|
||||
VarPtr input3_;
|
||||
VarPtr input4_;
|
||||
VarPtr input5_;
|
||||
VarPtr input6_;
|
||||
VarPtr mul0_x_;
|
||||
VarPtr mul1_sub_;
|
||||
VarPtr mul2_x_;
|
||||
VarPtr mul3_sub1_;
|
||||
VarPtr mul4_x_;
|
||||
VarPtr add2_y_;
|
||||
// nodes which two patterns share, and add2_y_ also.
|
||||
VarPtr real_div0_var_;
|
||||
VarPtr real_div1_var_;
|
||||
// part of output nodes
|
||||
VarPtr add0_var_;
|
||||
VarPtr add1_var_;
|
||||
// other node
|
||||
VarPtr real_div2_var_;
|
||||
};
|
||||
|
||||
class LambNextMVRuleCond4 : public LambNextMVRule {
|
||||
public:
|
||||
explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {}
|
||||
|
||||
~LambNextMVRuleCond4() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,63 +79,6 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
|
|||
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRule::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
|
||||
return add5;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||
|
||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
|
||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
||||
return add3;
|
||||
}
|
||||
|
||||
bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
|
||||
EquivPtr another_equiv =
|
||||
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
|
||||
*child_primitive_vars_, empty_equiv);
|
||||
if (another_equiv != nullptr && !another_equiv->empty()) {
|
||||
return IsShareNodes(equiv, another_equiv);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
|
||||
return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) &&
|
||||
IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_);
|
||||
|
@ -164,7 +107,7 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
|
||||
BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -205,7 +148,7 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
|
|||
return add5;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
|
||||
BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -246,7 +189,7 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
|
|||
return add5;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
|
||||
BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -286,5 +229,47 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
|
|||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
||||
return add5;
|
||||
}
|
||||
|
||||
BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||
|
||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
|
||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
||||
return add3;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
|
||||
return add5;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,15 +24,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class LambNextMVWithDecayRule : public PatternProcessPass {
|
||||
class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass {
|
||||
public:
|
||||
explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4",
|
||||
bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {
|
||||
explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true)
|
||||
: MultipleOutputPatternProcessPass(name, multigraph) {
|
||||
for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) {
|
||||
input_vars_.push_back(std::make_shared<Var>());
|
||||
}
|
||||
|
@ -48,21 +43,16 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
|
|||
}
|
||||
|
||||
~LambNextMVWithDecayRule() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
virtual const BaseRef DefineAnotherPattern() const;
|
||||
const BaseRef DefinePattern() const override = 0;
|
||||
BaseRef DefineAnotherPattern() const override = 0;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override;
|
||||
|
||||
protected:
|
||||
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
// check two patterns whether share the same nodes or not
|
||||
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const;
|
||||
|
||||
AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node,
|
||||
const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const;
|
||||
AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3,
|
||||
const AnfNodePtr &add5, const EquivPtr &equiv) const;
|
||||
PatternEngine child_pattern_engine_;
|
||||
PrimitiveVarMapPtr child_primitive_vars_;
|
||||
std::vector<VarPtr> input_vars_;
|
||||
std::vector<VarPtr> constant_mul_input_vars_;
|
||||
// nodes which two patterns share
|
||||
|
@ -82,7 +72,7 @@ class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule {
|
|||
|
||||
~LambNextMVWithDecayRuleCond1() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefineAnotherPattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
|
||||
|
@ -92,7 +82,7 @@ class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
|
|||
|
||||
~LambNextMVWithDecayRuleCond2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefineAnotherPattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
|
||||
|
@ -102,7 +92,17 @@ class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
|
|||
|
||||
~LambNextMVWithDecayRuleCond3() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const BaseRef DefineAnotherPattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule {
|
||||
public:
|
||||
explicit LambNextMVWithDecayRuleCond4(bool multigraph = true)
|
||||
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {}
|
||||
|
||||
~LambNextMVWithDecayRuleCond4() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -62,6 +62,21 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
|
||||
EquivPtr another_equiv =
|
||||
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
|
||||
*child_primitive_vars_, empty_equiv);
|
||||
if (another_equiv != nullptr && !another_equiv->empty()) {
|
||||
return IsShareNodes(equiv, another_equiv);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
|
||||
if (pass_manager != nullptr) {
|
||||
pass_managers_.push_back(pass_manager);
|
||||
|
|
|
@ -51,6 +51,25 @@ class PatternProcessPass : public NodePass {
|
|||
PrimitiveVarMapPtr primitive_vars_;
|
||||
};
|
||||
|
||||
class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
|
||||
~MultipleOutputPatternProcessPass() override = default;
|
||||
virtual BaseRef DefineAnotherPattern() const = 0;
|
||||
// check two patterns whether share the same nodes or not
|
||||
virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0;
|
||||
|
||||
protected:
|
||||
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
PatternEngine child_pattern_engine_;
|
||||
PrimitiveVarMapPtr child_primitive_vars_;
|
||||
};
|
||||
|
||||
class GraphOptimizer {
|
||||
public:
|
||||
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestHWLambNextMVRule : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_matched) {
|
||||
/*
|
||||
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -54,7 +54,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -65,15 +65,15 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div4) {
|
||||
/*
|
||||
* def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
|
||||
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -97,7 +97,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div4");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div4");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -109,14 +109,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div2) {
|
||||
/*
|
||||
* def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
|
||||
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -140,7 +140,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div2");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div2");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -152,14 +152,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div0) {
|
||||
/*
|
||||
* def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
|
||||
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -183,7 +183,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div0");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div0");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -195,14 +195,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) {
|
||||
/*
|
||||
* def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
|
||||
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -226,7 +226,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div1");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div1");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -238,7 +238,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestHWLambNextMVWithDecayRule : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_matched) {
|
||||
/*
|
||||
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -55,7 +55,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -66,15 +66,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add3) {
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_add3) {
|
||||
/*
|
||||
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -99,7 +99,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_add3");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_add3");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -111,15 +111,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul4) {
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_mul4) {
|
||||
/*
|
||||
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -144,7 +144,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_mul4");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_mul4");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -156,15 +156,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div0) {
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div0) {
|
||||
/*
|
||||
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -189,7 +189,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div0");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div0");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -201,15 +201,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div1) {
|
||||
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div1) {
|
||||
/*
|
||||
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
|
@ -234,7 +234,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div1");
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div1");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -246,11 +246,11 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
||||
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class FnDict:
|
|||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_lamb_next_mv_rule(tag):
|
||||
def test_lamb_next_mv_rule_cond4(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
|
|
|
@ -34,7 +34,7 @@ class FnDict:
|
|||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
def test_lamb_next_mv_with_decay_rule(tag):
|
||||
def test_lamb_next_mv_with_decay_rule_cond4(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
|
|
Loading…
Reference in New Issue