forked from mindspore-Ecosystem/mindspore
!4884 fix wide&deep dropout set strategy
Merge pull request !4884 from yao_yf/fix_wide_and_deep_dropout
This commit is contained in:
commit
7ba977c435
|
@ -93,9 +93,7 @@ class Dropout(Cell):
|
||||||
self.dropout_do_mask = P.DropoutDoMask()
|
self.dropout_do_mask = P.DropoutDoMask()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.is_gpu = context.get_context('device_target') in ["GPU"]
|
self.is_gpu = context.get_context('device_target') in ["GPU"]
|
||||||
|
self.dropout = P.Dropout(keep_prob)
|
||||||
if self.is_gpu:
|
|
||||||
self.dropout = P.Dropout(keep_prob)
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
if not self.training:
|
if not self.training:
|
||||||
|
|
|
@ -128,8 +128,10 @@ class EmbeddingLookup(Cell):
|
||||||
vocab_size (int): Size of the dictionary of embeddings.
|
vocab_size (int): Size of the dictionary of embeddings.
|
||||||
embedding_size (int): The size of each embedding vector.
|
embedding_size (int): The size of each embedding vector.
|
||||||
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
||||||
target (str): Specify the target where the op is executed. Default: 'CPU'.
|
target (str): Specify the target where the op is executed. The value should in
|
||||||
slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'.
|
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||||
|
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through
|
||||||
|
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'.
|
||||||
manual_shapes (tuple): The accompaniment array in field slice mode.
|
manual_shapes (tuple): The accompaniment array in field slice mode.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
|
|
Loading…
Reference in New Issue