!18808 Add Cloud Training Support for PanGu Model

Merge pull request !18808 from huangxinjing/pangu_master_run
This commit is contained in:
i-robot 2021-06-28 08:47:46 +00:00 committed by Gitee
commit f026c4ae3e
5 changed files with 77 additions and 17 deletions

View File

@ -74,7 +74,7 @@ def run_predict(args_opt):
device_num = 1
# Set model property
model_parallel_num = args_opt.tensor_model_parallel_num
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num

View File

@ -39,7 +39,7 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset):
attention_mask: the attention mask considering eod reset
"""
rank = int(rank)
input_ids = input_ids[rank*dis: (rank+1)*dis]
input_ids = input_ids[rank * dis: (rank + 1) * dis]
if not eod_reset:
return input_ids
seq_length = input_ids.shape[1] - 1
@ -60,8 +60,8 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset):
for i in range(eod_index.size):
# Reset position_ids and attention_mask considering EOD
index = eod_index[i]
batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
batch_attention_mask[bs_i, (index + 1):, :(index + 1)] = 0
batch_position_ids[bs_i, (index + 1):] -= (index + 1 - prev_index)
prev_index = index + 1
return batch_input_ids, batch_position_ids, batch_attention_mask
@ -106,9 +106,10 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
else:
# Each card slice a small batch from the full batch
dis = int(batch_size / device_num)
if dis <= 0:
if batch_size % device_num != 0:
raise ValueError(
"batch size {} should be a multiple of device number {}.".format(batch_size, device_num))
f"batch size {batch_size} should be a multiple of device number {device_num}."
" You should change the args: per_batch_size.")
map_func = (lambda input_ids: get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset))
# If eod_reset enabled, another two inputs will be generated through input_ids

View File

@ -95,7 +95,7 @@ def set_parse(args_opt):
args_opt.optimizer_shard = 0
args_opt.stage_num = 16
args_opt.micro_size = 32
args_opt.tensor_model_parallel_num = 16
args_opt.op_level_model_parallel_num = 16
elif args_opt.run_type == "predict":
args_opt.stage_num = 4
args_opt.micro_size = 1
@ -103,7 +103,7 @@ def set_parse(args_opt):
args_opt.embedding_size = 5120
args_opt.num_layers = 40
args_opt.num_heads = 40
args_opt.tensor_model_parallel_num = 8
args_opt.op_level_model_parallel_num = 8
if args_opt.run_type == "train":
args_opt.start_lr = 5e-5
args_opt.end_lr = 1e-6
@ -117,7 +117,7 @@ def set_parse(args_opt):
args_opt.embedding_size = 2560
args_opt.num_layers = 32
args_opt.num_heads = 32
args_opt.tensor_model_parallel_num = 8
args_opt.op_level_model_parallel_num = 8
if args_opt.run_type == "train":
args_opt.start_lr = 1e-4
args_opt.end_lr = 1e-6

View File

@ -16,6 +16,8 @@
network config setting, gradient clip function and dynamic learning rate function
"""
import argparse
import os
import time
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
@ -24,7 +26,6 @@ from mindspore.ops import functional as F
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
from mindspore.parallel._utils import _get_global_rank
from mindspore.communication.management import get_group_size
from mindspore.nn import AdamWeightDecay
@ -67,6 +68,7 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
new.append(new_state)
return ParameterTuple(new)
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@ -92,6 +94,7 @@ class GlobalNorm(nn.Cell):
Calculate the global norm value of given tensors
"""
def __init__(self, params):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
@ -106,8 +109,9 @@ class GlobalNorm(nn.Cell):
if item:
self.values.append(Tensor([1.0], mstype.float32))
else:
self.values.append(Tensor([self.group_size*1.0], mstype.float32))
self.values.append(Tensor([self.group_size * 1.0], mstype.float32))
self.values = tuple(self.values)
def construct(self, grads):
# Square sum of gradients for current rank
square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
@ -122,6 +126,7 @@ class ClipByGlobalNorm(nn.Cell):
Clip grads by global norm
"""
def __init__(self, params, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm(params)
@ -143,11 +148,11 @@ def _get_model_parallel_group(dp, mp):
return [x + index * mp for x in group]
class LearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for PanguAlpha network.
"""
def __init__(self,
learning_rate,
end_learning_rate,
@ -186,6 +191,7 @@ class LearningRate(LearningRateSchedule):
lr = decay_lr
return lr
def add_inference_params(opt):
"""Add inference params"""
opt.add_argument("--frequency_penalty",
@ -213,6 +219,7 @@ def add_inference_params(opt):
default=9,
help="the token id for <end of document>")
def add_training_params(opt):
"""Add training params"""
opt.add_argument("--seq_length",
@ -269,8 +276,8 @@ def add_training_params(opt):
default=2,
help="The sink size of the training")
opt.add_argument("--full_batch",
default=0,
help="Import the full size of a batch for each card, default is 0")
default=1,
help="Import the full size of a batch for each card, default is 1")
opt.add_argument("--optimizer_shard",
type=int,
default=1,
@ -279,6 +286,18 @@ def add_training_params(opt):
type=int,
default=6,
help="The batch size for each data parallel way. default 32")
opt.add_argument("--start_lr",
type=float,
default=5e-5,
help="The start learning rate. default 1e-5")
opt.add_argument("--end_lr",
type=float,
default=1e-6,
help="The end learning rate. default 1e-6")
opt.add_argument("--op_level_model_parallel_num",
type=int,
default=8,
help="The model parallel way. default 8")
def get_args(inference=False):
@ -335,10 +354,41 @@ def get_args(inference=False):
type=str,
default="fp32",
help="The initialization type for parameters. Default fp32.")
parser.add_argument("--offline",
type=int,
default=1,
help="Running on cloud of not. Default 1.")
add_training_params(parser)
if inference:
add_inference_params(parser)
args_opt = parser.parse_args()
return args_opt
def download_data(src_data_url, tgt_data_path, rank):
"""
Download the dataset from the obs.
src_data_url (Str): should be the dataset path in the obs
tgt_data_path (Str): the local dataset path
rank (Int): the current rank id
"""
cache_url = tgt_data_path
EXEC_PATH = '/tmp'
if rank % 8 == 0:
import moxing as mox
print("Modify the time out from 300 to 30000")
print("begin download dataset", flush=True)
if not os.path.exists(cache_url):
os.makedirs(cache_url, exist_ok=True)
mox.file.copy_parallel(src_url=src_data_url,
dst_url=cache_url)
print("Dataset download succeed!", flush=True)
f = open("%s/install.txt" % (EXEC_PATH), 'w')
f.close()
# stop
while not os.path.exists("%s/install.txt" % (EXEC_PATH)):
time.sleep(1)

View File

@ -29,11 +29,13 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from src.dataset import create_dataset
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
from src.utils import download_data
class LossCallBack(Callback):
@ -41,6 +43,7 @@ class LossCallBack(Callback):
Monitor the loss in training.
If the loss in NAN or INF terminating training.
"""
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0):
super(LossCallBack, self).__init__()
self._dataset_size = dataset_size
@ -103,8 +106,14 @@ def run_train(args_opt):
rank = 0
device_num = 1
# copy data from the cloud to the /cache/Data
cache_url = '/cache/Data/'
if args_opt.offline:
cache_url = args_opt.data_url
else:
download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank)
# Set model property
model_parallel_num = args_opt.tensor_model_parallel_num
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * data_parallel_num
config = PANGUALPHAConfig(
@ -162,7 +171,7 @@ def run_train(args_opt):
loss_scale_value = math.pow(2, 32)
epoch_num = args_opt.epoch_size
# Dataset loading mindrecord files
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
ds = create_dataset(config.batch_size, data_path=cache_url,
data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
step_per_epoch = ds.get_dataset_size()