forked from mindspore-Ecosystem/mindspore
fix bug for dropout do mask
This commit is contained in:
parent
45a17e3edf
commit
5828973978
|
@ -987,6 +987,11 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
|
|||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
|
||||
MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
|
||||
node_size = 2;
|
||||
}
|
||||
|
||||
for (size_t index = 1; index < node_size; ++index) {
|
||||
AnfNodePtr input = node->input(index);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
|
|
@ -25,13 +25,11 @@ class Net(Cell):
|
|||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.mul2 = P.Mul().shard(strategy1)
|
||||
self.dropout_do_mask = P.DropoutDoMask().shard(strategy2)
|
||||
self.dropout_gen_mask = P.DropoutGenMask()
|
||||
self.get_shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
self.mul_weight2 = Parameter(mul_weight, "w2")
|
||||
self.keep_prob = Tensor(0.9)
|
||||
|
||||
def construct(self, x, b):
|
||||
|
@ -41,7 +39,6 @@ class Net(Cell):
|
|||
keep_prob = self.cast(self.keep_prob, dtype)
|
||||
mask = self.dropout_gen_mask(shape, keep_prob)
|
||||
out = self.dropout_do_mask(out, mask, keep_prob)
|
||||
out = self.mul2(out, self.mul_weight2)
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue