!23012 fix RCAN training bug

Merge pull request !23012 from liuyu/master
This commit is contained in:
i-robot 2021-09-08 09:41:57 +00:00 committed by Gitee
commit 2bc5f2a79c
1 changed files with 5 additions and 8 deletions

View File

@ -37,19 +37,16 @@ def train():
device_num = int(os.getenv('RANK_SIZE', '1'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=device_num, global_rank=device_id,
gradients_mean=True)
if args.modelArts_mode:
import moxing as mox
local_data_url = '/cache/data'
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=device_num, gradients_mean=True)
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,