From b8e737f66adfa8502237d2b4c3eb312dbb508f63 Mon Sep 17 00:00:00 2001 From: huanghui Date: Wed, 2 Sep 2020 10:30:22 +0800 Subject: [PATCH] fix run error when there is a Depend or ControlDepend on BatchNorm --- .../backend/optimizer/ascend/ascend_helper.cc | 1 - .../ascend/ir_fusion/batchnorm_to_bninfer.cc | 4 +++- .../ir_fusion/batchnormgrad_to_bninfergrad.cc | 22 ++++++++++--------- .../ccsrc/backend/optimizer/common/helper.cc | 21 ++++++++++++++++++ .../ccsrc/backend/optimizer/common/helper.h | 3 +++ .../optimizer/pass/add_atomic_clean.cc | 1 - .../optimizer/pass/communication_op_fusion.cc | 2 +- .../pass/const_to_attr_strided_slice_grad.h | 2 +- 8 files changed, 41 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 1a81c828439..b4eb70c7269 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -79,7 +79,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt : AnfAlgo::GetOutputInferShape(input_node, insert_index); bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) : trans::IsNeedPadding(input_format, input_node_out_shape.size()); - if (!need_padding) { // don't need padding insert transdata only trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc index c9091112daf..2a952e79862 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc @@ -121,7 +121,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf if (!NeedFusion(graph, node, &batchnorm)) { return nullptr; } - return CreateBNInfer(graph, batchnorm, node); + auto bn_infer = CreateBNInfer(graph, batchnorm, node); + TransferDepend(batchnorm, graph, bn_infer); + return bn_infer; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc index 5eb8a86a041..ca2913b5b9f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc @@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad return true; } -bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { +bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); auto tuple_getitem = node->cast(); @@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat return false; } - AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(batchnormgrad_anf); - MS_EXCEPTION_IF_NULL(batchnormgrad); - *batchnormgrad = batchnormgrad_anf->cast(); - MS_EXCEPTION_IF_NULL(*batchnormgrad); - return CheckBatchNormGrad(graph, *batchnormgrad); + AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(batchnorm_grad_anf); + MS_EXCEPTION_IF_NULL(batchnorm_grad); + *batchnorm_grad = batchnorm_grad_anf->cast(); + MS_EXCEPTION_IF_NULL(*batchnorm_grad); + return CheckBatchNormGrad(graph, *batchnorm_grad); } } // namespace @@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); - CNodePtr batchnormgrad = nullptr; - if (!NeedFusion(graph, node, &batchnormgrad)) { + CNodePtr batchnorm_grad = nullptr; + if (!NeedFusion(graph, node, &batchnorm_grad)) { return nullptr; } - return CreateBNInferGrad(graph, batchnormgrad, node); + auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node); + TransferDepend(batchnorm_grad, graph, bn_infer_grad); + return bn_infer_grad; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index f9770fc9497..8b0f2e5f369 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -872,5 +872,26 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { return new_value_node; } +void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + // find BatchNorm's output which is a Depend or ControlDepend + for (const auto &node_index : manager->node_users()[old_node]) { + AnfNodePtr output = node_index.first; + size_t index = IntToSize(node_index.second); + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + auto control_depend = output->cast(); + MS_EXCEPTION_IF_NULL(control_depend); + control_depend->set_input(index, new_node); + } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { + auto depend = output->cast(); + MS_EXCEPTION_IF_NULL(depend); + depend->set_input(index, new_node); + } + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 87731b54454..9d64a88d020 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -203,6 +203,9 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set &suppor // Create a new value node of func graph,not kernel graph ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); + +// Transfer depend or control_depend to the new node +void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc index 4e6514a877d..3818e8c1718 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc @@ -27,7 +27,6 @@ namespace mindspore { namespace opt { namespace { - static std::vector g_output_idx; bool HasAtomic(const AnfNodePtr &input) { diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index f3ba66726a9..9994b972db3 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -98,7 +98,7 @@ void CheckInputs(const std::vector &fusion_inputs) { } } -bool CheckSegments(size_t segments, size_t communication_op_node_size, std::vector *segment_index) { +bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector *segment_index) { MS_EXCEPTION_IF_NULL(segment_index); if (segments >= communication_op_node_size) { MS_LOG(INFO) << "fusion not changed: segment_num=" << segments diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h index b2faa0437a0..ef0a6f8ae9f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h @@ -24,7 +24,7 @@ namespace opt { class ConstToAttrStridedSliceGradPass : public PatternProcessPass { public: explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) - : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} + : PatternProcessPass("const_to_attr_strided_slice_grad", multigraph) {} ~ConstToAttrStridedSliceGradPass() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;