!5466 remove bool parameter parser in wide_and_deep

Merge pull request !5466 from yao_yf/remove_bool_paraser_in_wide_and_deep_r0.7
This commit is contained in:
mindspore-ci-bot 2020-08-29 14:42:51 +08:00 committed by Gitee
commit a9943a382c
3 changed files with 6 additions and 6 deletions

View File

@ -129,9 +129,9 @@ class EmbeddingLookup(Cell):
embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'.
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'.
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs:

View File

@ -25,7 +25,7 @@ def argparse_init():
parser.add_argument("--data_path", type=str, default="./test_raw_data/",
help="This should be set to the same directory given to the data_download's data_dir argument")
parser.add_argument("--epochs", type=int, default=15, help="Total train epochs")
parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ")
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ")
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
@ -88,7 +88,7 @@ class WideDeepConfig():
self.device_target = args.device_target
self.data_path = args.data_path
self.epochs = args.epochs
self.full_batch = args.full_batch
self.full_batch = bool(args.full_batch)
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size

View File

@ -31,7 +31,7 @@ def argparse_init():
parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
parser.add_argument("--is_tf_dataset", type=int, default=1) # Is tf_dataset.
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
@ -87,7 +87,7 @@ class WideDeepConfig():
self.l2_coef = args.l2_coef
self.ftrl_lr = args.ftrl_lr
self.adam_lr = args.adam_lr
self.is_tf_dataset = args.is_tf_dataset
self.is_tf_dataset = bool(args.is_tf_dataset)
self.output_path = args.output_path
self.eval_file_name = args.eval_file_name