forked from mindspore-Ecosystem/mindspore
!1694 Check the input size of BatchNorm before fission in bert
Merge pull request !1694 from YuJianfeng/master
This commit is contained in:
commit
90d98aa6ec
|
@ -149,8 +149,17 @@ const BaseRef BatchNormBertFission::DefinePattern() const {
|
|||
const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<AnfNodePtr> bn_outputs;
|
||||
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
|
||||
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() != kBatchNormRealInputNum + 1) {
|
||||
MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum
|
||||
<< ". The node should not be changed";
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node);
|
||||
|
|
|
@ -28,7 +28,47 @@ class TestHWBatchNormBertFission : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
|
||||
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fission) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int> shp_y{64};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
args_spec_list.push_back(y_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
auto ret = kg->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
auto make_tuple0 = ret->input(1);
|
||||
EXPECT_NE(make_tuple0, nullptr);
|
||||
auto tuple_getitem0 = make_tuple0->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(tuple_getitem0, nullptr);
|
||||
auto make_tuple1 = tuple_getitem0->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(make_tuple1, nullptr);
|
||||
auto tuple_getitem1 = make_tuple1->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(tuple_getitem1, nullptr);
|
||||
auto bn = tuple_getitem1->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(bn, nullptr);
|
||||
auto bn_cnode = bn->cast<CNodePtr>();
|
||||
EXPECT_NE(bn_cnode, nullptr);
|
||||
auto inputs = bn_cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs(inputs.begin(), inputs.begin() + 4);
|
||||
bn_cnode->set_inputs(new_inputs);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BatchNormBertFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{32, 64, 112, 112};
|
||||
|
@ -47,8 +87,7 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
|
|||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue