forked from mindspore-Ecosystem/mindspore
!33693 fix bn_add_relu fusion
Merge pull request !33693 from wuweikang/bn-fix
This commit is contained in:
commit
260b464928
|
@ -29,13 +29,40 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
||||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
|
||||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
|
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
|
||||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||||
return relu;
|
return relu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr RemoveNodeFromUpdateState(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &updatestate) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(updatestate);
|
||||||
|
auto updatestate_cnode = updatestate->cast<CNodePtr>();
|
||||||
|
auto inputs = updatestate_cnode->inputs();
|
||||||
|
std::vector<AnfNodePtr> new_inputs;
|
||||||
|
(void)std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
|
||||||
|
[node](const AnfNodePtr &input) { return node != input; });
|
||||||
|
AnfNodePtr new_updatestate = nullptr;
|
||||||
|
constexpr size_t updatestate_input_size = 3;
|
||||||
|
// If there are only has one CNode in UpdateState's inputs
|
||||||
|
// old_updatestate = UpdateState(umonad, cnode1)
|
||||||
|
// cnode2 = CNode2(..., old_updatestate)
|
||||||
|
// --> after remove the cnode1, mean that replace old_updatestate by umonad.
|
||||||
|
// cnode2 = CNode2(..., umonad)
|
||||||
|
if (new_inputs.size() < updatestate_input_size) {
|
||||||
|
new_updatestate = updatestate_cnode->input(1);
|
||||||
|
} else {
|
||||||
|
new_updatestate = graph->NewCNode(new_inputs);
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(new_updatestate);
|
||||||
|
new_updatestate->set_scope(updatestate->scope());
|
||||||
|
new_updatestate->set_abstract(updatestate->abstract());
|
||||||
|
return new_updatestate;
|
||||||
|
}
|
||||||
|
|
||||||
const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
@ -77,12 +104,29 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
||||||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
// replace updatestate(%b, %a) after the BN(%a) being fused with updatestate(%b) to avoid circle in graph
|
||||||
|
// otherwise circle will be formed like:
|
||||||
|
// (BN1)->UpdateState2->BN2->BNActivation
|
||||||
|
// ^ |
|
||||||
|
// |___________________|
|
||||||
|
// ^
|
||||||
|
// |-----> need to be removed
|
||||||
|
auto manager = graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto user_nodes = manager->node_users()[batch_norm];
|
||||||
|
for (auto user_node : user_nodes) {
|
||||||
|
if (common::AnfAlgo::CheckPrimitiveType(user_node.first, prim::kPrimUpdateState)) {
|
||||||
|
auto new_updatestate = RemoveNodeFromUpdateState(graph, batch_norm, user_node.first);
|
||||||
|
(void)manager->Replace(user_node.first, new_updatestate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto x = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex0);
|
auto x = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex0);
|
||||||
auto scale = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex1);
|
auto scale = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex1);
|
||||||
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
||||||
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
||||||
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
||||||
|
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
|
||||||
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex1);
|
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex1);
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
@ -90,6 +134,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
||||||
MS_EXCEPTION_IF_NULL(bias);
|
MS_EXCEPTION_IF_NULL(bias);
|
||||||
MS_EXCEPTION_IF_NULL(mean);
|
MS_EXCEPTION_IF_NULL(mean);
|
||||||
MS_EXCEPTION_IF_NULL(var);
|
MS_EXCEPTION_IF_NULL(var);
|
||||||
|
MS_EXCEPTION_IF_NULL(umonad);
|
||||||
MS_EXCEPTION_IF_NULL(z);
|
MS_EXCEPTION_IF_NULL(z);
|
||||||
|
|
||||||
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
||||||
|
@ -108,8 +153,6 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
||||||
|
|
||||||
auto manager = graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
|
manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
|
||||||
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||||
return tuple_get_item;
|
return tuple_get_item;
|
||||||
|
|
|
@ -32,6 +32,7 @@ class BatchNormAddReluFusion : public PatternProcessPass {
|
||||||
var_ = std::make_shared<Var>();
|
var_ = std::make_shared<Var>();
|
||||||
index_ = std::make_shared<Var>();
|
index_ = std::make_shared<Var>();
|
||||||
z_ = std::make_shared<Var>();
|
z_ = std::make_shared<Var>();
|
||||||
|
umonad_ = std::make_shared<Var>();
|
||||||
}
|
}
|
||||||
~BatchNormAddReluFusion() override = default;
|
~BatchNormAddReluFusion() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
|
@ -45,6 +46,7 @@ class BatchNormAddReluFusion : public PatternProcessPass {
|
||||||
VarPtr var_;
|
VarPtr var_;
|
||||||
VarPtr index_;
|
VarPtr index_;
|
||||||
VarPtr z_;
|
VarPtr z_;
|
||||||
|
VarPtr umonad_;
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
const BaseRef BatchNormReluFusion::DefinePattern() const {
|
const BaseRef BatchNormReluFusion::DefinePattern() const {
|
||||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
|
||||||
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||||
VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get});
|
VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get});
|
||||||
return relu;
|
return relu;
|
||||||
|
@ -80,16 +80,18 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
||||||
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
||||||
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
||||||
|
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
MS_EXCEPTION_IF_NULL(scale);
|
MS_EXCEPTION_IF_NULL(scale);
|
||||||
MS_EXCEPTION_IF_NULL(bias);
|
MS_EXCEPTION_IF_NULL(bias);
|
||||||
MS_EXCEPTION_IF_NULL(mean);
|
MS_EXCEPTION_IF_NULL(mean);
|
||||||
MS_EXCEPTION_IF_NULL(var);
|
MS_EXCEPTION_IF_NULL(var);
|
||||||
|
MS_EXCEPTION_IF_NULL(umonad);
|
||||||
|
|
||||||
auto prim = std::make_shared<Primitive>(kBatchNormWithActivation);
|
auto prim = std::make_shared<Primitive>(kBatchNormWithActivation);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var};
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, umonad};
|
||||||
auto fused_batch_norm_with_relu = graph->NewCNode(inputs);
|
auto fused_batch_norm_with_relu = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_relu);
|
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_relu);
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ class BatchNormReluFusion : public PatternProcessPass {
|
||||||
bias_ = std::make_shared<Var>();
|
bias_ = std::make_shared<Var>();
|
||||||
mean_ = std::make_shared<Var>();
|
mean_ = std::make_shared<Var>();
|
||||||
var_ = std::make_shared<Var>();
|
var_ = std::make_shared<Var>();
|
||||||
|
umonad_ = std::make_shared<Var>();
|
||||||
index_ = std::make_shared<Var>();
|
index_ = std::make_shared<Var>();
|
||||||
}
|
}
|
||||||
~BatchNormReluFusion() override = default;
|
~BatchNormReluFusion() override = default;
|
||||||
|
@ -41,6 +42,7 @@ class BatchNormReluFusion : public PatternProcessPass {
|
||||||
VarPtr bias_;
|
VarPtr bias_;
|
||||||
VarPtr mean_;
|
VarPtr mean_;
|
||||||
VarPtr var_;
|
VarPtr var_;
|
||||||
|
VarPtr umonad_;
|
||||||
VarPtr index_;
|
VarPtr index_;
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
|
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
|
||||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_, umonad_});
|
||||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
|
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
|
||||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||||
|
@ -68,6 +68,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
||||||
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
auto bias = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex2);
|
||||||
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
auto mean = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex3);
|
||||||
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
auto var = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex4);
|
||||||
|
auto umonad = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), kIndex5);
|
||||||
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex0);
|
auto z = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), kIndex0);
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
@ -75,11 +76,12 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
||||||
MS_EXCEPTION_IF_NULL(bias);
|
MS_EXCEPTION_IF_NULL(bias);
|
||||||
MS_EXCEPTION_IF_NULL(mean);
|
MS_EXCEPTION_IF_NULL(mean);
|
||||||
MS_EXCEPTION_IF_NULL(var);
|
MS_EXCEPTION_IF_NULL(var);
|
||||||
|
MS_EXCEPTION_IF_NULL(umonad);
|
||||||
MS_EXCEPTION_IF_NULL(z);
|
MS_EXCEPTION_IF_NULL(z);
|
||||||
|
|
||||||
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z, umonad};
|
||||||
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
|
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu);
|
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu);
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ class PostBatchNormAddReluFusion : public PatternProcessPass {
|
||||||
bias_ = std::make_shared<Var>();
|
bias_ = std::make_shared<Var>();
|
||||||
mean_ = std::make_shared<Var>();
|
mean_ = std::make_shared<Var>();
|
||||||
var_ = std::make_shared<Var>();
|
var_ = std::make_shared<Var>();
|
||||||
|
umonad_ = std::make_shared<Var>();
|
||||||
index_ = std::make_shared<Var>();
|
index_ = std::make_shared<Var>();
|
||||||
z_ = std::make_shared<Var>();
|
z_ = std::make_shared<Var>();
|
||||||
}
|
}
|
||||||
|
@ -43,6 +44,7 @@ class PostBatchNormAddReluFusion : public PatternProcessPass {
|
||||||
VarPtr bias_;
|
VarPtr bias_;
|
||||||
VarPtr mean_;
|
VarPtr mean_;
|
||||||
VarPtr var_;
|
VarPtr var_;
|
||||||
|
VarPtr umonad_;
|
||||||
VarPtr index_;
|
VarPtr index_;
|
||||||
VarPtr z_;
|
VarPtr z_;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue