forked from mindspore-Ecosystem/mindspore
!1266 Refactor LambNextMVWithDecayRule fusion pass
Merge pull request !1266 from huanghui/LambNextMvWithDecayRuleConds-fusion-pass
This commit is contained in:
commit
e4795974be
|
@ -112,7 +112,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DereluFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
|
|
|
@ -20,28 +20,23 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node,
|
||||
const AnfNodePtr &add3, const AnfNodePtr &add5, const AnfNodePtr &real_div0,
|
||||
const AnfNodePtr &real_div1) {
|
||||
AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &new_node, const AnfNodePtr &add3,
|
||||
const AnfNodePtr &add5, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
MS_EXCEPTION_IF_NULL(add3);
|
||||
MS_EXCEPTION_IF_NULL(real_div0);
|
||||
MS_EXCEPTION_IF_NULL(real_div1);
|
||||
MS_EXCEPTION_IF_NULL(add5);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto add0 = GetAnfNodeByVar(equiv, add0_var_);
|
||||
MS_EXCEPTION_IF_NULL(add0);
|
||||
auto add1 = GetAnfNodeByVar(equiv, add1_var_);
|
||||
MS_EXCEPTION_IF_NULL(add1);
|
||||
|
||||
// Set abstract of new node
|
||||
AbstractBasePtrList new_node_list;
|
||||
new_node_list.push_back(add3->abstract());
|
||||
auto real_div0_cnode = real_div0->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div0_cnode);
|
||||
AnfNodePtr add0 = real_div0_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(add0);
|
||||
new_node_list.push_back(add0->abstract());
|
||||
auto real_div1_cnode = real_div1->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div1_cnode);
|
||||
AnfNodePtr add1 = real_div1_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(add1);
|
||||
new_node_list.push_back(add1->abstract());
|
||||
new_node_list.push_back(add5->abstract());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list);
|
||||
|
@ -58,94 +53,8 @@ AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const An
|
|||
return new_node_outputs[3];
|
||||
}
|
||||
|
||||
void GetSharedInputNodesByAdd5(const AnfNodePtr &node, AnfNodePtr *mul4, AnfNodePtr *real_div0, AnfNodePtr *real_div1,
|
||||
AnfNodePtr *constant_add2_y_input) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto add5_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add5_cnode);
|
||||
if (add5_cnode->inputs().size() < kAddInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of Add5 is less than " << kAddInputNum;
|
||||
}
|
||||
*mul4 = add5_cnode->input(2);
|
||||
|
||||
AnfNodePtr real_div4 = add5_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(real_div4);
|
||||
auto real_div4_cnode = real_div4->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div4_cnode);
|
||||
if (real_div4_cnode->inputs().size() < kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of RealDiv4 is less than " << kRealDivInputNum;
|
||||
}
|
||||
*real_div0 = real_div4_cnode->input(1);
|
||||
|
||||
AnfNodePtr add4 = real_div4_cnode->input(2);
|
||||
MS_EXCEPTION_IF_NULL(add4);
|
||||
auto add4_cnode = add4->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add4_cnode);
|
||||
if (add4_cnode->inputs().size() < kAddInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of Add4 is less than " << kAddInputNum;
|
||||
}
|
||||
AnfNodePtr sqrt1 = add4_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(sqrt1);
|
||||
auto sqrt1_cnode = sqrt1->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sqrt1_cnode);
|
||||
if (sqrt1_cnode->inputs().size() < kSqrtInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of Sqrt1 is less than " << kSqrtInputNum;
|
||||
}
|
||||
*real_div1 = sqrt1_cnode->input(1);
|
||||
*constant_add2_y_input = add4_cnode->input(2);
|
||||
}
|
||||
|
||||
bool MatchAdd3(const AnfNodePtr &add3, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, const AnfNodePtr &real_div1,
|
||||
const AnfNodePtr &constant_add2_y) {
|
||||
if (add3 == nullptr || !add3->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto add3_cnode = add3->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add3_cnode);
|
||||
if (AnfAlgo::GetCNodeName(add3_cnode) != prim::kPrimTensorAdd->name() ||
|
||||
add3_cnode->inputs().size() != kAddInputNum) {
|
||||
return false;
|
||||
}
|
||||
// Check the shared input nodes.
|
||||
if (add3_cnode->input(2) != mul4) {
|
||||
return false;
|
||||
}
|
||||
AnfNodePtr real_div2 = add3_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(real_div2);
|
||||
auto real_div2_cnode = real_div2->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div2_cnode);
|
||||
if (AnfAlgo::GetCNodeName(real_div2_cnode) != prim::kPrimMul->name() ||
|
||||
real_div2_cnode->inputs().size() != kMulInputNum) {
|
||||
return false;
|
||||
}
|
||||
if (real_div2_cnode->input(1) != real_div0) {
|
||||
return false;
|
||||
}
|
||||
AnfNodePtr sqrt0 = real_div2_cnode->input(2);
|
||||
MS_EXCEPTION_IF_NULL(sqrt0);
|
||||
auto sqrt0_cnode = sqrt0->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
|
||||
if (AnfAlgo::GetCNodeName(sqrt0_cnode) != kRsqrtOpName || sqrt0_cnode->inputs().size() != kRsqrtInputNum) {
|
||||
return false;
|
||||
}
|
||||
AnfNodePtr add2 = sqrt0_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(add2);
|
||||
auto add2_cnode = add2->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add2_cnode);
|
||||
if (AnfAlgo::GetCNodeName(add2_cnode) != prim::kPrimTensorAdd->name() ||
|
||||
add2_cnode->inputs().size() != kAddInputNum) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(add2_cnode->input(2));
|
||||
MS_EXCEPTION_IF_NULL(constant_add2_y);
|
||||
return add2_cnode->input(1) == real_div1 && *(add2_cnode->input(2)) == *constant_add2_y;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &add3, const AnfNodePtr &add5,
|
||||
const AnfNodePtr &real_div0,
|
||||
const AnfNodePtr &real_div1,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(add3);
|
||||
|
@ -167,7 +76,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
|
|||
MS_EXCEPTION_IF_NULL(constant_add2_y_node);
|
||||
new_node_inputs.push_back(constant_add2_y_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, real_div0, real_div1);
|
||||
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRule::DefinePattern() const {
|
||||
|
@ -175,44 +84,82 @@ const BaseRef LambNextMVWithDecayRule::DefinePattern() const {
|
|||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
||||
VectorRef mul4 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[4], input_vars_[6]});
|
||||
VectorRef add0 =
|
||||
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}),
|
||||
VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]})});
|
||||
VectorRef real_div0 = VectorRef({prim_deal_div, add0, input_vars_[5]});
|
||||
VectorRef add1 =
|
||||
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}),
|
||||
VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]})});
|
||||
VectorRef real_div1 = VectorRef({prim_deal_div, add1, input_vars_[2]});
|
||||
VectorRef real_div4 = VectorRef(
|
||||
{prim_deal_div, real_div0, VectorRef({prim::kPrimTensorAdd, VectorRef({prim_sqrt, real_div1}), constant_add2_y_})});
|
||||
return VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
|
||||
return add5;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||
|
||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
|
||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
||||
return add3;
|
||||
}
|
||||
|
||||
bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
|
||||
EquivPtr another_equiv =
|
||||
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
|
||||
*child_primitive_vars_, empty_equiv);
|
||||
if (another_equiv != nullptr && !another_equiv->empty()) {
|
||||
return IsShareNodes(equiv, another_equiv);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
|
||||
return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) &&
|
||||
IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_);
|
||||
}
|
||||
|
||||
const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Get the shared input nodes in patterns of add5 and add3
|
||||
AnfNodePtr mul4 = nullptr;
|
||||
AnfNodePtr real_div0 = nullptr;
|
||||
AnfNodePtr real_div1 = nullptr;
|
||||
AnfNodePtr constant_add2_y_input = nullptr;
|
||||
GetSharedInputNodesByAdd5(node, &mul4, &real_div0, &real_div1, &constant_add2_y_input);
|
||||
// Get add3 and try to match the add3 pattern
|
||||
AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
|
||||
MS_EXCEPTION_IF_NULL(mul4);
|
||||
// Get add3 and match the add3 pattern
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(mul4) == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input";
|
||||
}
|
||||
AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4];
|
||||
auto iter = std::find_if(
|
||||
mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(),
|
||||
[&node, &mul4, &real_div0, &real_div1, &constant_add2_y_input](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return node_index.first != node && MatchAdd3(node_index.first, mul4, real_div0, real_div1, constant_add2_y_input);
|
||||
});
|
||||
if (iter != mul4_output_node_index_set.end()) {
|
||||
return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, real_div0, real_div1, equiv);
|
||||
AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4];
|
||||
auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(),
|
||||
[&node, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return node_index.first != node && MatchAnotherPattern(node_index.first, equiv);
|
||||
});
|
||||
if (iter != mul4_outputs.end()) {
|
||||
return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
|
@ -25,8 +26,13 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class LambNextMVWithDecayRule : public PatternProcessPass {
|
||||
public:
|
||||
explicit LambNextMVWithDecayRule(bool multigraph = true)
|
||||
: PatternProcessPass("lamb_next_mv_with_decay_rule", multigraph) {
|
||||
explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4",
|
||||
bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {
|
||||
for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) {
|
||||
input_vars_.push_back(std::make_shared<Var>());
|
||||
}
|
||||
|
@ -34,20 +40,39 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
|
|||
constant_mul_input_vars_.push_back(std::make_shared<Var>());
|
||||
}
|
||||
constant_add2_y_ = std::make_shared<Var>();
|
||||
mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
|
||||
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
}
|
||||
|
||||
~LambNextMVWithDecayRule() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
virtual const BaseRef DefineAnotherPattern() const;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3,
|
||||
const AnfNodePtr &add5, const AnfNodePtr &real_div0,
|
||||
const AnfNodePtr &real_div1, const EquivPtr &equiv) const;
|
||||
protected:
|
||||
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
// check two patterns whether share the same nodes or not
|
||||
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const;
|
||||
|
||||
AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node,
|
||||
const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const;
|
||||
AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3,
|
||||
const AnfNodePtr &add5, const EquivPtr &equiv) const;
|
||||
PatternEngine child_pattern_engine_;
|
||||
PrimitiveVarMapPtr child_primitive_vars_;
|
||||
std::vector<VarPtr> input_vars_;
|
||||
std::vector<VarPtr> constant_mul_input_vars_;
|
||||
// nodes which two patterns share
|
||||
VarPtr constant_add2_y_;
|
||||
VarPtr mul4_var_;
|
||||
VarPtr real_div0_var_;
|
||||
VarPtr real_div1_var_;
|
||||
// part of output nodes
|
||||
VarPtr add0_var_;
|
||||
VarPtr add1_var_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,6 +64,8 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node);
|
||||
auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
|
||||
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node);
|
||||
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -64,6 +64,8 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node);
|
||||
auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
|
||||
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node);
|
||||
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -539,5 +539,169 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
|
|||
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
|
||||
}
|
||||
}
|
||||
|
||||
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||
auto b_node = utils::cast<AnfNodePtr>(b);
|
||||
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
|
||||
auto a_value_node = a_node->cast<ValueNodePtr>();
|
||||
auto a_value = a_value_node->value();
|
||||
auto a_prim = a_value->cast<PrimitivePtr>();
|
||||
|
||||
auto b_value_node = b_node->cast<ValueNodePtr>();
|
||||
auto b_value = b_value_node->value();
|
||||
auto b_prim = b_value->cast<PrimitivePtr>();
|
||||
|
||||
return a_prim->name() == b_prim->name();
|
||||
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
|
||||
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
|
||||
if (a_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto a_value_ptr = a_value_node_ptr->value();
|
||||
if (a_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
||||
if (b_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto b_value_ptr = b_value_node_ptr->value();
|
||||
if (b_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
return (*a_value_ptr) == (*b_value_ptr);
|
||||
}
|
||||
MS_LOG(DEBUG) << "check AnfNodePtr equal";
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
|
||||
MS_LOG(DEBUG) << "check GraphPtr equal";
|
||||
}
|
||||
return a == b;
|
||||
}
|
||||
|
||||
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
||||
// To matchCNode and Kernel's type
|
||||
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
|
||||
return true;
|
||||
}
|
||||
return a.type() == b.type();
|
||||
}
|
||||
|
||||
namespace {
|
||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||
if (utils::isa<int>(sexp)) {
|
||||
return NewValueNode(utils::cast<int>(sexp));
|
||||
}
|
||||
if (utils::isa<float>(sexp)) {
|
||||
return NewValueNode(utils::cast<float>(sexp));
|
||||
}
|
||||
if (utils::isa<bool>(sexp)) {
|
||||
return NewValueNode(utils::cast<bool>(sexp));
|
||||
}
|
||||
if (utils::isa<ValuePtr>(sexp)) {
|
||||
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph) {
|
||||
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
const auto &tuple = utils::cast<VectorRef>(sexp);
|
||||
if (multigraph && utils::isa<VarPtr>(graph)) {
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
VarPtr var_ptr = utils::cast<VarPtr>(graph);
|
||||
return std::make_shared<CNode>(input_nodes, var_ptr);
|
||||
}
|
||||
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
return CreateCNodeWithGraph(input_nodes, graph);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars);
|
||||
if (utils::isa<VectorRef>(sexp)) {
|
||||
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
|
||||
}
|
||||
if (utils::isa<VarPtr>(sexp)) {
|
||||
auto var_ptr = utils::cast<VarPtr>(sexp);
|
||||
MS_EXCEPTION_IF_NULL(var_ptr);
|
||||
if (var_ptr->primitive()) {
|
||||
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
|
||||
return NewValueNode(var_ptr->primitive());
|
||||
}
|
||||
return CreateVarNodeWithSexp(sexp, graph);
|
||||
}
|
||||
if (utils::isa<AnfNodePtr>(sexp)) {
|
||||
return utils::cast<AnfNodePtr>(sexp);
|
||||
}
|
||||
auto value_node = CreateValueNodeWithSexp(sexp);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
|
||||
}
|
||||
return value_node;
|
||||
}
|
||||
|
||||
bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
|
||||
MS_EXCEPTION_IF_NULL(equiv1);
|
||||
MS_EXCEPTION_IF_NULL(equiv2);
|
||||
MS_EXCEPTION_IF_NULL(var_node);
|
||||
auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
|
||||
MS_EXCEPTION_IF_NULL(equiv1_node);
|
||||
auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
|
||||
MS_EXCEPTION_IF_NULL(equiv2_node);
|
||||
return equiv1_node == equiv2_node;
|
||||
}
|
||||
|
||||
AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_EXCEPTION_IF_NULL(var_node);
|
||||
auto iter = (*equiv).find(var_node);
|
||||
if (iter == (*equiv).end()) {
|
||||
MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
|
||||
return nullptr;
|
||||
}
|
||||
auto res = utils::cast<AnfNodePtr>(iter->second);
|
||||
if (res == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "common/utils.h"
|
||||
#include "pre_activate/common/pattern_engine.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -162,6 +163,19 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
|
||||
|
||||
bool AnfEqual(const BaseRef &a, const BaseRef &b);
|
||||
|
||||
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
|
||||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph = false);
|
||||
|
||||
// Check var_node in two equivs is the same node
|
||||
bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node);
|
||||
|
||||
// Get anf_node from equiv by var_node
|
||||
AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
|
||||
|
|
|
@ -29,148 +29,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph);
|
||||
|
||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||
if (utils::isa<int>(sexp)) {
|
||||
return NewValueNode(utils::cast<int>(sexp));
|
||||
}
|
||||
if (utils::isa<float>(sexp)) {
|
||||
return NewValueNode(utils::cast<float>(sexp));
|
||||
}
|
||||
if (utils::isa<bool>(sexp)) {
|
||||
return NewValueNode(utils::cast<bool>(sexp));
|
||||
}
|
||||
if (utils::isa<ValuePtr>(sexp)) {
|
||||
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph = false) {
|
||||
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars);
|
||||
if (utils::isa<VectorRef>(sexp)) {
|
||||
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
|
||||
}
|
||||
if (utils::isa<VarPtr>(sexp)) {
|
||||
auto var_ptr = utils::cast<VarPtr>(sexp);
|
||||
MS_EXCEPTION_IF_NULL(var_ptr);
|
||||
if (var_ptr->primitive()) {
|
||||
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
|
||||
return NewValueNode(var_ptr->primitive());
|
||||
}
|
||||
return CreateVarNodeWithSexp(sexp, graph);
|
||||
}
|
||||
if (utils::isa<AnfNodePtr>(sexp)) {
|
||||
return utils::cast<AnfNodePtr>(sexp);
|
||||
}
|
||||
auto value_node = CreateValueNodeWithSexp(sexp);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
|
||||
}
|
||||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph) {
|
||||
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
const auto &tuple = utils::cast<VectorRef>(sexp);
|
||||
if (multigraph && utils::isa<VarPtr>(graph)) {
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
VarPtr var_ptr = utils::cast<VarPtr>(graph);
|
||||
return std::make_shared<CNode>(input_nodes, var_ptr);
|
||||
}
|
||||
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
return CreateCNodeWithGraph(input_nodes, graph);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||
auto b_node = utils::cast<AnfNodePtr>(b);
|
||||
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
|
||||
auto a_value_node = a_node->cast<ValueNodePtr>();
|
||||
auto a_value = a_value_node->value();
|
||||
auto a_prim = a_value->cast<PrimitivePtr>();
|
||||
|
||||
auto b_value_node = b_node->cast<ValueNodePtr>();
|
||||
auto b_value = b_value_node->value();
|
||||
auto b_prim = b_value->cast<PrimitivePtr>();
|
||||
|
||||
return a_prim->name() == b_prim->name();
|
||||
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
|
||||
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
|
||||
if (a_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto a_value_ptr = a_value_node_ptr->value();
|
||||
if (a_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
||||
if (b_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto b_value_ptr = b_value_node_ptr->value();
|
||||
if (b_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
return (*a_value_ptr) == (*b_value_ptr);
|
||||
}
|
||||
MS_LOG(DEBUG) << "check AnfNodePtr equal";
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
|
||||
MS_LOG(DEBUG) << "check GraphPtr equal";
|
||||
}
|
||||
return a == b;
|
||||
}
|
||||
|
||||
static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
||||
// To matchCNode and Kernel's type
|
||||
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
|
||||
return true;
|
||||
}
|
||||
return a.type() == b.type();
|
||||
}
|
||||
|
||||
PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
|
||||
: NodePass(name),
|
||||
multigraph_(multigraph),
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "pre_activate/common/pattern_engine.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "common/utils.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
batch_norm_grad = G.BatchNormGrad(is_training=False)
|
||||
|
|
|
@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple')
|
|||
tuple_getitem = Primitive('tuple_getitem')
|
||||
LambNextMVWithDecay = Primitive('LambNextMVWithDecay')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
@ -35,7 +34,6 @@ class FnDict:
|
|||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_lamb_next_mv_with_decay_rule(tag):
|
||||
fns = FnDict()
|
||||
|
||||
|
|
Loading…
Reference in New Issue