diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc new file mode 100644 index 00000000000..e1399281343 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include +#include "pre_activate/common/helper.h" +#include "session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBatchNormGradInferOutputNum = 3; +bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(node) == manager->node_users().end()) { + MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs"; + return false; + } + for (const auto &node_index : manager->node_users()[node]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + auto tuple_getiterm_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); + auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) { + MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change"; + return false; + } + } + return true; +} +} // namespace + +AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_grad); + MS_EXCEPTION_IF_NULL(equiv); + // Set inputs + auto iter_input0 = (*equiv).find(input0_var_); + if (iter_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; + } + auto iter_input2 = (*equiv).find(input2_var_); + if (iter_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched."; + } + auto iter_input4 = (*equiv).find(input4_var_); + if (iter_input4 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; + } + std::vector bn_infer_grad_inputs = { + NewValueNode(std::make_shared(kBNInferGradOpName)), utils::cast(iter_input0->second), + utils::cast(iter_input2->second), utils::cast(iter_input4->second)}; + auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_infer_grad); + // Set abstract, the output of new node is taking the place of the 0th output of bn_grad. + auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); + MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); + if (bn_grad_abstract_tuple->elements().empty()) { + MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty"; + } + bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad); + bn_infer_grad->set_scope(bn_grad->scope()); + return bn_infer_grad; +} + +AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, + const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_grad); + MS_EXCEPTION_IF_NULL(equiv); + // Set inputs + auto iter_input0 = (*equiv).find(input0_var_); + if (iter_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; + } + auto iter_input1 = (*equiv).find(input1_var_); + if (iter_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched."; + } + auto iter_input3 = (*equiv).find(input3_var_); + if (iter_input3 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched."; + } + auto iter_input4 = (*equiv).find(input4_var_); + if (iter_input4 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; + } + std::vector bn_training_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), + utils::cast(iter_input0->second), utils::cast(iter_input1->second), + utils::cast(iter_input3->second), utils::cast(iter_input4->second)}; + auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_grad); + // Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad. + auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); + MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); + if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) { + MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"; + } + std::vector abstract_list{bn_grad_abstract_tuple->elements()[1], + bn_grad_abstract_tuple->elements()[2]}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update_grad->set_abstract(abstract_tuple); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad); + bn_training_update_grad->set_scope(bn_grad->scope()); + return bn_training_update_grad; +} + +const BaseRef BatchNormGradInferFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs}); +} + +const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast())) { + MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed"; + return nullptr; + } + if (AnfAlgo::GetNodeAttr(node, kAttrIsTraining)) { + MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change"; + return nullptr; + } + if (!CheckOutputsIndex(func_graph, node)) { + MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change"; + return nullptr; + } + AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv); + AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv); + std::vector bn_training_update_grad_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum, + &bn_training_update_grad_outputs); + if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be " + << kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size(); + } + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad, + bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h new file mode 100644 index 00000000000..a8eefdaa852 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ + +#include +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormGradInferFission : public PatternProcessPass { + public: + explicit BatchNormGradInferFission(bool multigraph = true) + : PatternProcessPass("batch_norm_grad_infer_fission", multigraph), + input0_var_(std::make_shared()), + input1_var_(std::make_shared()), + input2_var_(std::make_shared()), + input3_var_(std::make_shared()), + input4_var_(std::make_shared()) {} + ~BatchNormGradInferFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const; + AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const; + + VarPtr input0_var_; + VarPtr input1_var_; + VarPtr input2_var_; + VarPtr input3_var_; + VarPtr input4_var_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5b8a8b178e3..c85a1d27457 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D constexpr auto kLabelSetOpName = "LabelSet"; constexpr auto kLabelSwitchOpName = "LabelSwitch"; constexpr auto kLabelGotoOpName = "LabelGoto"; +constexpr auto kBNInferGradOpName = "BNInferGrad"; // attr key name constexpr auto kAttrInputNames = "input_names"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc new file mode 100644 index 00000000000..ea4a5c0d5d5 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" + +namespace mindspore { +namespace opt { +class TestHWBatchNormGradInferFission : public BackendCommon { + public: + TestHWBatchNormGradInferFission() + : get_py_fun_("gtest_input.pre_activate.batch_norm_grad_infer_fission_test", true) {} + ~TestHWBatchNormGradInferFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp_x{32, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 5; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission1) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_is_training"); + EXPECT_NE(g, nullptr); + std::vector shp_x{32, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 5; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} + +TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_output3_not_null"); + EXPECT_NE(g, nullptr); + std::vector shp_x{32, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 5; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/batch_norm_grad_infer_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/batch_norm_grad_infer_fission_test.py new file mode 100644 index 00000000000..6d63dd24da2 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/batch_norm_grad_infer_fission_test.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import Primitive + +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +BatchNormGradTraining = G.BatchNormGrad(is_training=True) +BatchNormGradInfer = G.BatchNormGrad(is_training=False) +BNInferGrad = Primitive('BNInferGrad') +BNTrainingUpdateGrad = Primitive('BNTrainingUpdateGrad') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_batch_norm_grad_infer_fission(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4): + batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4) + outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_is_training(input0, input1, input2, input3, input4): + batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4) + outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) + output = tuple_getitem(outputs, 0) + return output + + @fns + def before_output3_not_null(input0, input1, input2, input3, input4): + batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4) + outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 3)) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4): + bn_infer_grad = BNInferGrad(input0, input2, input4) + bn_training_update_grad = BNTrainingUpdateGrad(input0, input1, input3, input4) + outputs = make_tuple(bn_infer_grad, tuple_getitem(bn_training_update_grad, 0), + tuple_getitem(bn_training_update_grad, 1)) + new_outputs = make_tuple(tuple_getitem(outputs, 0), tuple_getitem(outputs, 1), tuple_getitem(outputs, 2)) + output = tuple_getitem(new_outputs, 0) + return make_tuple(output) + + return fns[tag]