forked from mindspore-Ecosystem/mindspore
fix bn_add_relu fusion
This commit is contained in:
parent
244e078ad3
commit
c4ed23995a
|
@ -29,13 +29,40 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||
return relu;
|
||||
}
|
||||
|
||||
AnfNodePtr RemoveNodeFromUpdateState(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &updatestate) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(updatestate);
|
||||
auto updatestate_cnode = updatestate->cast<CNodePtr>();
|
||||
auto inputs = updatestate_cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
(void)std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
|
||||
[node](const AnfNodePtr &input) { return node != input; });
|
||||
AnfNodePtr new_updatestate = nullptr;
|
||||
constexpr size_t updatestate_input_size = 3;
|
||||
// If there are only has one CNode in UpdateState's inputs
|
||||
// old_updatestate = UpdateState(umonad, cnode1)
|
||||
// cnode2 = CNode2(..., old_updatestate)
|
||||
// --> after remove the cnode1, mean that replace old_updatestate by umonad.
|
||||
// cnode2 = CNode2(..., umonad)
|
||||
if (new_inputs.size() < updatestate_input_size) {
|
||||
new_updatestate = updatestate_cnode->input(1);
|
||||
} else {
|
||||
new_updatestate = graph->NewCNode(new_inputs);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_updatestate);
|
||||
new_updatestate->set_scope(updatestate->scope());
|
||||
new_updatestate->set_abstract(updatestate->abstract());
|
||||
return new_updatestate;
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -77,12 +104,29 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
// replace updatestate(%b, %a) after the BN(%a) being fused with updatestate(%b) to avoid circle in graph
|
||||
// otherwise circle will be formed like:
|
||||
// (BN1)->UpdateState2->BN2->BNActivation
|
||||
// ^ |
|
||||
// |___________________|
|
||||
// ^
|
||||
// |-----> need to be removed
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto user_nodes = manager->node_users()[batch_norm];
|
||||
for (auto user_node : user_nodes) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(user_node.first, prim::kPrimUpdateState)) {
|
||||
auto new_updatestate = RemoveNodeFromUpdateState(graph, batch_norm, user_node.first);
|
||||
(void)manager->Replace(user_node.first, new_updatestate);
|
||||
}
|
||||
}
|
||||
|
||||
auto x = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex0);
|
||||
auto scale = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex1);
|
||||
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
||||
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
||||
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
||||
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
|
||||
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex1);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -90,6 +134,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
MS_EXCEPTION_IF_NULL(bias);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
MS_EXCEPTION_IF_NULL(umonad);
|
||||
MS_EXCEPTION_IF_NULL(z);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
||||
|
@ -108,8 +153,6 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||
return tuple_get_item;
|
||||
|
|
|
@ -32,6 +32,7 @@ class BatchNormAddReluFusion : public PatternProcessPass {
|
|||
var_ = std::make_shared<Var>();
|
||||
index_ = std::make_shared<Var>();
|
||||
z_ = std::make_shared<Var>();
|
||||
umonad_ = std::make_shared<Var>();
|
||||
}
|
||||
~BatchNormAddReluFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -45,6 +46,7 @@ class BatchNormAddReluFusion : public PatternProcessPass {
|
|||
VarPtr var_;
|
||||
VarPtr index_;
|
||||
VarPtr z_;
|
||||
VarPtr umonad_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef BatchNormReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
|
||||
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get});
|
||||
return relu;
|
||||
|
@ -80,16 +80,18 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
|||
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
||||
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
||||
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
||||
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
MS_EXCEPTION_IF_NULL(umonad);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormWithActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, umonad};
|
||||
auto fused_batch_norm_with_relu = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_relu);
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ class BatchNormReluFusion : public PatternProcessPass {
|
|||
bias_ = std::make_shared<Var>();
|
||||
mean_ = std::make_shared<Var>();
|
||||
var_ = std::make_shared<Var>();
|
||||
umonad_ = std::make_shared<Var>();
|
||||
index_ = std::make_shared<Var>();
|
||||
}
|
||||
~BatchNormReluFusion() override = default;
|
||||
|
@ -41,6 +42,7 @@ class BatchNormReluFusion : public PatternProcessPass {
|
|||
VarPtr bias_;
|
||||
VarPtr mean_;
|
||||
VarPtr var_;
|
||||
VarPtr umonad_;
|
||||
VarPtr index_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||
|
@ -68,6 +68,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
|||
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
||||
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
||||
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
||||
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
|
||||
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex0);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -75,11 +76,12 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
|||
MS_EXCEPTION_IF_NULL(bias);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
MS_EXCEPTION_IF_NULL(umonad);
|
||||
MS_EXCEPTION_IF_NULL(z);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z, umonad};
|
||||
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu);
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ class PostBatchNormAddReluFusion : public PatternProcessPass {
|
|||
bias_ = std::make_shared<Var>();
|
||||
mean_ = std::make_shared<Var>();
|
||||
var_ = std::make_shared<Var>();
|
||||
umonad_ = std::make_shared<Var>();
|
||||
index_ = std::make_shared<Var>();
|
||||
z_ = std::make_shared<Var>();
|
||||
}
|
||||
|
@ -43,6 +44,7 @@ class PostBatchNormAddReluFusion : public PatternProcessPass {
|
|||
VarPtr bias_;
|
||||
VarPtr mean_;
|
||||
VarPtr var_;
|
||||
VarPtr umonad_;
|
||||
VarPtr index_;
|
||||
VarPtr z_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue