!19820 Fix Dataset Import Error in DataParallel Mode

Merge pull request !19820 from huangxinjing/code_docs_fix_dataset_import
This commit is contained in:
i-robot 2021-07-09 07:36:57 +00:00 committed by Gitee
commit b137e8a812
2 changed files with 7 additions and 9 deletions

View File

@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
os.path.join(home_path, name) for name in files
if not name.endswith(".db")
]
# Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem.
data.sort()
# Load data files and preprocess
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)

View File

@ -279,15 +279,11 @@ def run_train_pipeline(args_opt):
optimizer = nn.Lamb(group_params, learning_rate=lr)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
if context.get_auto_parallel_context("full_batch"):
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
else:
if batch_size % stage_device_num != 0:
raise ValueError("Batch_size should be divisible by device_num")
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
column_name=args_opt.data_column_name)
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
full_batch=context.get_auto_parallel_context("full_batch"),
column_name=args_opt.data_column_name)
epoch_num = args_opt.epoch_size
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size