diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index ff7291cd86b..1af08ea3e12 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } +bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { + bool has_random_effect = false; + auto prim_main = GetCNodePrimitive(main); + auto prim_node = GetCNodePrimitive(node); + if (prim_main == prim_node) { + return false; + } + if (prim_main != nullptr) { + auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); + if (effect_val != nullptr && effect_val->isa()) { + has_random_effect = GetValue(effect_val); + } + } + return has_random_effect; +} + bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); @@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { break; } } - if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) { + if (CheckRandomEffect(c_main, c_node)) { appsame = false; } replace = appsame; diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index 544e6cb6a36..fd90f61eebc 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -43,6 +43,8 @@ class CSE { virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; + virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; + bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; private: diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 931e9e17b11..83392784f3e 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; +const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index ed68da17dad..74c27ff35d2 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; +extern const char GRAPH_FLAG_RANDOM_EFFECT[]; } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a80500c0e64..545e1e9f2da 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1877,6 +1877,7 @@ class DropoutGenMask(Primitive): self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) validator.check_value_type("Seed0", Seed0, [int], self.name) validator.check_value_type("Seed1", Seed1, [int], self.name) + self.add_prim_attr("_random_effect", True) class DropoutDoMask(PrimitiveWithInfer): diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py index 1bd72d02217..e6578af7491 100644 --- a/tests/st/networks/models/bert/bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -162,7 +162,7 @@ def test_bert_tdt(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661] + expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)