forked from mindspore-Ecosystem/mindspore
Add AdamApplyOneAssign and AdamApplyOneWithDecayAssign fusion pass
This commit is contained in:
parent
fe2c2e8330
commit
8a77751988
|
@ -124,6 +124,10 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
|
||||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond1Fusion>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond2Fusion>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond3Fusion>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond4Fusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>());
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>());
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>());
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>());
|
||||||
|
|
|
@ -15,30 +15,9 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h"
|
#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(equiv);
|
|
||||||
auto prim = std::make_shared<Primitive>(kAdamApplyOneOpName);
|
|
||||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
|
|
||||||
for (const auto &input_var : input_vars_) {
|
|
||||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]);
|
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
|
||||||
new_node_inputs.push_back(input_node);
|
|
||||||
}
|
|
||||||
for (const auto &mul_x_input_var : mul_x_input_vars_) {
|
|
||||||
auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]);
|
|
||||||
MS_EXCEPTION_IF_NULL(mul_x_input_node);
|
|
||||||
new_node_inputs.push_back(mul_x_input_node);
|
|
||||||
}
|
|
||||||
auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
|
||||||
MS_EXCEPTION_IF_NULL(add2_y_node);
|
|
||||||
new_node_inputs.push_back(add2_y_node);
|
|
||||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
|
||||||
return new_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
const BaseRef AdamApplyOneFusion::DefinePattern() const {
|
const BaseRef AdamApplyOneFusion::DefinePattern() const {
|
||||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
|
@ -104,16 +83,152 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
|
||||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
|
||||||
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||||
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
|
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||||
|
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||||
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
|
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||||
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||||
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||||
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||||
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||||
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
|
||||||
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||||
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
|
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||||
|
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||||
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
|
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||||
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
||||||
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||||
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||||
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||||
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
|
||||||
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||||
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
|
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||||
|
VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]});
|
||||||
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
|
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||||
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||||
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||||
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||||
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
|
||||||
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||||
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
|
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||||
|
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||||
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
|
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||||
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||||
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||||
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||||
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
|
||||||
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||||
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
|
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||||
|
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||||
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
|
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||||
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
||||||
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||||
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||||
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||||
|
const AnfNodePtr &final_node) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(equiv);
|
||||||
|
PrimitivePtr prim = nullptr;
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) {
|
||||||
|
prim = std::make_shared<Primitive>(kAdamApplyOneAssignOpName);
|
||||||
|
} else {
|
||||||
|
prim = std::make_shared<Primitive>(kAdamApplyOneOpName);
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
|
||||||
|
for (const auto &input_var : input_vars_) {
|
||||||
|
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
|
new_node_inputs.push_back(input_node);
|
||||||
|
}
|
||||||
|
for (const auto &mul_x_input_var : mul_x_input_vars_) {
|
||||||
|
auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]);
|
||||||
|
MS_EXCEPTION_IF_NULL(mul_x_input_node);
|
||||||
|
new_node_inputs.push_back(mul_x_input_node);
|
||||||
|
}
|
||||||
|
auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
||||||
|
MS_EXCEPTION_IF_NULL(add2_y_node);
|
||||||
|
new_node_inputs.push_back(add2_y_node);
|
||||||
|
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
|
||||||
const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &equiv) const {
|
const EquivPtr &equiv) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
auto sub0 = node;
|
||||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||||
|
auto iter_sub0 = (*equiv).find(sub0_var_);
|
||||||
|
if (iter_sub0 == (*equiv).end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched.";
|
||||||
|
}
|
||||||
|
sub0 = utils::cast<AnfNodePtr>(iter_sub0->second);
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(sub0);
|
||||||
|
if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto new_node = CreateAdamApplyOneNode(func_graph, equiv);
|
auto new_node = CreateAdamApplyOneNode(func_graph, equiv, node);
|
||||||
MS_EXCEPTION_IF_NULL(new_node);
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
new_node->set_scope(node->scope());
|
new_node->set_scope(sub0->scope());
|
||||||
// Set abstract of new node
|
// Set abstract of new node
|
||||||
AbstractBasePtrList new_node_abstract_list;
|
AbstractBasePtrList new_node_abstract_list;
|
||||||
auto iter_add0 = (*equiv).find(add0_var_);
|
auto iter_add0 = (*equiv).find(add0_var_);
|
||||||
|
@ -130,7 +245,7 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
|
||||||
MS_EXCEPTION_IF_NULL(add1);
|
MS_EXCEPTION_IF_NULL(add1);
|
||||||
new_node_abstract_list.push_back(add1->abstract());
|
new_node_abstract_list.push_back(add1->abstract());
|
||||||
new_node_abstract_list.push_back(add0->abstract());
|
new_node_abstract_list.push_back(add0->abstract());
|
||||||
new_node_abstract_list.push_back(node->abstract());
|
new_node_abstract_list.push_back(sub0->abstract());
|
||||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list);
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list);
|
||||||
new_node->set_abstract(abstract_tuple);
|
new_node->set_abstract(abstract_tuple);
|
||||||
// Create tuple_getitem node for outputs
|
// Create tuple_getitem node for outputs
|
||||||
|
|
|
@ -40,6 +40,7 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
||||||
add2_y_ = std::make_shared<Var>();
|
add2_y_ = std::make_shared<Var>();
|
||||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||||
|
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
~AdamApplyOneFusion() override = default;
|
~AdamApplyOneFusion() override = default;
|
||||||
|
@ -47,12 +48,14 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
|
AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||||
|
const AnfNodePtr &final_node) const;
|
||||||
std::vector<VarPtr> input_vars_;
|
std::vector<VarPtr> input_vars_;
|
||||||
std::vector<VarPtr> mul_x_input_vars_;
|
std::vector<VarPtr> mul_x_input_vars_;
|
||||||
VarPtr add2_y_;
|
VarPtr add2_y_;
|
||||||
VarPtr add0_var_;
|
VarPtr add0_var_;
|
||||||
VarPtr add1_var_;
|
VarPtr add1_var_;
|
||||||
|
VarPtr sub0_var_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AdamApplyOneCond1Fusion : public AdamApplyOneFusion {
|
class AdamApplyOneCond1Fusion : public AdamApplyOneFusion {
|
||||||
|
@ -90,6 +93,51 @@ class AdamApplyOneCond4Fusion : public AdamApplyOneFusion {
|
||||||
~AdamApplyOneCond4Fusion() override = default;
|
~AdamApplyOneCond4Fusion() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class AdamApplyOneAssignFusion : public AdamApplyOneFusion {
|
||||||
|
public:
|
||||||
|
explicit AdamApplyOneAssignFusion(bool multigraph = true)
|
||||||
|
: AdamApplyOneFusion("adam_apply_one_assign_fusion", multigraph) {}
|
||||||
|
|
||||||
|
~AdamApplyOneAssignFusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AdamApplyOneAssignCond1Fusion : public AdamApplyOneFusion {
|
||||||
|
public:
|
||||||
|
explicit AdamApplyOneAssignCond1Fusion(bool multigraph = true)
|
||||||
|
: AdamApplyOneFusion("adam_apply_one_assign_cond1_fusion", multigraph) {}
|
||||||
|
|
||||||
|
~AdamApplyOneAssignCond1Fusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AdamApplyOneAssignCond2Fusion : public AdamApplyOneFusion {
|
||||||
|
public:
|
||||||
|
explicit AdamApplyOneAssignCond2Fusion(bool multigraph = true)
|
||||||
|
: AdamApplyOneFusion("adam_apply_one_assign_cond2_fusion", multigraph) {}
|
||||||
|
|
||||||
|
~AdamApplyOneAssignCond2Fusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AdamApplyOneAssignCond3Fusion : public AdamApplyOneFusion {
|
||||||
|
public:
|
||||||
|
explicit AdamApplyOneAssignCond3Fusion(bool multigraph = true)
|
||||||
|
: AdamApplyOneFusion("adam_apply_one_assign_cond3_fusion", multigraph) {}
|
||||||
|
|
||||||
|
~AdamApplyOneAssignCond3Fusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AdamApplyOneAssignCond4Fusion : public AdamApplyOneFusion {
|
||||||
|
public:
|
||||||
|
explicit AdamApplyOneAssignCond4Fusion(bool multigraph = true)
|
||||||
|
: AdamApplyOneFusion("adam_apply_one_assign_cond4_fusion", multigraph) {}
|
||||||
|
|
||||||
|
~AdamApplyOneAssignCond4Fusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_
|
||||||
|
|
|
@ -119,6 +119,7 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay";
|
||||||
constexpr auto kBatchNormGradOpName = "BatchNormGrad";
|
constexpr auto kBatchNormGradOpName = "BatchNormGrad";
|
||||||
constexpr auto kBNInferOpName = "BNInfer";
|
constexpr auto kBNInferOpName = "BNInfer";
|
||||||
constexpr auto kAdamApplyOneOpName = "AdamApplyOne";
|
constexpr auto kAdamApplyOneOpName = "AdamApplyOne";
|
||||||
|
constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign";
|
||||||
constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad";
|
constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad";
|
||||||
constexpr auto kFusedMulAddOpName = "FusedMulAdd";
|
constexpr auto kFusedMulAddOpName = "FusedMulAdd";
|
||||||
constexpr auto kFusedMulAddNOpName = "FusedMulAddN";
|
constexpr auto kFusedMulAddNOpName = "FusedMulAddN";
|
||||||
|
|
|
@ -217,5 +217,105 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) {
|
||||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after");
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after");
|
||||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_fusion) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before");
|
||||||
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < 10; ++i) {
|
||||||
|
args_spec_list.push_back(x_abstract);
|
||||||
|
}
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::AdamApplyOneAssignFusion>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond1_fusion) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond1");
|
||||||
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < 10; ++i) {
|
||||||
|
args_spec_list.push_back(x_abstract);
|
||||||
|
}
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond1Fusion>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond2_fusion) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond2");
|
||||||
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < 10; ++i) {
|
||||||
|
args_spec_list.push_back(x_abstract);
|
||||||
|
}
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond2Fusion>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond3_fusion) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond3");
|
||||||
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < 10; ++i) {
|
||||||
|
args_spec_list.push_back(x_abstract);
|
||||||
|
}
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond3Fusion>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond4_fusion) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond4");
|
||||||
|
std::vector<int> shp{2, 32, 224, 224};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < 10; ++i) {
|
||||||
|
args_spec_list.push_back(x_abstract);
|
||||||
|
}
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond4Fusion>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
Add = P.TensorAdd()
|
Add = P.TensorAdd()
|
||||||
Sub = P.Sub()
|
Sub = P.Sub()
|
||||||
|
@ -21,9 +22,11 @@ Mul = P.Mul()
|
||||||
RealDiv = P.RealDiv()
|
RealDiv = P.RealDiv()
|
||||||
Sqrt = P.Sqrt()
|
Sqrt = P.Sqrt()
|
||||||
Square = P.Square()
|
Square = P.Square()
|
||||||
|
Assign = P.Assign()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
AdamApplyOne = Primitive('AdamApplyOne')
|
AdamApplyOne = Primitive('AdamApplyOne')
|
||||||
|
AdamApplyOneAssign = Primitive('AdamApplyOneAssign')
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
@ -139,3 +142,138 @@ def test_adam_apply_one_fusion(tag):
|
||||||
return make_tuple(output)
|
return make_tuple(output)
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
def test_adam_apply_one_assign_fusion(tag):
|
||||||
|
fns = FnDict()
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||||
|
square0 = Square(input0)
|
||||||
|
mul1 = Mul(mul1_x, input0)
|
||||||
|
mul0 = Mul(mul0_x, input2)
|
||||||
|
mul2 = Mul(mul2_x, input1)
|
||||||
|
mul3 = Mul(mul3_x, square0)
|
||||||
|
add0 = Add(mul0, mul1)
|
||||||
|
add1 = Add(mul2, mul3)
|
||||||
|
sqrt0 = Sqrt(add1)
|
||||||
|
add2 = Add(sqrt0, add2_y)
|
||||||
|
true_div0 = RealDiv(add0, add2)
|
||||||
|
mul4 = Mul(input4, true_div0)
|
||||||
|
sub0 = Sub(input3, mul4)
|
||||||
|
assign0 = Assign(input3, sub0)
|
||||||
|
depend0 = F.depend(sub0, assign0)
|
||||||
|
assign1 = Assign(input2, add0)
|
||||||
|
depend1 = F.depend(depend0, assign1)
|
||||||
|
assign2 = Assign(input1, add1)
|
||||||
|
depend2 = F.depend(depend1, assign2)
|
||||||
|
outputs = make_tuple(add1, add0, depend2)
|
||||||
|
output = tuple_getitem(outputs, 0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||||
|
square0 = Square(input0)
|
||||||
|
mul1 = Mul(mul1_x, input0)
|
||||||
|
mul0 = Mul(mul0_x, input2)
|
||||||
|
mul2 = Mul(mul2_x, input1)
|
||||||
|
mul3 = Mul(mul3_x, square0)
|
||||||
|
add0 = Add(mul0, mul1)
|
||||||
|
add1 = Add(mul2, mul3)
|
||||||
|
sqrt0 = Sqrt(add1)
|
||||||
|
add2 = Add(add2_y, sqrt0)
|
||||||
|
true_div0 = RealDiv(add0, add2)
|
||||||
|
mul4 = Mul(input4, true_div0)
|
||||||
|
sub0 = Sub(input3, mul4)
|
||||||
|
assign0 = Assign(input3, sub0)
|
||||||
|
depend0 = F.depend(sub0, assign0)
|
||||||
|
assign1 = Assign(input2, add0)
|
||||||
|
depend1 = F.depend(depend0, assign1)
|
||||||
|
assign2 = Assign(input1, add1)
|
||||||
|
depend2 = F.depend(depend1, assign2)
|
||||||
|
outputs = make_tuple(add1, add0, depend2)
|
||||||
|
output = tuple_getitem(outputs, 0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||||
|
square0 = Square(input0)
|
||||||
|
mul1 = Mul(mul1_x, input0)
|
||||||
|
mul0 = Mul(mul0_x, input2)
|
||||||
|
mul2 = Mul(mul2_x, input1)
|
||||||
|
mul3 = Mul(square0, mul3_x)
|
||||||
|
add0 = Add(mul0, mul1)
|
||||||
|
add1 = Add(mul2, mul3)
|
||||||
|
sqrt0 = Sqrt(add1)
|
||||||
|
add2 = Add(sqrt0, add2_y)
|
||||||
|
true_div0 = RealDiv(add0, add2)
|
||||||
|
mul4 = Mul(true_div0, input4)
|
||||||
|
sub0 = Sub(input3, mul4)
|
||||||
|
assign0 = Assign(input3, sub0)
|
||||||
|
depend0 = F.depend(sub0, assign0)
|
||||||
|
assign1 = Assign(input2, add0)
|
||||||
|
depend1 = F.depend(depend0, assign1)
|
||||||
|
assign2 = Assign(input1, add1)
|
||||||
|
depend2 = F.depend(depend1, assign2)
|
||||||
|
outputs = make_tuple(add1, add0, depend2)
|
||||||
|
output = tuple_getitem(outputs, 0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||||
|
square0 = Square(input0)
|
||||||
|
mul1 = Mul(mul1_x, input0)
|
||||||
|
mul0 = Mul(mul0_x, input2)
|
||||||
|
mul2 = Mul(mul2_x, input1)
|
||||||
|
mul3 = Mul(mul3_x, square0)
|
||||||
|
add0 = Add(mul0, mul1)
|
||||||
|
add1 = Add(mul2, mul3)
|
||||||
|
sqrt0 = Sqrt(add1)
|
||||||
|
add2 = Add(sqrt0, add2_y)
|
||||||
|
true_div0 = RealDiv(add0, add2)
|
||||||
|
mul4 = Mul(true_div0, input4)
|
||||||
|
sub0 = Sub(input3, mul4)
|
||||||
|
assign0 = Assign(input3, sub0)
|
||||||
|
depend0 = F.depend(sub0, assign0)
|
||||||
|
assign1 = Assign(input2, add0)
|
||||||
|
depend1 = F.depend(depend0, assign1)
|
||||||
|
assign2 = Assign(input1, add1)
|
||||||
|
depend2 = F.depend(depend1, assign2)
|
||||||
|
outputs = make_tuple(add1, add0, depend2)
|
||||||
|
output = tuple_getitem(outputs, 0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||||
|
square0 = Square(input0)
|
||||||
|
mul1 = Mul(mul1_x, input0)
|
||||||
|
mul0 = Mul(mul0_x, input2)
|
||||||
|
mul2 = Mul(mul2_x, input1)
|
||||||
|
mul3 = Mul(mul3_x, square0)
|
||||||
|
add0 = Add(mul0, mul1)
|
||||||
|
add1 = Add(mul2, mul3)
|
||||||
|
sqrt0 = Sqrt(add1)
|
||||||
|
add2 = Add(add2_y, sqrt0)
|
||||||
|
true_div0 = RealDiv(add0, add2)
|
||||||
|
mul4 = Mul(true_div0, input4)
|
||||||
|
sub0 = Sub(input3, mul4)
|
||||||
|
assign0 = Assign(input3, sub0)
|
||||||
|
depend0 = F.depend(sub0, assign0)
|
||||||
|
assign1 = Assign(input2, add0)
|
||||||
|
depend1 = F.depend(depend0, assign1)
|
||||||
|
assign2 = Assign(input1, add1)
|
||||||
|
depend2 = F.depend(depend1, assign2)
|
||||||
|
outputs = make_tuple(add1, add0, depend2)
|
||||||
|
output = tuple_getitem(outputs, 0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||||
|
adam_apply_one_assign = AdamApplyOneAssign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x,
|
||||||
|
mul3_x, add2_y)
|
||||||
|
outputs = make_tuple(tuple_getitem(adam_apply_one_assign, 0), tuple_getitem(adam_apply_one_assign, 1),
|
||||||
|
tuple_getitem(adam_apply_one_assign, 2))
|
||||||
|
output = tuple_getitem(outputs, 0)
|
||||||
|
return make_tuple(output)
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
Loading…
Reference in New Issue