From 5828973978e2108681b379e4db3faae8c28b57e5 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Sun, 25 Apr 2021 16:47:44 +0800 Subject: [PATCH] fix bug for dropout do mask --- mindspore/ccsrc/frontend/parallel/step_parallel.cc | 5 +++++ tests/ut/python/parallel/test_dropout_do_mask.py | 3 --- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 516bb40ccc4..cf56e896eed 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -987,6 +987,11 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); + if (IsSomePrimitive(node, DROPOUT_DO_MASK)) { + MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]"; + node_size = 2; + } + for (size_t index = 1; index < node_size; ++index) { AnfNodePtr input = node->input(index); MS_EXCEPTION_IF_NULL(input); diff --git a/tests/ut/python/parallel/test_dropout_do_mask.py b/tests/ut/python/parallel/test_dropout_do_mask.py index f727105123c..4c32f41e503 100644 --- a/tests/ut/python/parallel/test_dropout_do_mask.py +++ b/tests/ut/python/parallel/test_dropout_do_mask.py @@ -25,13 +25,11 @@ class Net(Cell): def __init__(self, mul_weight, strategy1=None, strategy2=None): super().__init__() self.mul = P.Mul().shard(strategy1) - self.mul2 = P.Mul().shard(strategy1) self.dropout_do_mask = P.DropoutDoMask().shard(strategy2) self.dropout_gen_mask = P.DropoutGenMask() self.get_shape = P.Shape() self.cast = P.Cast() self.mul_weight = Parameter(mul_weight, "w1") - self.mul_weight2 = Parameter(mul_weight, "w2") self.keep_prob = Tensor(0.9) def construct(self, x, b): @@ -41,7 +39,6 @@ class Net(Cell): keep_prob = self.cast(self.keep_prob, dtype) mask = self.dropout_gen_mask(shape, keep_prob) out = self.dropout_do_mask(out, mask, keep_prob) - out = self.mul2(out, self.mul_weight2) return out