From 74fb1b56dfe8d9ff3bb5a0947b18c67f462384b1 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Fri, 15 Jan 2021 16:03:47 +0800 Subject: [PATCH] modify dropout unify mindir pass --- .../ascend/mindir/dropout_unify_mindir.cc | 158 +++++++++++------- .../ascend/mindir/dropout_unify_mindir.h | 39 +++-- .../ccsrc/backend/session/ascend_session.cc | 8 +- 3 files changed, 121 insertions(+), 84 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc index 4145ff63a8d..01af6f5ee45 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc @@ -155,59 +155,20 @@ std::vector CalDropoutGenMaskOutput(const std::vector &shape) MS_LOG(INFO) << "Output_size: " << ret; return {ret}; } + +bool NeedUpdate(const CNodePtr &getitem_cnode) { + MS_EXCEPTION_IF_NULL(getitem_cnode); + MS_EXCEPTION_IF_NULL(getitem_cnode->input(2)); + auto index_vnode = getitem_cnode->input(2)->cast(); + MS_EXCEPTION_IF_NULL(index_vnode); + auto index_value = index_vnode->value(); + MS_EXCEPTION_IF_NULL(index_value); + auto index = GetValue(index_value); + return index == 1; +} } // namespace -const BaseRef DropoutUnifyMindIR::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Y = std::make_shared(); - auto prim = std::make_shared(kDropoutOpName); - auto ref = VectorRef({prim, X}); - return VectorRef({prim::kPrimTupleGetItem, ref, Y}); -} - -const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto tuple_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_cnode); - auto dropout_node = tuple_cnode->input(1); - MS_EXCEPTION_IF_NULL(dropout_node); - - auto inputx_type_id = GetInputXDataType(dropout_node); - auto inputx_shape = GetInputXShape(dropout_node); - auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); - auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); - - // CreateDropoutGenMask - std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), - shape_value, keep_prob_value}; - CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); - MS_EXCEPTION_IF_NULL(dropout_gen_mask); - AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); - auto output_shape = CalDropoutGenMaskOutput(inputx_shape); - auto gen_mask_abstract = std::make_shared(kUInt8, output_shape); - MS_EXCEPTION_IF_NULL(gen_mask_abstract); - dropout_gen_mask->set_abstract(gen_mask_abstract); - dropout_gen_mask->set_scope(node->scope()); - - // CreateDropoutDoMask - MS_EXCEPTION_IF_NULL(dropout_node); - auto dropout_cnode = dropout_node->cast(); - MS_EXCEPTION_IF_NULL(dropout_cnode); - auto dropout_input = dropout_cnode->input(1); - std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), - dropout_input, dropout_gen_mask, keep_prob_value}; - auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); - MS_EXCEPTION_IF_NULL(dropout_do_mask); - auto do_mask_abstract = std::make_shared(TypeIdToType(inputx_type_id), inputx_shape); - dropout_do_mask->set_abstract(do_mask_abstract); - dropout_do_mask->set_scope(node->scope()); - - return dropout_do_mask; -} - -const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { +const BaseRef DropoutAndDropoutGradUnifyMindIR::DefinePattern() const { VarPtr X = std::make_shared(); VarPtr Y = std::make_shared(); auto dropout_prim = std::make_shared(kDropoutOpName); @@ -220,8 +181,8 @@ const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { return VectorRef({dropout_grad_prim, grad_input_, ref1}); } -const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { +const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); auto dropout_grad_cnode = node->cast(); @@ -312,13 +273,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, return dropout_do_mask; } -const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const { +const BaseRef DropoutUnifyMindIR0::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + auto prim = std::make_shared(kDropoutOpName); + auto ref = VectorRef({prim, X}); + return VectorRef({prim::kPrimTupleGetItem, ref, Y}); +} + +const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto tuple_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_cnode); + if (!NeedUpdate(tuple_cnode)) { + return nullptr; + } + + auto dropout_node = tuple_cnode->input(1); + MS_EXCEPTION_IF_NULL(dropout_node); + auto inputx_type_id = GetInputXDataType(dropout_node); + auto inputx_shape = GetInputXShape(dropout_node); + auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); + + // CreateDropoutGenMask + std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), + shape_value, keep_prob_value}; + CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); + auto output_shape = CalDropoutGenMaskOutput(inputx_shape); + auto gen_mask_abstract = std::make_shared(kUInt8, output_shape); + MS_EXCEPTION_IF_NULL(gen_mask_abstract); + dropout_gen_mask->set_abstract(gen_mask_abstract); + dropout_gen_mask->set_scope(node->scope()); + + // CreateDropoutDoMask + MS_EXCEPTION_IF_NULL(dropout_node); + auto dropout_cnode = dropout_node->cast(); + MS_EXCEPTION_IF_NULL(dropout_cnode); + auto dropout_input = dropout_cnode->input(1); + std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + dropout_input, dropout_gen_mask, keep_prob_value}; + auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + auto do_mask_abstract = std::make_shared(TypeIdToType(inputx_type_id), inputx_shape); + dropout_do_mask->set_abstract(do_mask_abstract); + dropout_do_mask->set_scope(node->scope()); + + // make tuple to replace dropout + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(dropout_node, make_tuple); + + tuple_cnode->set_abstract(gen_mask_abstract); + return tuple_cnode; +} + +const BaseRef DropoutUnifyMindIR1::DefinePattern() const { VarPtr X = std::make_shared(); return VectorRef({prim::kPrimDropout, X}); } -const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { +const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); auto dropout_node = node->cast(); @@ -359,15 +381,15 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr return make_tuple; } -const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const { +const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { VarPtr X = std::make_shared(); VarPtr Y = std::make_shared(); auto dropout_grad_prim = std::make_shared(kDropoutGradOpName); return VectorRef({dropout_grad_prim, X, Y}); } -const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { +const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); auto dropout_grad_cnode = node->cast(); @@ -377,9 +399,25 @@ const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &fun auto grad_input_shape = GetInputXShape(dropout_grad_cnode); auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); + // DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter + // in that scene, need to be updated. + auto mask_input = dropout_grad_cnode->input(2); + if (mask_input->isa()) { + // update abstract + auto mask_abstract = mask_input->abstract(); + MS_EXCEPTION_IF_NULL(mask_abstract); + auto mask_shape = CalDropoutGenMaskOutput(grad_input_shape); + mask_abstract = std::make_shared(kUInt8, mask_shape); + mask_input->set_abstract(mask_abstract); + // update kernel info + auto kernel_build_info_builder = std::make_shared(); + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{kNumberTypeUInt8}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get()); + } + // CreateDropoutDoMask auto grad_input = dropout_grad_cnode->input(1); - auto mask_input = dropout_grad_cnode->input(2); std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), grad_input, mask_input, keep_prob_value}; auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h index 3a078a1d312..046d9ea48cc 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h @@ -21,21 +21,13 @@ namespace mindspore { namespace opt { -class DropoutUnifyMindIR : public PatternProcessPass { +class DropoutAndDropoutGradUnifyMindIR : public PatternProcessPass { public: - explicit DropoutUnifyMindIR(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir", multigraph) {} - ~DropoutUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; - -class DropoutGradUnifyMindIR : public PatternProcessPass { - public: - explicit DropoutGradUnifyMindIR(bool multigraph = true) - : PatternProcessPass("dropout_grad_unify_mindir", multigraph) { + explicit DropoutAndDropoutGradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("dropout_and_dropoutgrad_unify_mindir", multigraph) { grad_input_ = std::make_shared(); } - ~DropoutGradUnifyMindIR() override = default; + ~DropoutAndDropoutGradUnifyMindIR() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; @@ -43,20 +35,27 @@ class DropoutGradUnifyMindIR : public PatternProcessPass { VarPtr grad_input_; }; -class DropoutUnifyMindIRPynative : public PatternProcessPass { +class DropoutUnifyMindIR0 : public PatternProcessPass { public: - explicit DropoutUnifyMindIRPynative(bool multigraph = true) - : PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {} - ~DropoutUnifyMindIRPynative() override = default; + explicit DropoutUnifyMindIR0(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir0", multigraph) {} + ~DropoutUnifyMindIR0() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; -class DropoutGradUnifyMindIRPynative : public PatternProcessPass { +class DropoutUnifyMindIR1 : public PatternProcessPass { public: - explicit DropoutGradUnifyMindIRPynative(bool multigraph = true) - : PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {} - ~DropoutGradUnifyMindIRPynative() override = default; + explicit DropoutUnifyMindIR1(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir1", multigraph) {} + ~DropoutUnifyMindIR1() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class DropoutGradUnifyMindIR : public PatternProcessPass { + public: + explicit DropoutGradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("dropoutgrad_unify_mindir", multigraph) {} + ~DropoutGradUnifyMindIR() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 7789a62a3c8..a434600503c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -236,17 +236,17 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { - unify_mindir_pm->AddPass(std::make_shared()); - unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); } else { - unify_mindir_pm->AddPass(std::make_shared()); - unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); } + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); optimizer->AddPassManager(unify_mindir_pm); (void)optimizer->Optimize(graph);