From ef13a4b6fb7fe8705e09e0851687b67c20ca3100 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Thu, 25 Jun 2020 00:45:10 +0800 Subject: [PATCH] adjust cse code when op has side effect. --- mindspore/ccsrc/optimizer/cse.cc | 86 ++++++++++++------- mindspore/ccsrc/optimizer/cse.h | 2 +- .../pass/common_subexpression_elimination.cc | 2 +- .../pass/common_subexpression_elimination.h | 2 +- mindspore/ccsrc/pybind_api/export_flags.cc | 1 + mindspore/ccsrc/pybind_api/export_flags.h | 2 +- mindspore/nn/wrap/cell_wrapper.py | 4 +- mindspore/ops/operations/debug_ops.py | 2 +- tests/ut/python/utils/test_serialize.py | 1 + 9 files changed, 66 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index 1af08ea3e12..0b675cca721 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } - +// The op like print, summary, or the op do not has true output, and always as a depend node input. +static bool HasSideEffect(const AnfNodePtr &node) { + auto prim = GetCNodePrimitive(node); + if (prim == nullptr) { + return false; + } + auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); + if (side_effect_v != nullptr && side_effect_v->isa()) { + return GetValue(side_effect_v); + } + return false; +} +// If true do not merge the node. bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { bool has_random_effect = false; auto prim_main = GetCNodePrimitive(main); auto prim_node = GetCNodePrimitive(node); - if (prim_main == prim_node) { - return false; - } + // if has random effect, when generate by different op (not same object), do not merge. if (prim_main != nullptr) { + if (prim_main == prim_node) { + return false; + } auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); if (effect_val != nullptr && effect_val->isa()) { has_random_effect = GetValue(effect_val); @@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons return has_random_effect; } -bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { +bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); - bool replace = false; if (main->isa() && node->isa()) { auto main_value = GetValueNode(main); auto node_value = GetValueNode(node); - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); } else if (main->isa() && node->isa()) { auto c_main = main->cast(); auto c_node = node->cast(); + // When appsame is true, check if has side effect, do not merge. + if (check_side_effect && HasSideEffect(main)) { + return false; + } const auto &inp1 = c_main->inputs(); const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - // Handle the case of two different Tensor, but with the same value - if (IsValueNode(inp1[j]) && IsValueNode(inp2[j])) { - auto tensor1 = GetValueNode(inp1[j]); - auto tensor2 = GetValueNode(inp2[j]); - if (tensor1->ValueEqual(*tensor2)) { - continue; - } - } - appsame = false; - break; - } - } - if (CheckRandomEffect(c_main, c_node)) { - appsame = false; - } - replace = appsame; + if (inp1.size() != inp2.size()) { + return false; } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + // Handle the case of two different Tensor, but with the same value + if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { + auto tensor1 = GetValueNode(inp1_j); + auto tensor2 = GetValueNode(inp2_j); + if (tensor1->ValueEqual(*tensor2)) { + continue; + } + } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { + // When the same side effect node as another two nodes' inputs, we still merge the node. + // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the + // node. + if (CheckReplace(inp1_j, inp2_j, false)) { + continue; + } + } + return false; + } + } + // When appsame is true, check if has random effect do not merge + if (CheckRandomEffect(c_main, c_node)) { + return false; + } + return true; } - return replace; + // a parameter node. + return false; } bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index fd90f61eebc..57163cc5c9d 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -41,7 +41,7 @@ class CSE { return chg && report_changes_; } - virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; + virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc index 9af50eac330..297a167aa8e 100644 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc @@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { } } // namespace -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h index 8e1768ea99c..18f433ab955 100644 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h +++ b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h @@ -31,7 +31,7 @@ class BackendCSE : public CSE { public: BackendCSE() = default; ~BackendCSE() override = default; - bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override; + bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 83392784f3e..253e271e525 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; +const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 74c27ff35d2..6ea584e66d8 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[]; - +extern const char GRAPH_FLAG_SIDE_EFFECT[]; } // namespace mindspore #endif // PYBIND_API_EXPORT_FLAGS_H_ diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index f0d920f51fa..9e3d00cc959 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -220,7 +220,9 @@ class DataWrapper(Cell): def __init__(self, network, dataset_types, dataset_shapes, queue_name): super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) - + # Also copy the flag in `network` construct + flags = getattr(network.__class__.construct, "_mindspore_flags", {}) + self.add_flags(**flags) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) self.network = network diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index bafc72897e6..c62b3f1ab8b 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -334,7 +334,7 @@ class Print(PrimitiveWithInfer): @prim_attr_register def __init__(self): - pass + self.add_prim_attr("_side_effect", True) def __call__(self, *args): for arg in args: diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index c5b4586566e..035ea878459 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -336,6 +336,7 @@ class PrintNet(nn.Cell): def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2): self.print('============tensor int8:==============', int8) + self.print('============tensor int8:==============', int8) self.print('============tensor uint8:==============', uint8) self.print('============tensor int16:==============', int16) self.print('============tensor uint16:==============', uint16)