!22120 Transformer API For Rewring the PanGU model

Merge pull request !22120 from huangxinjing/transformer_example
This commit is contained in:
i-robot 2021-08-31 07:17:16 +00:00 committed by Gitee
commit 570363c2a3
8 changed files with 350 additions and 1205 deletions

View File

@ -28,8 +28,9 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.train.model import Model
from mindspore.train.serialization import load_distributed_checkpoint
from src.pangu_alpha import PanguAlpha, EvalNet
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
from mindspore.parallel.nn.transformer import TransformerOpParallelConfig
from src.pangu_alpha import EvalNet, PanguAlphaModel
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
from src.utils import get_args
@ -74,29 +75,31 @@ def load_model(args_opt):
# Set model property
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
recompute=True)
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num
# Now only support single batch_size for predict
if args_opt.run_type == "predict":
batch_size = 1
config = PANGUALPHAConfig(
data_parallel_num=data_parallel_num,
model_parallel_num=model_parallel_num,
config = PanguAlphaConfig(
batch_size=batch_size,
seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size,
embedding_size=args_opt.embedding_size,
hidden_size=args_opt.embedding_size,
num_layers=args_opt.num_layers,
num_heads=args_opt.num_heads,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.0,
compute_dtype=mstype.float16,
ffn_hidden_size=args_opt.embedding_size * 4,
use_past=use_past,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
eod_reset=False,
word_emb_dp=True,
parallel_config=parallel_config,
load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16)
print("===config is: ", config, flush=True)
@ -104,7 +107,7 @@ def load_model(args_opt):
ckpt_name = args_opt.load_ckpt_name
# Define network
pangu_alpha = PanguAlpha(config)
pangu_alpha = PanguAlphaModel(config)
eval_net = EvalNet(pangu_alpha)
eval_net.set_train(False)
model_predict = Model(eval_net)

View File

@ -1,4 +1,4 @@
numpy
trnsformers
transformers
sentencepiece
jieba

View File

@ -16,8 +16,8 @@
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distributed_train_gpu.sh RANK_SIZE HOSTFILE DATASET PER_BATCH MODE"
echo "for example: bash run_distributed_train_gpu.sh 16 hostfile_16p /mass_dataset/train_data/ 16 2.6B"
echo "bash run_distributed_train_gpu.sh RANK_SIZE HOSTFILE DATASET MODE"
echo "for example: bash run_distributed_train_gpu.sh 16 hostfile_16p /mass_dataset/train_data/ 2.6B"
echo "It is better to use absolute path."
echo "=============================================================================================================="
@ -26,8 +26,7 @@ self_path=$(dirname "${script_self}")
RANK_SIZE=$1
HOSTFILE=$2
DATASET=$3
PER_BATCH=$4
MODE=$5
MODE=$4
mpirun --allow-run-as-root -x PATH -x LD_LIBRARY_PATH -x PYTHONPATH -x NCCL_DEBUG -x GLOG_v -n $RANK_SIZE --hostfile $HOSTFILE --output-filename log_output --merge-stderr-to-stdout \
python -s ${self_path}/../train.py \
@ -36,5 +35,4 @@ mpirun --allow-run-as-root -x PATH -x LD_LIBRARY_PATH -x PYTHONPATH -x NCCL_DEBU
--device_target="GPU" \
--data_url=$DATASET \
--mode=$MODE \
--per_batch_size=$PER_BATCH \
--run_type=train > train_log.txt 2>&1 &

File diff suppressed because it is too large Load Diff

View File

@ -17,57 +17,47 @@ network config setting
"""
import mindspore.common.dtype as mstype
class PanguAlphaConfig:
"""
PanGUConfig config class which defines the model size
"""
class PANGUALPHAConfig:
"""
PANGUALPHA config class which defines the model size
"""
def __init__(self,
data_parallel_num,
model_parallel_num,
batch_size=32,
seq_length=1024,
vocab_size=50257,
embedding_size=768,
vocab_size=40000,
hidden_size=768,
ffn_hidden_size=768,
num_layers=12,
num_heads=12,
expand_ratio=4,
load_ckpt_path=None,
param_init_type=mstype.float32,
post_layernorm_residual=False,
dropout_rate=0.1,
compute_dtype=mstype.float16,
eod_token=6,
use_past=False,
word_emb_dp=True,
stage_num=16,
hidden_act='gelu',
eod_reset=True,
micro_size=32,
load_ckpt_path=None,
use_top_query_attention=True,
param_init_type=mstype.float32,
enable_offload=False):
enable_offload=False,
parallel_config=None):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
# The expand ratio of feature size in FFN
self.expand_ratio = expand_ratio
self.eod_token = eod_token
# Use post-layernorm or pre-layernrom, default:pre-layernorm
self.post_layernorm_residual = post_layernorm_residual
self.dropout_rate = dropout_rate
self.compute_dtype = compute_dtype
# Whether use incremental inference
self.use_past = use_past
self.dp = data_parallel_num
self.mp = model_parallel_num
self.stage_num = stage_num
self.micro_size = micro_size
self.word_emb_dp = word_emb_dp
self.eod_reset = eod_reset
# Used for loading embedding tables
self.load_ckpt_path = load_ckpt_path
self.use_top_query_attention = use_top_query_attention
self.param_init_type = param_init_type
self.dropout_rate = dropout_rate
self.compute_dtype = mstype.float16
self.parallel_config = parallel_config
self.ffn_hidden_size = ffn_hidden_size
self.hidden_act = hidden_act
self.use_past = use_past
self.eod_reset = eod_reset
self.enable_offload = enable_offload
def __str__(self):
@ -78,6 +68,7 @@ class PANGUALPHAConfig:
info += '=' * 10
return info
def set_parse(args_opt):
r"""
Set config according to the mode

View File

@ -70,6 +70,7 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
@grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
accu_grad = F.depend(accu_grad, grad)
@ -79,6 +80,7 @@ def tensor_grad_scale_pipeline(scale, grad, accu_grad):
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
return new_grad
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
new_grad = grad * reciprocal(scale)
@ -151,6 +153,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
self.optimizer(grads)
return loss, cond, scaling_sens
class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
"""
Encapsulation class of PanguAlpha network training.
@ -200,7 +203,7 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.clip = ClipByGlobalNorm(self.weights, self.config)
self.micro_size = config.micro_size
self.micro_size = config.parallel_config.micro_batch_num
self.opt_shard = _get_enable_parallel_optimizer()
@C.add_flags(has_effect=True)

View File

@ -139,12 +139,12 @@ class GlobalNorm(nn.Cell):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
self.hyper_map = C.HyperMap()
self.is_pipeline = (config.stage_num > 1)
self.is_pipeline = context.get_auto_parallel_context("pipeline_stages") > 1
if self.is_pipeline:
if context.get_auto_parallel_context("enable_parallel_optimizer"):
group_size = get_group_size() // config.stage_num
group_size = get_group_size() // config.parallel_config.pipeline_stage
else:
group_size = config.mp
group_size = config.parallel_config.model_parallel
group_list, group_name = _get_model_parallel_group(group_size)
create_group(group_name, group_list)
self.allreduce = P.AllReduce(group=group_name)
@ -160,9 +160,10 @@ class GlobalNorm(nn.Cell):
elif "embedding_table" not in x.name:
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
else:
if not config.word_emb_dp and "position_embedding.embedding_table" not in x.name \
if not config.parallel_config.vocab_emb_dp and "position_embedding.embedding_table" not in x.name \
and "top_query_embedding_table" not in x.name:
self.allreduce_group_size = self.allreduce_group_size + (config.dp * 1.0,)
self.allreduce_group_size = self.allreduce_group_size +\
(config.parallel_config.data_parallel * 1.0,)
else:
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)

View File

@ -29,17 +29,17 @@ import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
from src.adam import AdamWeightDecayOp
from src.dataset import create_dataset
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
from src.utils import download_data
from src.callbacks import EvalCallBack, LossCallBack
from src.metrics import PPLMetric
project_root = os.path.abspath(
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
print('project_root:', project_root)
@ -69,7 +69,8 @@ def run_train(args_opt):
The main training process.
"""
# Set execution mode
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, variable_memory_max_size="31GB")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(variable_memory_max_size="31GB")
# Set parallel context
if args_opt.distribute == "true":
D.init()
@ -100,27 +101,37 @@ def run_train(args_opt):
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * data_parallel_num
config = PANGUALPHAConfig(
data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num,
batch_size=batch_size, seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size,
num_layers=args_opt.num_layers, num_heads=args_opt.num_heads,
expand_ratio=4, dropout_rate=0.1, compute_dtype=mstype.float16,
stage_num=args_opt.stage_num, micro_size=args_opt.micro_size,
eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=bool(args_opt.optimizer_shard),
recompute=True)
config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
ffn_hidden_size=args_opt.embedding_size * 4,
eod_token=bool(args_opt.eod_reset),
load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
enable_offload=bool(args_opt.opt_offload),
parallel_config=parallel_config)
print("===config is: ", config, flush=True)
# Define network
pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config)
pangu_alpha_with_loss_net = PanguAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
# Warm-up and cosine decay learning rate
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step, decay_steps=200000)
params = pangu_alpha.trainable_params()
lr = LearningRate(learning_rate=args_opt.start_lr,
end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step,
decay_steps=200000)
params = pangu_alpha_with_loss.trainable_params()
group_params = set_weight_decay(params)
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
@ -165,6 +176,7 @@ def run_train(args_opt):
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=args_opt.sink_size, dataset_sink_mode=True)
def run_train_pipeline(args_opt):
r"""
The main training process in pipeline.
@ -203,27 +215,26 @@ def run_train_pipeline(args_opt):
raise ValueError("The dp must large than 1 when applying optimizer shard.")
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
config = PANGUALPHAConfig(
data_parallel_num=data_parallel_num,
model_parallel_num=model_parallel_num,
batch_size=batch_size,
seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size,
embedding_size=args_opt.embedding_size,
num_layers=args_opt.num_layers,
num_heads=args_opt.num_heads,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=bool(args_opt.optimizer_shard),
recompute=True)
config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num,
num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,
seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size,
num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4,
eod_token=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
enable_offload=bool(args_opt.opt_offload), parallel_config=parallel_config)
print("===config is: ", config, flush=True)
pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config)
pangu_alpha_with_loss_net = PipelineCell(PanguAlphaWithLoss(config, pangu_alpha, loss), config.micro_size)
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss),
config.parallel_config.micro_batch_num)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
@ -238,7 +249,8 @@ def run_train_pipeline(args_opt):
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=cache_url,
device_num=stage_device_num,
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
full_batch=context.get_auto_parallel_context("full_batch"),
column_name=args_opt.data_column_name)
@ -246,19 +258,20 @@ def run_train_pipeline(args_opt):
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, micro_size=config.micro_size)]
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id,
micro_size=parallel_config.micro_batch_num)]
loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
if args_opt.train_and_eval_mode:
ds_eval = create_dataset(config.batch_size // config.micro_size, data_path=eval_cache_url,
ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url,
device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True,
data_start_index=0, full_batch=bool(args_opt.full_batch),
column_name=args_opt.data_column_name,
num_samples=args_opt.eval_steps * config.batch_size)
ppl_metric = PPLMetric(config.seq_length)
pangu_alpha_with_loss_eval_net = _VirtualDatasetCell(PanguAlphaWithLoss(config, pangu_alpha, loss))
pangu_alpha_with_loss_eval_net = _VirtualDatasetCell(PanGUAlphaWithLoss(config, pangu_alpha, loss))
model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss_eval_net, metrics={"ppl": ppl_metric})
model.build(ds, ds_eval, sink_size=callback_size)
eval_callback = EvalCallBack(model, ds_eval, ppl_metric)