!33693 fix bn_add_relu fusion

Merge pull request !33693 from wuweikang/bn-fix
This commit is contained in:
i-robot 2022-04-29 07:43:52 +00:00 committed by Gitee
commit 260b464928
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 60 additions and 7 deletions

View File

@ -29,13 +29,40 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
const BaseRef BatchNormAddReluFusion::DefinePattern() const { const BaseRef BatchNormAddReluFusion::DefinePattern() const {
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_}); VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_}); VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
return relu; return relu;
} }
AnfNodePtr RemoveNodeFromUpdateState(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &updatestate) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(updatestate);
auto updatestate_cnode = updatestate->cast<CNodePtr>();
auto inputs = updatestate_cnode->inputs();
std::vector<AnfNodePtr> new_inputs;
(void)std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
[node](const AnfNodePtr &input) { return node != input; });
AnfNodePtr new_updatestate = nullptr;
constexpr size_t updatestate_input_size = 3;
// If there are only has one CNode in UpdateState's inputs
// old_updatestate = UpdateState(umonad, cnode1)
// cnode2 = CNode2(..., old_updatestate)
// --> after remove the cnode1, mean that replace old_updatestate by umonad.
// cnode2 = CNode2(..., umonad)
if (new_inputs.size() < updatestate_input_size) {
new_updatestate = updatestate_cnode->input(1);
} else {
new_updatestate = graph->NewCNode(new_inputs);
}
MS_EXCEPTION_IF_NULL(new_updatestate);
new_updatestate->set_scope(updatestate->scope());
new_updatestate->set_abstract(updatestate->abstract());
return new_updatestate;
}
const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const { const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -77,12 +104,29 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
if (shape.back() % kBNChannelMultipleFactor != 0) { if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr; return nullptr;
} }
// replace updatestate(%b, %a) after the BN(%a) being fused with updatestate(%b) to avoid circle in graph
// otherwise circle will be formed like:
// (BN1)->UpdateState2->BN2->BNActivation
// ^ |
// |___________________|
// ^
// |-----> need to be removed
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto user_nodes = manager->node_users()[batch_norm];
for (auto user_node : user_nodes) {
if (common::AnfAlgo::CheckPrimitiveType(user_node.first, prim::kPrimUpdateState)) {
auto new_updatestate = RemoveNodeFromUpdateState(graph, batch_norm, user_node.first);
(void)manager->Replace(user_node.first, new_updatestate);
}
}
auto x = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex0); auto x = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex0);
auto scale = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex1); auto scale = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex1);
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2); auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3); auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4); auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex1); auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex1);
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
@ -90,6 +134,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
MS_EXCEPTION_IF_NULL(umonad);
MS_EXCEPTION_IF_NULL(z); MS_EXCEPTION_IF_NULL(z);
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation); auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
@ -108,8 +153,6 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu); common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm, fused_batch_norm_with_add_relu); manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
return tuple_get_item; return tuple_get_item;

View File

@ -32,6 +32,7 @@ class BatchNormAddReluFusion : public PatternProcessPass {
var_ = std::make_shared<Var>(); var_ = std::make_shared<Var>();
index_ = std::make_shared<Var>(); index_ = std::make_shared<Var>();
z_ = std::make_shared<Var>(); z_ = std::make_shared<Var>();
umonad_ = std::make_shared<Var>();
} }
~BatchNormAddReluFusion() override = default; ~BatchNormAddReluFusion() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
@ -45,6 +46,7 @@ class BatchNormAddReluFusion : public PatternProcessPass {
VarPtr var_; VarPtr var_;
VarPtr index_; VarPtr index_;
VarPtr z_; VarPtr z_;
VarPtr umonad_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -29,7 +29,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
const BaseRef BatchNormReluFusion::DefinePattern() const { const BaseRef BatchNormReluFusion::DefinePattern() const {
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_}); VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get}); VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get});
return relu; return relu;
@ -80,16 +80,18 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2); auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3); auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4); auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(scale); MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
MS_EXCEPTION_IF_NULL(umonad);
auto prim = std::make_shared<Primitive>(kBatchNormWithActivation); auto prim = std::make_shared<Primitive>(kBatchNormWithActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, umonad};
auto fused_batch_norm_with_relu = graph->NewCNode(inputs); auto fused_batch_norm_with_relu = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_relu); MS_EXCEPTION_IF_NULL(fused_batch_norm_with_relu);

View File

@ -29,6 +29,7 @@ class BatchNormReluFusion : public PatternProcessPass {
bias_ = std::make_shared<Var>(); bias_ = std::make_shared<Var>();
mean_ = std::make_shared<Var>(); mean_ = std::make_shared<Var>();
var_ = std::make_shared<Var>(); var_ = std::make_shared<Var>();
umonad_ = std::make_shared<Var>();
index_ = std::make_shared<Var>(); index_ = std::make_shared<Var>();
} }
~BatchNormReluFusion() override = default; ~BatchNormReluFusion() override = default;
@ -41,6 +42,7 @@ class BatchNormReluFusion : public PatternProcessPass {
VarPtr bias_; VarPtr bias_;
VarPtr mean_; VarPtr mean_;
VarPtr var_; VarPtr var_;
VarPtr umonad_;
VarPtr index_; VarPtr index_;
}; };
} // namespace opt } // namespace opt

View File

@ -29,7 +29,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const { const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_}); VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item}); VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
@ -68,6 +68,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2); auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3); auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4); auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex0); auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex0);
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
@ -75,11 +76,12 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
MS_EXCEPTION_IF_NULL(umonad);
MS_EXCEPTION_IF_NULL(z); MS_EXCEPTION_IF_NULL(z);
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation); auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z, umonad};
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu); MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu);

View File

@ -30,6 +30,7 @@ class PostBatchNormAddReluFusion : public PatternProcessPass {
bias_ = std::make_shared<Var>(); bias_ = std::make_shared<Var>();
mean_ = std::make_shared<Var>(); mean_ = std::make_shared<Var>();
var_ = std::make_shared<Var>(); var_ = std::make_shared<Var>();
umonad_ = std::make_shared<Var>();
index_ = std::make_shared<Var>(); index_ = std::make_shared<Var>();
z_ = std::make_shared<Var>(); z_ = std::make_shared<Var>();
} }
@ -43,6 +44,7 @@ class PostBatchNormAddReluFusion : public PatternProcessPass {
VarPtr bias_; VarPtr bias_;
VarPtr mean_; VarPtr mean_;
VarPtr var_; VarPtr var_;
VarPtr umonad_;
VarPtr index_; VarPtr index_;
VarPtr z_; VarPtr z_;
}; };