!9936 is_save_on_master doesn't work ,it is ok now

From: @shuzigood
Reviewed-by: @guoqi1024,@oacjiewen,@linqingke
Signed-off-by: @linqingke
This commit is contained in:
mindspore-ci-bot 2020-12-17 14:25:27 +08:00 committed by Gitee
commit 0db846978e
2 changed files with 14 additions and 18 deletions

View File

@ -75,7 +75,6 @@ python train.py \
--is_distributed=0 \
--lr=0.001 \
--loss_scale=1024 \
--sens=1024 \
--weight_decay=0.016 \
--T_max=320 \
--max_epoch=320 \
@ -175,8 +174,6 @@ optional arguments:
Whether to use label smooth in CE. Default:0
--label_smooth_factor LABEL_SMOOTH_FACTOR
Smooth strength of original one-hot. Default: 0.1
--sens SENS
Static sens. Default: 1024
--log_interval LOG_INTERVAL
Logging interval steps. Default: 100
--ckpt_path CKPT_PATH
@ -211,7 +208,6 @@ python train.py \
--is_distributed=0 \
--lr=0.001 \
--loss_scale=1024 \
--sens=1024 \
--weight_decay=0.016 \
--T_max=320 \
--max_epoch=320 \

View File

@ -124,20 +124,6 @@ def parse_args():
args.data_root = os.path.join(args.data_dir, 'train2014')
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
return args
@ -160,6 +146,20 @@ def train():
init("nccl")
args.rank = get_rank()
args.group_size = get_group_size()
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
if args.need_profiler:
from mindspore.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)