diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc index b1afa338d4d..f89c833d06b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace opt { -std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { +std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv, + const AnfNodePtr &final_node) const { MS_EXCEPTION_IF_NULL(equiv); auto input0 = utils::cast((*equiv)[input0_]); auto input1 = utils::cast((*equiv)[input1_]); @@ -37,7 +38,12 @@ std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ auto mul3_x = utils::cast((*equiv)[mul3_x_]); auto mul4_x = utils::cast((*equiv)[mul4_x_]); auto add2_y = utils::cast((*equiv)[add2_y_]); - auto prim = std::make_shared(kAdamApplyOneWithDecayOpName); + PrimitivePtr prim = nullptr; + if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) { + prim = std::make_shared(kAdamApplyOneWithDecayAssignOpName); + } else { + prim = std::make_shared(kAdamApplyOneWithDecayOpName); + } return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; } @@ -141,18 +147,152 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { return sub0; } +const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, input4_, add3}); + VectorRef sub0({sub0_var_, input3_, mul5}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); + VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({sub0_var_, input3_, mul5}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({sub0_var_, input3_, mul5}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({sub0_var_, input3_, mul5}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + +const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({sub0_var_, input3_, mul5}); + VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); + VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); + VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); + return VectorRef({prim::kPrimDepend, depend1, assign2}); +} + const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { if (graph == nullptr || node == nullptr || equiv == nullptr) { return nullptr; } - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + auto sub0 = node; + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { + auto iter_sub0 = (*equiv).find(sub0_var_); + if (iter_sub0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched."; + } + sub0 = utils::cast(iter_sub0->second); + } + MS_EXCEPTION_IF_NULL(sub0); + if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) { return nullptr; } - std::vector inputs = GetFusionNodeInputs(equiv); + std::vector inputs = GetFusionNodeInputs(equiv, node); auto fusion_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); + fusion_node->set_scope(sub0->scope()); auto iter_add0 = (*equiv).find(add0_var_); if (iter_add0 == (*equiv).end()) { @@ -167,9 +307,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c auto add1 = utils::cast(iter_add1->second); MS_EXCEPTION_IF_NULL(add1); auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), - AnfAlgo::GetOutputInferDataType(node, 0)}; + AnfAlgo::GetOutputInferDataType(sub0, 0)}; auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), - AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::GetOutputInferShape(sub0, 0)}; AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); std::vector fusion_node_outputs; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h index 98bc63a6f1a..e43c5ad496d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h @@ -40,13 +40,14 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { add2_y_ = std::make_shared(); add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + sub0_var_ = std::make_shared(std::make_shared(prim::kPrimSub->name())); } ~AdamApplyOneWithDecayRule() override = default; const BaseRef DefinePattern() const override = 0; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; protected: - std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; + std::vector GetFusionNodeInputs(const EquivPtr &equiv, const AnfNodePtr &final_node) const; VarPtr input0_; VarPtr input1_; VarPtr input2_; @@ -60,6 +61,7 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { VarPtr add2_y_; VarPtr add0_var_; VarPtr add1_var_; + VarPtr sub0_var_; }; class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { @@ -106,6 +108,51 @@ class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { ~AdamApplyOneWithDecayRuleCond5() override = default; const BaseRef DefinePattern() const override; }; + +class AdamApplyOneWithDecayAssignRuleCond1 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayAssignRuleCond1(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond1", multigraph) {} + + ~AdamApplyOneWithDecayAssignRuleCond1() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayAssignRuleCond2 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayAssignRuleCond2(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond2", multigraph) {} + + ~AdamApplyOneWithDecayAssignRuleCond2() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayAssignRuleCond3 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayAssignRuleCond3(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond3", multigraph) {} + + ~AdamApplyOneWithDecayAssignRuleCond3() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayAssignRuleCond4 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayAssignRuleCond4(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond4", multigraph) {} + + ~AdamApplyOneWithDecayAssignRuleCond4() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayAssignRuleCond5 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayAssignRuleCond5(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond5", multigraph) {} + + ~AdamApplyOneWithDecayAssignRuleCond5() override = default; + const BaseRef DefinePattern() const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index da7fa32ae95..95261889910 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -122,6 +122,7 @@ constexpr auto kLayerNormBetaGammaBackpropOpName = "LayerNormBetaGammaBackprop"; constexpr auto kLambNextMVOpName = "LambNextMV"; constexpr auto kConfusionTransposeDOpName = "ConfusionTransposeD"; constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; +constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign"; constexpr auto kBatchNormGradOpName = "BatchNormGrad"; constexpr auto kBNInferOpName = "BNInfer"; constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc index 78c815bf506..8cdb03d870c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc @@ -31,7 +31,7 @@ class TestHWOptimizeAdamApplyOneWithDecayRule : public BackendCommon { }; TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond1"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -47,12 +47,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond2"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -68,12 +68,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond3"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -89,12 +89,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond4"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -110,12 +110,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond5"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -131,7 +131,112 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond1) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond1"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond2"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond3"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond4) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond4"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond5) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond5"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } } // namespace opt diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py index a13b1bc5834..96bbe7c057f 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py @@ -15,6 +15,7 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P +from mindspore.ops import functional as F mul = P.Mul() add = P.TensorAdd() @@ -22,9 +23,11 @@ square = P.Square() sqrt = P.Sqrt() real_div = P.RealDiv() sub = P.Sub() +Assign = P.Assign() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay') +adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign') class FnDict: @@ -39,63 +42,10 @@ class FnDict: def test_adam_apply_one_with_decay_rule(tag): - """ test_adam_apply_one_with_decay_rule """ fns = FnDict() @fns - def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - mul0 = mul(mul0_x, input2) - mul1 = mul(mul1_x, input0) - square0 = square(input0) - add0 = add(mul0, mul1) - mul2 = mul(mul2_x, input1) - mul3 = mul(mul3_x, square0) - add1 = add(mul2, mul3) - sqrt0 = sqrt(add1) - add2 = add(sqrt0, add2_y) - mul4 = mul(mul4_x, input3) - real_div0 = real_div(add0, add2) - add3 = add(real_div0, mul4) - mul5 = mul(input4, add3) - sub0 = sub(input3, mul5) - return make_tuple(add1, add0, sub0) - - @fns - def no_match(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - mul0 = mul(mul0_x, input2) - mul1 = mul(mul1_x, input0) - square0 = square(input0) - # diff mul from original add - add0 = mul(mul0, mul1) - mul2 = mul(mul2_x, input1) - mul3 = mul(mul3_x, square0) - add1 = add(mul2, mul3) - sqrt0 = sqrt(add1) - add2 = add(sqrt0, add2_y) - mul4 = mul(mul4_x, input3) - real_div0 = real_div(add0, add2) - add3 = add(real_div0, mul4) - mul5 = mul(input4, add3) - sub0 = sub(input3, mul5) - return make_tuple(add1, add0, sub0) - - @fns - def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, - add2_y) - item0 = tuple_getitem(res, 0) - item1 = tuple_getitem(res, 1) - item2 = tuple_getitem(res, 2) - return make_tuple(make_tuple(item0, item1, item2)) - - return fns[tag] - - -def test_adam_apply_one_with_decay_rule_cond1(tag): - fns = FnDict() - - @fns - def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): mul0 = mul(mul0_x, input2) mul1 = mul(mul1_x, input0) square0 = square(input0) @@ -113,22 +63,7 @@ def test_adam_apply_one_with_decay_rule_cond1(tag): return make_tuple(add1, add0, sub0) @fns - def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, - add2_y) - item0 = tuple_getitem(res, 0) - item1 = tuple_getitem(res, 1) - item2 = tuple_getitem(res, 2) - return make_tuple(make_tuple(item0, item1, item2)) - - return fns[tag] - - -def test_adam_apply_one_with_decay_rule_cond2(tag): - fns = FnDict() - - @fns - def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): mul0 = mul(input2, mul0_x) mul1 = mul(input0, mul1_x) square0 = square(input0) @@ -146,22 +81,7 @@ def test_adam_apply_one_with_decay_rule_cond2(tag): return make_tuple(add1, add0, sub0) @fns - def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, - add2_y) - item0 = tuple_getitem(res, 0) - item1 = tuple_getitem(res, 1) - item2 = tuple_getitem(res, 2) - return make_tuple(make_tuple(item0, item1, item2)) - - return fns[tag] - - -def test_adam_apply_one_with_decay_rule_cond3(tag): - fns = FnDict() - - @fns - def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): mul0 = mul(mul0_x, input2) mul1 = mul(mul1_x, input0) square0 = square(input0) @@ -179,22 +99,7 @@ def test_adam_apply_one_with_decay_rule_cond3(tag): return make_tuple(add1, add0, sub0) @fns - def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, - add2_y) - item0 = tuple_getitem(res, 0) - item1 = tuple_getitem(res, 1) - item2 = tuple_getitem(res, 2) - return make_tuple(make_tuple(item0, item1, item2)) - - return fns[tag] - - -def test_adam_apply_one_with_decay_rule_cond4(tag): - fns = FnDict() - - @fns - def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): mul0 = mul(mul0_x, input2) mul1 = mul(mul1_x, input0) square0 = square(input0) @@ -212,22 +117,7 @@ def test_adam_apply_one_with_decay_rule_cond4(tag): return make_tuple(add1, add0, sub0) @fns - def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): - res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, - add2_y) - item0 = tuple_getitem(res, 0) - item1 = tuple_getitem(res, 1) - item2 = tuple_getitem(res, 2) - return make_tuple(make_tuple(item0, item1, item2)) - - return fns[tag] - - -def test_adam_apply_one_with_decay_rule_cond5(tag): - fns = FnDict() - - @fns - def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + def before_cond5(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): mul0 = mul(mul0_x, input2) mul1 = mul(mul1_x, input0) square0 = square(input0) @@ -254,3 +144,138 @@ def test_adam_apply_one_with_decay_rule_cond5(tag): return make_tuple(make_tuple(item0, item1, item2)) return fns[tag] + + +def test_adam_apply_one_with_decay_assign_rule(tag): + fns = FnDict() + + @fns + def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(add2_y, sqrt0) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(input4, add3) + sub0 = sub(input3, mul5) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + return make_tuple(add1, add0, depend2) + + @fns + def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(input2, mul0_x) + mul1 = mul(input0, mul1_x) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(input1, mul2_x) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(sqrt0, add2_y) + mul4 = mul(input3, mul4_x) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + return make_tuple(add1, add0, depend2) + + @fns + def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(square0, mul3_x) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(sqrt0, add2_y) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + return make_tuple(add1, add0, depend2) + + @fns + def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(add2_y, sqrt0) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + return make_tuple(add1, add0, depend2) + + @fns + def before_cond5(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(sqrt0, add2_y) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + assign0 = Assign(input3, sub0) + depend0 = F.depend(sub0, assign0) + assign1 = Assign(input2, add0) + depend1 = F.depend(depend0, assign1) + assign2 = Assign(input1, add1) + depend2 = F.depend(depend1, assign2) + return make_tuple(add1, add0, depend2) + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + res = adam_apply_one_with_decay_assign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, + mul4_x, add2_y) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + item2 = tuple_getitem(res, 2) + return make_tuple(make_tuple(item0, item1, item2)) + + return fns[tag]