forked from mindspore-Ecosystem/mindspore
!5466 remove bool parameter parser in wide_and_deep
Merge pull request !5466 from yao_yf/remove_bool_paraser_in_wide_and_deep_r0.7
This commit is contained in:
commit
a9943a382c
|
@ -129,9 +129,9 @@ class EmbeddingLookup(Cell):
|
|||
embedding_size (int): The size of each embedding vector.
|
||||
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
||||
target (str): Specify the target where the op is executed. The value should in
|
||||
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through
|
||||
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'.
|
||||
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
|
||||
manual_shapes (tuple): The accompaniment array in field slice mode.
|
||||
|
||||
Inputs:
|
||||
|
|
|
@ -25,7 +25,7 @@ def argparse_init():
|
|||
parser.add_argument("--data_path", type=str, default="./test_raw_data/",
|
||||
help="This should be set to the same directory given to the data_download's data_dir argument")
|
||||
parser.add_argument("--epochs", type=int, default=15, help="Total train epochs")
|
||||
parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ")
|
||||
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ")
|
||||
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
|
||||
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
|
||||
|
@ -88,7 +88,7 @@ class WideDeepConfig():
|
|||
self.device_target = args.device_target
|
||||
self.data_path = args.data_path
|
||||
self.epochs = args.epochs
|
||||
self.full_batch = args.full_batch
|
||||
self.full_batch = bool(args.full_batch)
|
||||
self.batch_size = args.batch_size
|
||||
self.eval_batch_size = args.eval_batch_size
|
||||
self.field_size = args.field_size
|
||||
|
|
|
@ -31,7 +31,7 @@ def argparse_init():
|
|||
parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr
|
||||
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
|
||||
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
|
||||
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
|
||||
parser.add_argument("--is_tf_dataset", type=int, default=1) # Is tf_dataset.
|
||||
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
|
||||
|
||||
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||
|
@ -87,7 +87,7 @@ class WideDeepConfig():
|
|||
self.l2_coef = args.l2_coef
|
||||
self.ftrl_lr = args.ftrl_lr
|
||||
self.adam_lr = args.adam_lr
|
||||
self.is_tf_dataset = args.is_tf_dataset
|
||||
self.is_tf_dataset = bool(args.is_tf_dataset)
|
||||
|
||||
self.output_path = args.output_path
|
||||
self.eval_file_name = args.eval_file_name
|
||||
|
|
Loading…
Reference in New Issue