From cd6e8d65427d452810f7c0424d11ef48379abce3 Mon Sep 17 00:00:00 2001 From: huanghui Date: Wed, 22 Apr 2020 14:43:25 +0800 Subject: [PATCH] fix ReluV2's mask shape in derelu fusion pass --- .../ascend/ascend_backend_optimization.cc | 6 +++++- .../ir_fusion/confusion_mul_grad_fusion.cc | 4 ++++ .../ascend/ir_fusion/derelu_fusion.cc | 20 +++++++++++++++++-- mindspore/ccsrc/utils/utils.h | 2 +- .../confusion_mul_grad_fusion_test.cc | 5 +++++ .../gtest_input/pre_activate/derelu_fusion.py | 2 +- 6 files changed, 34 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index a2d82525e9a..4a2b5def1a8 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -46,6 +46,8 @@ #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" +#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" #include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/optimize_dependence.h" @@ -94,8 +96,10 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc index 6b7f732a6a2..47098379bf6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "session/anf_runtime_algorithm.h" #include "ir/primitive.h" #include "utils/utils.h" @@ -89,6 +90,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons auto reduce_sum = node->cast(); MS_EXCEPTION_IF_NULL(reduce_sum); auto mul1 = reduce_sum->input(1); + if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) { + return nullptr; + } if (IsUsedByOthers(graph, mul1)) { MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; return nullptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc index d5ea315de11..74b63a5b526 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc @@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { MS_EXCEPTION_IF_NULL(new_node); new_node->set_scope(relu->scope()); - // ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector + // ReluV2's 2rd output is mask whose data type is uint8 TypeId mask_dtype = kNumberTypeUInt8; - std::vector mask_shape; + std::vector mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); + if (mask_shape.size() != 4) { + MS_LOG(WARNING) << "relu's infer shape size not equal 4"; + return nullptr; + } + auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); + if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { + mask_shape[1] = (mask_shape[1] + 31) / 32; + mask_shape.push_back(4); + } else { + mask_shape[1] = (mask_shape[1] + 15) / 16; + mask_shape.push_back(2); + } + auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); @@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP MS_EXCEPTION_IF_NULL(relu); auto relu_v2 = CreateReluV2(graph, relu); + if (relu_v2 == nullptr) { + return nullptr; + } std::vector relu_v2_node_outputs; CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 6829a7e888e..b967f0c3552 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -120,7 +120,7 @@ constexpr auto kStreamActiveOpName = "StreamActive"; constexpr auto kAssignAddOpName = "AssignAdd"; constexpr auto kSendOpName = "Send"; constexpr auto kRecvOpName = "Recv"; -constexpr auto kReluV2OpName = "ReluV2"; +constexpr auto kReluV2OpName = "ReLUV2"; constexpr auto kReluGradV2OpName = "ReluGradV2"; // attr key name diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc index e3bf09d2cbd..4b5d38d3757 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc @@ -32,6 +32,11 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); EXPECT_NE(g, nullptr); + auto bert_scope = std::make_shared("bert/encoder"); + for (auto node : TopoSort(g->get_return())) { + node->set_scope(bert_scope); + } + std::vector shp{1, 1, 1, 1}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py index 497975542b8..767f85332f6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py @@ -17,7 +17,7 @@ from mindspore.ops import Primitive relu = P.ReLU() relu_grad = Primitive('ReluGrad') -relu_v2 = Primitive('ReluV2') +relu_v2 = Primitive('ReLUV2') relu_grad_v2 = Primitive('ReluGradV2') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem')