forked from mindspore-Ecosystem/mindspore
!19820 Fix Dataset Import Error in DataParallel Mode
Merge pull request !19820 from huangxinjing/code_docs_fix_dataset_import
This commit is contained in:
commit
b137e8a812
|
@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
|
|||
os.path.join(home_path, name) for name in files
|
||||
if not name.endswith(".db")
|
||||
]
|
||||
# Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem.
|
||||
data.sort()
|
||||
|
||||
# Load data files and preprocess
|
||||
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
||||
|
|
|
@ -279,15 +279,11 @@ def run_train_pipeline(args_opt):
|
|||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
if context.get_auto_parallel_context("full_batch"):
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
|
||||
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
|
||||
else:
|
||||
if batch_size % stage_device_num != 0:
|
||||
raise ValueError("Batch_size should be divisible by device_num")
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
||||
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
|
||||
column_name=args_opt.data_column_name)
|
||||
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
||||
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
|
||||
full_batch=context.get_auto_parallel_context("full_batch"),
|
||||
column_name=args_opt.data_column_name)
|
||||
epoch_num = args_opt.epoch_size
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
callback_size = args_opt.sink_size
|
||||
|
|
Loading…
Reference in New Issue