diff --git a/model_zoo/official/nlp/pangu_alpha/src/dataset.py b/model_zoo/official/nlp/pangu_alpha/src/dataset.py index 6f6fc147b31..b8966d870c4 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/dataset.py +++ b/model_zoo/official/nlp/pangu_alpha/src/dataset.py @@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_ os.path.join(home_path, name) for name in files if not name.endswith(".db") ] + # Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem. + data.sort() # Load data files and preprocess dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False) diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index db0aaa84451..7834cd682a6 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -279,15 +279,11 @@ def run_train_pipeline(args_opt): optimizer = nn.Lamb(group_params, learning_rate=lr) else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) - if context.get_auto_parallel_context("full_batch"): - ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True, - data_start_index=0, full_batch=True, column_name=args_opt.data_column_name) - else: - if batch_size % stage_device_num != 0: - raise ValueError("Batch_size should be divisible by device_num") - ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num, - rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False, - column_name=args_opt.data_column_name) + + ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num, + rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0, + full_batch=context.get_auto_parallel_context("full_batch"), + column_name=args_opt.data_column_name) epoch_num = args_opt.epoch_size step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size