forked from mindspore-Ecosystem/mindspore
!972 Fix CSE bug for some operations like `DropoutGenMask` which may have random effect
Merge pull request !972 from seatea/fix-cse-bug-for-random-effect
This commit is contained in:
commit
78a29bba3c
|
@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
|||
return changed;
|
||||
}
|
||||
|
||||
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
||||
bool has_random_effect = false;
|
||||
auto prim_main = GetCNodePrimitive(main);
|
||||
auto prim_node = GetCNodePrimitive(node);
|
||||
if (prim_main == prim_node) {
|
||||
return false;
|
||||
}
|
||||
if (prim_main != nullptr) {
|
||||
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
|
||||
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
|
||||
has_random_effect = GetValue<bool>(effect_val);
|
||||
}
|
||||
}
|
||||
return has_random_effect;
|
||||
}
|
||||
|
||||
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(main);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|||
break;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) {
|
||||
if (CheckRandomEffect(c_main, c_node)) {
|
||||
appsame = false;
|
||||
}
|
||||
replace = appsame;
|
||||
|
|
|
@ -43,6 +43,8 @@ class CSE {
|
|||
|
||||
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
|
||||
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
|
||||
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
|
|||
const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
|
||||
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
|
||||
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
|
||||
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
|
|||
extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
|
||||
extern const char GRAPH_FLAG_HAS_EFFECT[];
|
||||
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
|
||||
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1872,6 +1872,7 @@ class DropoutGenMask(Primitive):
|
|||
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
|
||||
validator.check_value_type("Seed0", Seed0, [int], self.name)
|
||||
validator.check_value_type("Seed1", Seed1, [int], self.name)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
|
||||
|
||||
class DropoutDoMask(PrimitiveWithInfer):
|
||||
|
|
|
@ -162,7 +162,7 @@ def test_bert_tdt():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661]
|
||||
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
|
||||
|
||||
|
|
Loading…
Reference in New Issue