forked from mindspore-Ecosystem/mindspore
!18808 Add Cloud Training Support for PanGu Model
Merge pull request !18808 from huangxinjing/pangu_master_run
This commit is contained in:
commit
f026c4ae3e
|
@ -74,7 +74,7 @@ def run_predict(args_opt):
|
|||
device_num = 1
|
||||
|
||||
# Set model property
|
||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
per_batch_size = args_opt.per_batch_size
|
||||
batch_size = per_batch_size * data_parallel_num
|
||||
|
|
|
@ -39,7 +39,7 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset):
|
|||
attention_mask: the attention mask considering eod reset
|
||||
"""
|
||||
rank = int(rank)
|
||||
input_ids = input_ids[rank*dis: (rank+1)*dis]
|
||||
input_ids = input_ids[rank * dis: (rank + 1) * dis]
|
||||
if not eod_reset:
|
||||
return input_ids
|
||||
seq_length = input_ids.shape[1] - 1
|
||||
|
@ -60,8 +60,8 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset):
|
|||
for i in range(eod_index.size):
|
||||
# Reset position_ids and attention_mask considering EOD
|
||||
index = eod_index[i]
|
||||
batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
|
||||
batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
|
||||
batch_attention_mask[bs_i, (index + 1):, :(index + 1)] = 0
|
||||
batch_position_ids[bs_i, (index + 1):] -= (index + 1 - prev_index)
|
||||
prev_index = index + 1
|
||||
return batch_input_ids, batch_position_ids, batch_attention_mask
|
||||
|
||||
|
@ -106,9 +106,10 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
|
|||
else:
|
||||
# Each card slice a small batch from the full batch
|
||||
dis = int(batch_size / device_num)
|
||||
if dis <= 0:
|
||||
if batch_size % device_num != 0:
|
||||
raise ValueError(
|
||||
"batch size {} should be a multiple of device number {}.".format(batch_size, device_num))
|
||||
f"batch size {batch_size} should be a multiple of device number {device_num}."
|
||||
" You should change the args: per_batch_size.")
|
||||
|
||||
map_func = (lambda input_ids: get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset))
|
||||
# If eod_reset enabled, another two inputs will be generated through input_ids
|
||||
|
|
|
@ -95,7 +95,7 @@ def set_parse(args_opt):
|
|||
args_opt.optimizer_shard = 0
|
||||
args_opt.stage_num = 16
|
||||
args_opt.micro_size = 32
|
||||
args_opt.tensor_model_parallel_num = 16
|
||||
args_opt.op_level_model_parallel_num = 16
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 4
|
||||
args_opt.micro_size = 1
|
||||
|
@ -103,7 +103,7 @@ def set_parse(args_opt):
|
|||
args_opt.embedding_size = 5120
|
||||
args_opt.num_layers = 40
|
||||
args_opt.num_heads = 40
|
||||
args_opt.tensor_model_parallel_num = 8
|
||||
args_opt.op_level_model_parallel_num = 8
|
||||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 5e-5
|
||||
args_opt.end_lr = 1e-6
|
||||
|
@ -117,7 +117,7 @@ def set_parse(args_opt):
|
|||
args_opt.embedding_size = 2560
|
||||
args_opt.num_layers = 32
|
||||
args_opt.num_heads = 32
|
||||
args_opt.tensor_model_parallel_num = 8
|
||||
args_opt.op_level_model_parallel_num = 8
|
||||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 1e-4
|
||||
args_opt.end_lr = 1e-6
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
network config setting, gradient clip function and dynamic learning rate function
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -24,7 +26,6 @@ from mindspore.ops import functional as F
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
|
||||
|
||||
from mindspore.parallel._utils import _get_global_rank
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn import AdamWeightDecay
|
||||
|
@ -67,6 +68,7 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
|
|||
new.append(new_state)
|
||||
return ParameterTuple(new)
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
|
||||
|
||||
|
@ -92,6 +94,7 @@ class GlobalNorm(nn.Cell):
|
|||
Calculate the global norm value of given tensors
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
|
@ -106,8 +109,9 @@ class GlobalNorm(nn.Cell):
|
|||
if item:
|
||||
self.values.append(Tensor([1.0], mstype.float32))
|
||||
else:
|
||||
self.values.append(Tensor([self.group_size*1.0], mstype.float32))
|
||||
self.values.append(Tensor([self.group_size * 1.0], mstype.float32))
|
||||
self.values = tuple(self.values)
|
||||
|
||||
def construct(self, grads):
|
||||
# Square sum of gradients for current rank
|
||||
square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
|
||||
|
@ -122,6 +126,7 @@ class ClipByGlobalNorm(nn.Cell):
|
|||
Clip grads by global norm
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm(params)
|
||||
|
@ -143,11 +148,11 @@ def _get_model_parallel_group(dp, mp):
|
|||
return [x + index * mp for x in group]
|
||||
|
||||
|
||||
|
||||
class LearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for PanguAlpha network.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
end_learning_rate,
|
||||
|
@ -186,6 +191,7 @@ class LearningRate(LearningRateSchedule):
|
|||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
|
||||
def add_inference_params(opt):
|
||||
"""Add inference params"""
|
||||
opt.add_argument("--frequency_penalty",
|
||||
|
@ -213,6 +219,7 @@ def add_inference_params(opt):
|
|||
default=9,
|
||||
help="the token id for <end of document>")
|
||||
|
||||
|
||||
def add_training_params(opt):
|
||||
"""Add training params"""
|
||||
opt.add_argument("--seq_length",
|
||||
|
@ -269,8 +276,8 @@ def add_training_params(opt):
|
|||
default=2,
|
||||
help="The sink size of the training")
|
||||
opt.add_argument("--full_batch",
|
||||
default=0,
|
||||
help="Import the full size of a batch for each card, default is 0")
|
||||
default=1,
|
||||
help="Import the full size of a batch for each card, default is 1")
|
||||
opt.add_argument("--optimizer_shard",
|
||||
type=int,
|
||||
default=1,
|
||||
|
@ -279,6 +286,18 @@ def add_training_params(opt):
|
|||
type=int,
|
||||
default=6,
|
||||
help="The batch size for each data parallel way. default 32")
|
||||
opt.add_argument("--start_lr",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="The start learning rate. default 1e-5")
|
||||
opt.add_argument("--end_lr",
|
||||
type=float,
|
||||
default=1e-6,
|
||||
help="The end learning rate. default 1e-6")
|
||||
opt.add_argument("--op_level_model_parallel_num",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The model parallel way. default 8")
|
||||
|
||||
|
||||
def get_args(inference=False):
|
||||
|
@ -335,10 +354,41 @@ def get_args(inference=False):
|
|||
type=str,
|
||||
default="fp32",
|
||||
help="The initialization type for parameters. Default fp32.")
|
||||
|
||||
parser.add_argument("--offline",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Running on cloud of not. Default 1.")
|
||||
add_training_params(parser)
|
||||
if inference:
|
||||
add_inference_params(parser)
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
return args_opt
|
||||
|
||||
|
||||
def download_data(src_data_url, tgt_data_path, rank):
|
||||
"""
|
||||
Download the dataset from the obs.
|
||||
src_data_url (Str): should be the dataset path in the obs
|
||||
tgt_data_path (Str): the local dataset path
|
||||
rank (Int): the current rank id
|
||||
|
||||
"""
|
||||
cache_url = tgt_data_path
|
||||
EXEC_PATH = '/tmp'
|
||||
if rank % 8 == 0:
|
||||
import moxing as mox
|
||||
print("Modify the time out from 300 to 30000")
|
||||
print("begin download dataset", flush=True)
|
||||
|
||||
if not os.path.exists(cache_url):
|
||||
os.makedirs(cache_url, exist_ok=True)
|
||||
mox.file.copy_parallel(src_url=src_data_url,
|
||||
dst_url=cache_url)
|
||||
print("Dataset download succeed!", flush=True)
|
||||
|
||||
f = open("%s/install.txt" % (EXEC_PATH), 'w')
|
||||
f.close()
|
||||
# stop
|
||||
while not os.path.exists("%s/install.txt" % (EXEC_PATH)):
|
||||
time.sleep(1)
|
||||
|
|
|
@ -29,11 +29,13 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
||||
from src.utils import download_data
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
|
@ -41,6 +43,7 @@ class LossCallBack(Callback):
|
|||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
self._dataset_size = dataset_size
|
||||
|
@ -103,8 +106,14 @@ def run_train(args_opt):
|
|||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
# copy data from the cloud to the /cache/Data
|
||||
cache_url = '/cache/Data/'
|
||||
if args_opt.offline:
|
||||
cache_url = args_opt.data_url
|
||||
else:
|
||||
download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank)
|
||||
# Set model property
|
||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
batch_size = args_opt.per_batch_size * data_parallel_num
|
||||
config = PANGUALPHAConfig(
|
||||
|
@ -162,7 +171,7 @@ def run_train(args_opt):
|
|||
loss_scale_value = math.pow(2, 32)
|
||||
epoch_num = args_opt.epoch_size
|
||||
# Dataset loading mindrecord files
|
||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url,
|
||||
data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
|
||||
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue