diff --git a/model_zoo/official/nlp/pangu_alpha/predict.py b/model_zoo/official/nlp/pangu_alpha/predict.py index 0d8ce6bd9d7..a27aa87b239 100644 --- a/model_zoo/official/nlp/pangu_alpha/predict.py +++ b/model_zoo/official/nlp/pangu_alpha/predict.py @@ -74,7 +74,7 @@ def run_predict(args_opt): device_num = 1 # Set model property - model_parallel_num = args_opt.tensor_model_parallel_num + model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num diff --git a/model_zoo/official/nlp/pangu_alpha/src/dataset.py b/model_zoo/official/nlp/pangu_alpha/src/dataset.py index 19f290bdc07..8e7707510d0 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/dataset.py +++ b/model_zoo/official/nlp/pangu_alpha/src/dataset.py @@ -39,7 +39,7 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset): attention_mask: the attention mask considering eod reset """ rank = int(rank) - input_ids = input_ids[rank*dis: (rank+1)*dis] + input_ids = input_ids[rank * dis: (rank + 1) * dis] if not eod_reset: return input_ids seq_length = input_ids.shape[1] - 1 @@ -60,8 +60,8 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset): for i in range(eod_index.size): # Reset position_ids and attention_mask considering EOD index = eod_index[i] - batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0 - batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index) + batch_attention_mask[bs_i, (index + 1):, :(index + 1)] = 0 + batch_position_ids[bs_i, (index + 1):] -= (index + 1 - prev_index) prev_index = index + 1 return batch_input_ids, batch_position_ids, batch_attention_mask @@ -106,9 +106,10 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_ else: # Each card slice a small batch from the full batch dis = int(batch_size / device_num) - if dis <= 0: + if batch_size % device_num != 0: raise ValueError( - "batch size {} should be a multiple of device number {}.".format(batch_size, device_num)) + f"batch size {batch_size} should be a multiple of device number {device_num}." + " You should change the args: per_batch_size.") map_func = (lambda input_ids: get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset)) # If eod_reset enabled, another two inputs will be generated through input_ids diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py index b7e357cd146..ad255cfa9df 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py @@ -95,7 +95,7 @@ def set_parse(args_opt): args_opt.optimizer_shard = 0 args_opt.stage_num = 16 args_opt.micro_size = 32 - args_opt.tensor_model_parallel_num = 16 + args_opt.op_level_model_parallel_num = 16 elif args_opt.run_type == "predict": args_opt.stage_num = 4 args_opt.micro_size = 1 @@ -103,7 +103,7 @@ def set_parse(args_opt): args_opt.embedding_size = 5120 args_opt.num_layers = 40 args_opt.num_heads = 40 - args_opt.tensor_model_parallel_num = 8 + args_opt.op_level_model_parallel_num = 8 if args_opt.run_type == "train": args_opt.start_lr = 5e-5 args_opt.end_lr = 1e-6 @@ -117,7 +117,7 @@ def set_parse(args_opt): args_opt.embedding_size = 2560 args_opt.num_layers = 32 args_opt.num_heads = 32 - args_opt.tensor_model_parallel_num = 8 + args_opt.op_level_model_parallel_num = 8 if args_opt.run_type == "train": args_opt.start_lr = 1e-4 args_opt.end_lr = 1e-6 diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index c76f5d2b0b8..6e9a60dcb16 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -16,6 +16,8 @@ network config setting, gradient clip function and dynamic learning rate function """ import argparse +import os +import time import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P @@ -24,7 +26,6 @@ from mindspore.ops import functional as F import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR - from mindspore.parallel._utils import _get_global_rank from mindspore.communication.management import get_group_size from mindspore.nn import AdamWeightDecay @@ -67,6 +68,7 @@ class FP32StateAdamWeightDecay(AdamWeightDecay): new.append(new_state) return ParameterTuple(new) + get_square_sum = C.MultitypeFuncGraph("get_square_sum") @@ -92,6 +94,7 @@ class GlobalNorm(nn.Cell): Calculate the global norm value of given tensors """ + def __init__(self, params): super(GlobalNorm, self).__init__() self.norm = nn.Norm() @@ -106,8 +109,9 @@ class GlobalNorm(nn.Cell): if item: self.values.append(Tensor([1.0], mstype.float32)) else: - self.values.append(Tensor([self.group_size*1.0], mstype.float32)) + self.values.append(Tensor([self.group_size * 1.0], mstype.float32)) self.values = tuple(self.values) + def construct(self, grads): # Square sum of gradients for current rank square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) @@ -122,6 +126,7 @@ class ClipByGlobalNorm(nn.Cell): Clip grads by global norm """ + def __init__(self, params, clip_norm=1.0): super(ClipByGlobalNorm, self).__init__() self.global_norm = GlobalNorm(params) @@ -143,11 +148,11 @@ def _get_model_parallel_group(dp, mp): return [x + index * mp for x in group] - class LearningRate(LearningRateSchedule): """ Warmup-decay learning rate for PanguAlpha network. """ + def __init__(self, learning_rate, end_learning_rate, @@ -186,6 +191,7 @@ class LearningRate(LearningRateSchedule): lr = decay_lr return lr + def add_inference_params(opt): """Add inference params""" opt.add_argument("--frequency_penalty", @@ -213,6 +219,7 @@ def add_inference_params(opt): default=9, help="the token id for ") + def add_training_params(opt): """Add training params""" opt.add_argument("--seq_length", @@ -269,8 +276,8 @@ def add_training_params(opt): default=2, help="The sink size of the training") opt.add_argument("--full_batch", - default=0, - help="Import the full size of a batch for each card, default is 0") + default=1, + help="Import the full size of a batch for each card, default is 1") opt.add_argument("--optimizer_shard", type=int, default=1, @@ -279,6 +286,18 @@ def add_training_params(opt): type=int, default=6, help="The batch size for each data parallel way. default 32") + opt.add_argument("--start_lr", + type=float, + default=5e-5, + help="The start learning rate. default 1e-5") + opt.add_argument("--end_lr", + type=float, + default=1e-6, + help="The end learning rate. default 1e-6") + opt.add_argument("--op_level_model_parallel_num", + type=int, + default=8, + help="The model parallel way. default 8") def get_args(inference=False): @@ -335,10 +354,41 @@ def get_args(inference=False): type=str, default="fp32", help="The initialization type for parameters. Default fp32.") - + parser.add_argument("--offline", + type=int, + default=1, + help="Running on cloud of not. Default 1.") add_training_params(parser) if inference: add_inference_params(parser) args_opt = parser.parse_args() return args_opt + + +def download_data(src_data_url, tgt_data_path, rank): + """ + Download the dataset from the obs. + src_data_url (Str): should be the dataset path in the obs + tgt_data_path (Str): the local dataset path + rank (Int): the current rank id + + """ + cache_url = tgt_data_path + EXEC_PATH = '/tmp' + if rank % 8 == 0: + import moxing as mox + print("Modify the time out from 300 to 30000") + print("begin download dataset", flush=True) + + if not os.path.exists(cache_url): + os.makedirs(cache_url, exist_ok=True) + mox.file.copy_parallel(src_url=src_data_url, + dst_url=cache_url) + print("Dataset download succeed!", flush=True) + + f = open("%s/install.txt" % (EXEC_PATH), 'w') + f.close() + # stop + while not os.path.exists("%s/install.txt" % (EXEC_PATH)): + time.sleep(1) diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index b2a61504e56..ff4257b69dd 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -29,11 +29,13 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell import mindspore.common.dtype as mstype from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs + from src.dataset import create_dataset from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell from src.pangu_alpha_config import PANGUALPHAConfig, set_parse from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay +from src.utils import download_data class LossCallBack(Callback): @@ -41,6 +43,7 @@ class LossCallBack(Callback): Monitor the loss in training. If the loss in NAN or INF terminating training. """ + def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0): super(LossCallBack, self).__init__() self._dataset_size = dataset_size @@ -103,8 +106,14 @@ def run_train(args_opt): rank = 0 device_num = 1 + # copy data from the cloud to the /cache/Data + cache_url = '/cache/Data/' + if args_opt.offline: + cache_url = args_opt.data_url + else: + download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank) # Set model property - model_parallel_num = args_opt.tensor_model_parallel_num + model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) batch_size = args_opt.per_batch_size * data_parallel_num config = PANGUALPHAConfig( @@ -162,7 +171,7 @@ def run_train(args_opt): loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size # Dataset loading mindrecord files - ds = create_dataset(config.batch_size, data_path=args_opt.data_url, + ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num) step_per_epoch = ds.get_dataset_size()