forked from mindspore-Ecosystem/mindspore
!9911 Fix an Xception training bug about init() calling position and a small bug in TextRCNN running script.
From: @penny369 Reviewed-by: @guoqi1024,@c_34 Signed-off-by: @guoqi1024
This commit is contained in:
commit
f52b11f974
|
@ -98,24 +98,24 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
rank = get_rank()
|
||||
group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
group_size = 1
|
||||
context.set_context(device_id=0)
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
#train on Ascend
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
|
||||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
init()
|
||||
rank = get_rank()
|
||||
group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
|
||||
else:
|
||||
rank = 0
|
||||
group_size = 1
|
||||
context.set_context(device_id=0)
|
||||
|
||||
# define network
|
||||
net = xception(class_num=config.class_num)
|
||||
net.to_float(mstype.float16)
|
||||
|
|
|
@ -17,4 +17,4 @@ ulimit -u unlimited
|
|||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 &
|
||||
python ${BASEPATH}/../eval.py --ckpt_path $1 > ./eval.log 2>&1 &
|
||||
|
|
Loading…
Reference in New Issue