!17714 Enable full Batch and DataParalel Imports For PanGu-Alpha Model

Merge pull request !17714 from huangxinjing/enable_dataset_import
This commit is contained in:
i-robot 2021-06-15 19:37:43 +08:00 committed by Gitee
commit 077845fd0d
7 changed files with 60 additions and 60 deletions

View File

@ -77,11 +77,11 @@ on 8 cards as follows:
# run distributed training example
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 fp32
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 fp32 2.6B
```
We recommend to run the code on 32 Ascend cards for training 13B models.
By replacing `2.6B` with `13B`, the program will switch to train 13B model (at least 16P).
For distributed training, an hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
@ -94,7 +94,7 @@ https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts:
- tokenizer: vocab.txt and vocab.model
- checkpint file: \*.part\[0-4\] and *.npy under the same parameter size
- checkpoint file: \*.part\[0-4\] and *.npy under the same parameter size
- strategy file: a file described how the parameters are sliced across different devices.
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:

View File

@ -26,6 +26,7 @@ DATA_DIR=$1
export RANK_TABLE_FILE=$2
RANK_SIZE=$3
PARAM_INIT_TYPE=$4
MODE=$5
for((i=0;i<${RANK_SIZE};i++));
@ -36,5 +37,5 @@ do
export RANK_ID=$i
export DEVICE_ID=$i
python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_url=$DATA_DIR --run_type=train \
--param_init_type=$PARAM_INIT_TYPE > log$i.log 2>&1 &
--param_init_type=$PARAM_INIT_TYPE --mode=$MODE > log$i.log 2>&1 &
done

View File

@ -23,14 +23,16 @@ import mindspore.dataset.transforms.c_transforms as C
import mindspore.common.dtype as mstype
def get_input_data(input_ids, eod_id, rank, dis):
def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset):
"""
Generate position_id and attention_mask according to input_ids considering eod reset
Inputs:
input_ids: the input token ids
eod_id: the id for <EOD>
rank: the current rank
dis: the slice value for each rank
eod_reset: whether to open eod reset or not
returns:
input_ids: the input token ids
position_id: the position ids cosidering eod reset
@ -38,8 +40,9 @@ def get_input_data(input_ids, eod_id, rank, dis):
"""
rank = int(rank)
input_ids = input_ids[rank*dis: (rank+1)*dis]
if not eod_reset:
return input_ids
seq_length = input_ids.shape[1] - 1
# Initialize position_ids and attention_mask
batch_input_ids = input_ids
batch_position_ids = np.ones((dis, seq_length))
@ -63,7 +66,7 @@ def get_input_data(input_ids, eod_id, rank, dis):
return batch_input_ids, batch_position_ids, batch_attention_mask
def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0,
def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_batch=False, data_start_index=0,
eod_reset=False, eod_id=9, column_name='input_ids', epoch=1):
"""
Create dataset
@ -86,12 +89,6 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
# Get path for source data files
home_path = os.path.join(os.getcwd(), data_path)
files = os.listdir(data_path)
dis = int(batch_size / device_num)
if dis <= 0:
raise ValueError(
"batch size {} should be a multiple of device number {}.".format(batch_size,
device_num))
data = [
os.path.join(home_path, name) for name in files
if not name.endswith(".db")
@ -102,17 +99,31 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
type_cast_op = C.TypeCast(mstype.int32)
type_cast_op_float = C.TypeCast(mstype.float16)
if full_batch:
# no need to slice from the inputs
rank = 0
dis = batch_size
else:
# Each card slice a small batch from the full batch
dis = int(batch_size / device_num)
if dis <= 0:
raise ValueError(
"batch size {} should be a multiple of device number {}.".format(batch_size, device_num))
map_func = (lambda input_ids: get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset))
# If eod_reset enabled, another two inputs will be generated through input_ids
if eod_reset:
map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis))
dataset = dataset.batch(batch_size, drop_remainder=drop)
dataset = dataset.map(operations=map_func, input_columns=[column_name],
output_columns=["input_ids", "position_id", "attention_mask"],
column_order=["input_ids", "position_id", "attention_mask"])
output_columns=[column_name, "position_id", "attention_mask"],
column_order=[column_name, "position_id", "attention_mask"])
dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
else:
raise ValueError("Not supported here")
dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
dataset = dataset.map(input_columns=[column_name], operations=type_cast_op)
dataset = dataset.batch(batch_size, drop_remainder=drop)
dataset = dataset.map(operations=map_func, input_columns=[column_name],
output_columns=[column_name])
dataset = dataset.map(input_columns=column_name, operations=type_cast_op)
dataset = dataset.repeat(epoch)
return dataset

View File

@ -742,10 +742,6 @@ class PanguAlpha_Model(nn.Cell):
# in other words, in backward process each block will be almosttotally recomputed.
if config.use_recompute:
per_block.recompute()
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
per_block.attention.dropout.dropout_gen_mask.recompute(False)
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
per_block.output.dropout.dropout_gen_mask.recompute(False)
if config.param_init_type == mstype.float16:
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
@ -799,18 +795,15 @@ class PanguAlpha_Model(nn.Cell):
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
# so we fuse communications of embedding to a large value(+100)
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 200)
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 200)
self.top_query_embedding.embedding_table.parallel_optimizer = False
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
self.top_query_embedding.expand.shard(((config.dp, 1),))
self.top_query_layer = QueryLayer(config)
if config.use_recompute:
self.top_query_layer.recompute()
self.top_query_layer.output.dropout.dropout_gen_mask.recompute(False)
self.top_query_layer.attention.dropout.dropout_gen_mask.recompute(False)
self.top_query_layer.attention.prob_dropout.dropout_gen_mask.recompute(False)
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2)
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
self.top_query_layer.layernorm1.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
self.top_query_layer.layernorm2.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
@ -1057,13 +1050,15 @@ class EvalNet(nn.Cell):
self.argmax = P.Argmax()
self.generate = generate
self.topk = P.TopK(sorted=True).shard(((1, 1),))
self.gather = P.GatherV2().shard(((1, 1), (1,)))
self.gather = P.GatherV2().shard(((1, 1, 1), (1,)))
self.log_softmax = P.LogSoftmax().shard(((1, 1),))
def construct(self, input_ids, current_index):
"""evaluation net"""
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
logits = self.backbone(input_ids, input_mask)
logits = self.gather(logits, current_index)
bs, seq_length = F.shape(input_ids)
logits = F.reshape(logits, (bs, seq_length, -1))
logits = self.gather(logits, current_index, 1)
log_probs = self.log_softmax(logits)
return log_probs

View File

@ -86,26 +86,20 @@ def set_parse(args_opt):
Set config according to the mode
"""
if args_opt.mode == "200B":
args_opt.seq_length = 1024
args_opt.vocab_size = 40000
args_opt.embedding_size = 16384
args_opt.num_layers = 64
args_opt.num_heads = 128
if args_opt.run_type == "train":
args_opt.start_lr = 6e-5
args_opt.end_lr = 6e-6
args_opt.optimizer_shard = False
args_opt.optimizer_shard = 0
args_opt.stage_num = 16
args_opt.micro_size = 32
args_opt.tensor_model_parallel_num = 16
args_opt.per_batch_size = 1
elif args_opt.run_type == "predict":
args_opt.stage_num = 4
args_opt.micro_size = 1
args_opt.per_batch_size = 1
elif args_opt.mode == "13B":
args_opt.seq_length = 1024
args_opt.vocab_size = 40000
args_opt.embedding_size = 5120
args_opt.num_layers = 40
args_opt.num_heads = 40
@ -113,17 +107,13 @@ def set_parse(args_opt):
if args_opt.run_type == "train":
args_opt.start_lr = 5e-5
args_opt.end_lr = 1e-6
args_opt.optimizer_shard = True
args_opt.optimizer_shard = 1
args_opt.stage_num = 1
args_opt.micro_size = 1
args_opt.per_batch_size = 16
elif args_opt.run_type == "predict":
args_opt.stage_num = 1
args_opt.micro_size = 1
args_opt.per_batch_size = 1
elif args_opt.mode == "2.6B":
args_opt.seq_length = 1024
args_opt.vocab_size = 40000
args_opt.embedding_size = 2560
args_opt.num_layers = 32
args_opt.num_heads = 32
@ -131,11 +121,9 @@ def set_parse(args_opt):
if args_opt.run_type == "train":
args_opt.start_lr = 1e-4
args_opt.end_lr = 1e-6
args_opt.optimizer_shard = True
args_opt.optimizer_shard = 1
args_opt.stage_num = 1
args_opt.micro_size = 1
args_opt.per_batch_size = 2
elif args_opt.run_type == "predict":
args_opt.stage_num = 1
args_opt.micro_size = 1
args_opt.per_batch_size = 1

View File

@ -154,8 +154,7 @@ class LearningRate(LearningRateSchedule):
warmup_steps,
decay_steps,
power=1.0,
use_cosine=True,
lr_scale=0.125):
use_cosine=True):
super(LearningRate, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
@ -171,7 +170,6 @@ class LearningRate(LearningRateSchedule):
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
self.use_cosine = use_cosine
self.lr_scale = lr_scale
def construct(self, global_step):
"""dynamic learning rate"""
@ -186,7 +184,7 @@ class LearningRate(LearningRateSchedule):
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr * self.lr_scale
return lr
def add_inference_params(opt):
"""Add inference params"""
@ -270,6 +268,18 @@ def add_training_params(opt):
type=int,
default=2,
help="The sink size of the training")
opt.add_argument("--full_batch",
default=0,
help="Import the full size of a batch for each card, default is 0")
opt.add_argument("--optimizer_shard",
type=int,
default=1,
help="Enable optimizer parallel, default is 1")
opt.add_argument("--per_batch_size",
type=int,
default=6,
help="The batch size for each data parallel way. default 32")
def get_args(inference=False):
"""train function for PanguAlpha"""

View File

@ -27,9 +27,8 @@ import mindspore.nn as nn
from mindspore.train.callback import TimeMonitor, Callback
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from src.dataset import create_dataset
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
@ -95,9 +94,8 @@ def run_train(args_opt):
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False,
device_num=device_num,
full_batch=False,
enable_parallel_optimizer=True)
auto_parallel_context().set_loss_repeated_mean(True)
full_batch=bool(args_opt.full_batch),
enable_parallel_optimizer=bool(args_opt.optimizer_shard))
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
@ -108,7 +106,7 @@ def run_train(args_opt):
# Set model property
model_parallel_num = args_opt.tensor_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * device_num
batch_size = args_opt.per_batch_size * data_parallel_num
config = PANGUALPHAConfig(
data_parallel_num=data_parallel_num,
model_parallel_num=model_parallel_num,
@ -121,8 +119,6 @@ def run_train(args_opt):
expand_ratio=4,
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False,
self_layernorm=True,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
eod_reset=bool(args_opt.eod_reset),
@ -142,8 +138,7 @@ def run_train(args_opt):
lr = LearningRate(learning_rate=args_opt.start_lr,
end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step,
decay_steps=200000,
lr_scale=1)
decay_steps=200000)
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
@ -168,7 +163,7 @@ def run_train(args_opt):
epoch_num = args_opt.epoch_size
# Dataset loading mindrecord files
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
data_start_index=0, eod_reset=config.eod_reset,
data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size