!5277 add option choices for transformer

Merge pull request !5277 from yuchaojie/transformer_st
This commit is contained in:
mindspore-ci-bot 2020-08-27 19:47:38 +08:00 committed by Gitee
commit e4909f9050
2 changed files with 12 additions and 7 deletions

View File

@ -18,6 +18,7 @@
import os import os
import time import time
import argparse import argparse
import ast
import random import random
import numpy as np import numpy as np
@ -41,7 +42,7 @@ np.random.seed(1)
de.config.set_seed(1) de.config.set_seed(1)
parser = argparse.ArgumentParser(description="FasterRcnn training") parser = argparse.ArgumentParser(description="FasterRcnn training")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default: false.") parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")

View File

@ -89,16 +89,20 @@ def argparse_init():
Argparse init. Argparse init.
""" """
parser = argparse.ArgumentParser(description='transformer') parser = argparse.ArgumentParser(description='transformer')
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--distribute", type=str, default="false", choices=['true', 'false'],
help="Run distribute, default is false.")
parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.") parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", choices=['true', 'false'],
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") help="Use lossscale or not, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.") parser.add_argument("--do_shuffle", type=str, default="true", choices=['true', 'false'],
help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="false", choices=['true', 'false'],
help="Enable data sink, default is false.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, " parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=['true', 'false'],
"default is true.") help="Enable save checkpoint, default is true.")
parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, " parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, "
"default is 2500.") "default is 2500.")
parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.") parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.")