forked from mindspore-Ecosystem/mindspore
!1156 Add batch_norm grad infer fission pass
Merge pull request !1156 from YuJianfeng/master
This commit is contained in:
commit
c786596641
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
||||
#include <vector>
|
||||
#include "pre_activate/common/helper.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBatchNormGradInferOutputNum = 3;
|
||||
bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(node) == manager->node_users().end()) {
|
||||
MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs";
|
||||
return false;
|
||||
}
|
||||
for (const auto &node_index : manager->node_users()[node]) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
|
||||
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int index = GetValue<int>(value_node->value());
|
||||
if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) {
|
||||
MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
// Set inputs
|
||||
auto iter_input0 = (*equiv).find(input0_var_);
|
||||
if (iter_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched.";
|
||||
}
|
||||
auto iter_input2 = (*equiv).find(input2_var_);
|
||||
if (iter_input2 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched.";
|
||||
}
|
||||
auto iter_input4 = (*equiv).find(input4_var_);
|
||||
if (iter_input4 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched.";
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_infer_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNInferGradOpName)), utils::cast<AnfNodePtr>(iter_input0->second),
|
||||
utils::cast<AnfNodePtr>(iter_input2->second), utils::cast<AnfNodePtr>(iter_input4->second)};
|
||||
auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_infer_grad);
|
||||
// Set abstract, the output of new node is taking the place of the 0th output of bn_grad.
|
||||
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple);
|
||||
if (bn_grad_abstract_tuple->elements().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty";
|
||||
}
|
||||
bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]);
|
||||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad);
|
||||
bn_infer_grad->set_scope(bn_grad->scope());
|
||||
return bn_infer_grad;
|
||||
}
|
||||
|
||||
AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &bn_grad,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
// Set inputs
|
||||
auto iter_input0 = (*equiv).find(input0_var_);
|
||||
if (iter_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched.";
|
||||
}
|
||||
auto iter_input1 = (*equiv).find(input1_var_);
|
||||
if (iter_input1 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched.";
|
||||
}
|
||||
auto iter_input3 = (*equiv).find(input3_var_);
|
||||
if (iter_input3 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched.";
|
||||
}
|
||||
auto iter_input4 = (*equiv).find(input4_var_);
|
||||
if (iter_input4 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched.";
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_training_update_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)),
|
||||
utils::cast<AnfNodePtr>(iter_input0->second), utils::cast<AnfNodePtr>(iter_input1->second),
|
||||
utils::cast<AnfNodePtr>(iter_input3->second), utils::cast<AnfNodePtr>(iter_input4->second)};
|
||||
auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_grad);
|
||||
// Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad.
|
||||
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple);
|
||||
if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3";
|
||||
}
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[1],
|
||||
bn_grad_abstract_tuple->elements()[2]};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
bn_training_update_grad->set_abstract(abstract_tuple);
|
||||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad);
|
||||
bn_training_update_grad->set_scope(bn_grad->scope());
|
||||
return bn_training_update_grad;
|
||||
}
|
||||
|
||||
const BaseRef BatchNormGradInferFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast<CNodePtr>())) {
|
||||
MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed";
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetNodeAttr<bool>(node, kAttrIsTraining)) {
|
||||
MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change";
|
||||
return nullptr;
|
||||
}
|
||||
if (!CheckOutputsIndex(func_graph, node)) {
|
||||
MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change";
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv);
|
||||
AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv);
|
||||
std::vector<AnfNodePtr> bn_training_update_grad_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum,
|
||||
&bn_training_update_grad_outputs);
|
||||
if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be "
|
||||
<< kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size();
|
||||
}
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad,
|
||||
bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BatchNormGradInferFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit BatchNormGradInferFission(bool multigraph = true)
|
||||
: PatternProcessPass("batch_norm_grad_infer_fission", multigraph),
|
||||
input0_var_(std::make_shared<Var>()),
|
||||
input1_var_(std::make_shared<Var>()),
|
||||
input2_var_(std::make_shared<Var>()),
|
||||
input3_var_(std::make_shared<Var>()),
|
||||
input4_var_(std::make_shared<Var>()) {}
|
||||
~BatchNormGradInferFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const;
|
||||
AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad,
|
||||
const EquivPtr &equiv) const;
|
||||
|
||||
VarPtr input0_var_;
|
||||
VarPtr input1_var_;
|
||||
VarPtr input2_var_;
|
||||
VarPtr input3_var_;
|
||||
VarPtr input4_var_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_
|
|
@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D
|
|||
constexpr auto kLabelSetOpName = "LabelSet";
|
||||
constexpr auto kLabelSwitchOpName = "LabelSwitch";
|
||||
constexpr auto kLabelGotoOpName = "LabelGoto";
|
||||
constexpr auto kBNInferGradOpName = "BNInferGrad";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestHWBatchNormGradInferFission : public BackendCommon {
|
||||
public:
|
||||
TestHWBatchNormGradInferFission()
|
||||
: get_py_fun_("gtest_input.pre_activate.batch_norm_grad_infer_fission_test", true) {}
|
||||
~TestHWBatchNormGradInferFission() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_fission) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission1) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_is_training");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission2) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_output3_not_null");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import Primitive
|
||||
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
BatchNormGradTraining = G.BatchNormGrad(is_training=True)
|
||||
BatchNormGradInfer = G.BatchNormGrad(is_training=False)
|
||||
BNInferGrad = Primitive('BNInferGrad')
|
||||
BNTrainingUpdateGrad = Primitive('BNTrainingUpdateGrad')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_batch_norm_grad_infer_fission(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4):
|
||||
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_is_training(input0, input1, input2, input3, input4):
|
||||
batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_output3_not_null(input0, input1, input2, input3, input4):
|
||||
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4):
|
||||
bn_infer_grad = BNInferGrad(input0, input2, input4)
|
||||
bn_training_update_grad = BNTrainingUpdateGrad(input0, input1, input3, input4)
|
||||
outputs = make_tuple(bn_infer_grad, tuple_getitem(bn_training_update_grad, 0),
|
||||
tuple_getitem(bn_training_update_grad, 1))
|
||||
new_outputs = make_tuple(tuple_getitem(outputs, 0), tuple_getitem(outputs, 1), tuple_getitem(outputs, 2))
|
||||
output = tuple_getitem(new_outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue