From 48efc182e0d9a5f4cb5e555c9adf18c2086034a4 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Wed, 26 Aug 2020 20:03:04 +0800 Subject: [PATCH] change option choices for transformer --- model_zoo/official/cv/faster_rcnn/train.py | 3 ++- model_zoo/official/nlp/transformer/train.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/model_zoo/official/cv/faster_rcnn/train.py b/model_zoo/official/cv/faster_rcnn/train.py index d48466f6216..895ac725cc1 100644 --- a/model_zoo/official/cv/faster_rcnn/train.py +++ b/model_zoo/official/cv/faster_rcnn/train.py @@ -18,6 +18,7 @@ import os import time import argparse +import ast import random import numpy as np @@ -41,7 +42,7 @@ np.random.seed(1) de.config.set_seed(1) parser = argparse.ArgumentParser(description="FasterRcnn training") -parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default: false.") +parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.") parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") diff --git a/model_zoo/official/nlp/transformer/train.py b/model_zoo/official/nlp/transformer/train.py index 8b7dc434562..7ffa7d47649 100644 --- a/model_zoo/official/nlp/transformer/train.py +++ b/model_zoo/official/nlp/transformer/train.py @@ -89,16 +89,20 @@ def argparse_init(): Argparse init. """ parser = argparse.ArgumentParser(description='transformer') - parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--distribute", type=str, default="false", choices=['true', 'false'], + help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.") - parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") - parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.") + parser.add_argument("--enable_lossscale", type=str, default="true", choices=['true', 'false'], + help="Use lossscale or not, default is true.") + parser.add_argument("--do_shuffle", type=str, default="true", choices=['true', 'false'], + help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="false", choices=['true', 'false'], + help="Enable data sink, default is false.") parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") - parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, " - "default is true.") + parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=['true', 'false'], + help="Enable save checkpoint, default is true.") parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, " "default is 2500.") parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.")