forked from mindspore-Ecosystem/mindspore
!23012 fix RCAN training bug
Merge pull request !23012 from liuyu/master
This commit is contained in:
commit
2bc5f2a79c
|
@ -37,19 +37,16 @@ def train():
|
|||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=device_num, global_rank=device_id,
|
||||
gradients_mean=True)
|
||||
if args.modelArts_mode:
|
||||
import moxing as mox
|
||||
local_data_url = '/cache/data'
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=device_num, gradients_mean=True)
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
|
||||
|
|
Loading…
Reference in New Issue