fix bug for dropout do mask

This commit is contained in:
yangzhenzhang 2021-04-25 16:47:44 +08:00
parent 45a17e3edf
commit 5828973978
2 changed files with 5 additions and 3 deletions

View File

@ -987,6 +987,11 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
node_size = 2;
}
for (size_t index = 1; index < node_size; ++index) {
AnfNodePtr input = node->input(index);
MS_EXCEPTION_IF_NULL(input);

View File

@ -25,13 +25,11 @@ class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.mul2 = P.Mul().shard(strategy1)
self.dropout_do_mask = P.DropoutDoMask().shard(strategy2)
self.dropout_gen_mask = P.DropoutGenMask()
self.get_shape = P.Shape()
self.cast = P.Cast()
self.mul_weight = Parameter(mul_weight, "w1")
self.mul_weight2 = Parameter(mul_weight, "w2")
self.keep_prob = Tensor(0.9)
def construct(self, x, b):
@ -41,7 +39,6 @@ class Net(Cell):
keep_prob = self.cast(self.keep_prob, dtype)
mask = self.dropout_gen_mask(shape, keep_prob)
out = self.dropout_do_mask(out, mask, keep_prob)
out = self.mul2(out, self.mul_weight2)
return out