forked from mindspore-Ecosystem/mindspore
!17714 Enable full Batch and DataParalel Imports For PanGu-Alpha Model
Merge pull request !17714 from huangxinjing/enable_dataset_import
This commit is contained in:
commit
077845fd0d
|
@ -77,11 +77,11 @@ on 8 cards as follows:
|
|||
|
||||
# run distributed training example
|
||||
|
||||
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 fp32
|
||||
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 fp32 2.6B
|
||||
|
||||
```
|
||||
|
||||
We recommend to run the code on 32 Ascend cards for training 13B models.
|
||||
By replacing `2.6B` with `13B`, the program will switch to train 13B model (at least 16P).
|
||||
|
||||
For distributed training, an hccl configuration file with JSON format needs to be created in advance.
|
||||
Please follow the instructions in the link below:
|
||||
|
@ -94,7 +94,7 @@ https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
|
|||
Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts:
|
||||
|
||||
- tokenizer: vocab.txt and vocab.model
|
||||
- checkpint file: \*.part\[0-4\] and *.npy under the same parameter size
|
||||
- checkpoint file: \*.part\[0-4\] and *.npy under the same parameter size
|
||||
- strategy file: a file described how the parameters are sliced across different devices.
|
||||
|
||||
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:
|
||||
|
|
|
@ -26,6 +26,7 @@ DATA_DIR=$1
|
|||
export RANK_TABLE_FILE=$2
|
||||
RANK_SIZE=$3
|
||||
PARAM_INIT_TYPE=$4
|
||||
MODE=$5
|
||||
|
||||
|
||||
for((i=0;i<${RANK_SIZE};i++));
|
||||
|
@ -36,5 +37,5 @@ do
|
|||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_url=$DATA_DIR --run_type=train \
|
||||
--param_init_type=$PARAM_INIT_TYPE > log$i.log 2>&1 &
|
||||
--param_init_type=$PARAM_INIT_TYPE --mode=$MODE > log$i.log 2>&1 &
|
||||
done
|
||||
|
|
|
@ -23,14 +23,16 @@ import mindspore.dataset.transforms.c_transforms as C
|
|||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
def get_input_data(input_ids, eod_id, rank, dis):
|
||||
def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset):
|
||||
"""
|
||||
Generate position_id and attention_mask according to input_ids considering eod reset
|
||||
|
||||
Inputs:
|
||||
input_ids: the input token ids
|
||||
eod_id: the id for <EOD>
|
||||
|
||||
rank: the current rank
|
||||
dis: the slice value for each rank
|
||||
eod_reset: whether to open eod reset or not
|
||||
returns:
|
||||
input_ids: the input token ids
|
||||
position_id: the position ids cosidering eod reset
|
||||
|
@ -38,8 +40,9 @@ def get_input_data(input_ids, eod_id, rank, dis):
|
|||
"""
|
||||
rank = int(rank)
|
||||
input_ids = input_ids[rank*dis: (rank+1)*dis]
|
||||
if not eod_reset:
|
||||
return input_ids
|
||||
seq_length = input_ids.shape[1] - 1
|
||||
|
||||
# Initialize position_ids and attention_mask
|
||||
batch_input_ids = input_ids
|
||||
batch_position_ids = np.ones((dis, seq_length))
|
||||
|
@ -63,7 +66,7 @@ def get_input_data(input_ids, eod_id, rank, dis):
|
|||
return batch_input_ids, batch_position_ids, batch_attention_mask
|
||||
|
||||
|
||||
def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0,
|
||||
def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_batch=False, data_start_index=0,
|
||||
eod_reset=False, eod_id=9, column_name='input_ids', epoch=1):
|
||||
"""
|
||||
Create dataset
|
||||
|
@ -86,12 +89,6 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
|
|||
# Get path for source data files
|
||||
home_path = os.path.join(os.getcwd(), data_path)
|
||||
files = os.listdir(data_path)
|
||||
dis = int(batch_size / device_num)
|
||||
if dis <= 0:
|
||||
raise ValueError(
|
||||
"batch size {} should be a multiple of device number {}.".format(batch_size,
|
||||
device_num))
|
||||
|
||||
data = [
|
||||
os.path.join(home_path, name) for name in files
|
||||
if not name.endswith(".db")
|
||||
|
@ -102,17 +99,31 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
|
|||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
type_cast_op_float = C.TypeCast(mstype.float16)
|
||||
|
||||
if full_batch:
|
||||
# no need to slice from the inputs
|
||||
rank = 0
|
||||
dis = batch_size
|
||||
else:
|
||||
# Each card slice a small batch from the full batch
|
||||
dis = int(batch_size / device_num)
|
||||
if dis <= 0:
|
||||
raise ValueError(
|
||||
"batch size {} should be a multiple of device number {}.".format(batch_size, device_num))
|
||||
|
||||
map_func = (lambda input_ids: get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset))
|
||||
# If eod_reset enabled, another two inputs will be generated through input_ids
|
||||
if eod_reset:
|
||||
map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis))
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop)
|
||||
dataset = dataset.map(operations=map_func, input_columns=[column_name],
|
||||
output_columns=["input_ids", "position_id", "attention_mask"],
|
||||
column_order=["input_ids", "position_id", "attention_mask"])
|
||||
output_columns=[column_name, "position_id", "attention_mask"],
|
||||
column_order=[column_name, "position_id", "attention_mask"])
|
||||
dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
|
||||
dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
|
||||
else:
|
||||
raise ValueError("Not supported here")
|
||||
dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
|
||||
dataset = dataset.map(input_columns=[column_name], operations=type_cast_op)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop)
|
||||
dataset = dataset.map(operations=map_func, input_columns=[column_name],
|
||||
output_columns=[column_name])
|
||||
dataset = dataset.map(input_columns=column_name, operations=type_cast_op)
|
||||
dataset = dataset.repeat(epoch)
|
||||
return dataset
|
||||
|
|
|
@ -742,10 +742,6 @@ class PanguAlpha_Model(nn.Cell):
|
|||
# in other words, in backward process each block will be almosttotally recomputed.
|
||||
if config.use_recompute:
|
||||
per_block.recompute()
|
||||
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
|
||||
per_block.attention.dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.output.dropout.dropout_gen_mask.recompute(False)
|
||||
if config.param_init_type == mstype.float16:
|
||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
||||
|
@ -799,18 +795,15 @@ class PanguAlpha_Model(nn.Cell):
|
|||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
||||
# so we fuse communications of embedding to a large value(+100)
|
||||
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 200)
|
||||
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 200)
|
||||
self.top_query_embedding.embedding_table.parallel_optimizer = False
|
||||
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.top_query_embedding.expand.shard(((config.dp, 1),))
|
||||
self.top_query_layer = QueryLayer(config)
|
||||
if config.use_recompute:
|
||||
self.top_query_layer.recompute()
|
||||
self.top_query_layer.output.dropout.dropout_gen_mask.recompute(False)
|
||||
self.top_query_layer.attention.dropout.dropout_gen_mask.recompute(False)
|
||||
self.top_query_layer.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||
|
||||
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2)
|
||||
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
|
||||
self.top_query_layer.layernorm1.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
|
||||
self.top_query_layer.layernorm2.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
|
||||
|
||||
|
@ -1057,13 +1050,15 @@ class EvalNet(nn.Cell):
|
|||
self.argmax = P.Argmax()
|
||||
self.generate = generate
|
||||
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
||||
self.gather = P.GatherV2().shard(((1, 1), (1,)))
|
||||
self.gather = P.GatherV2().shard(((1, 1, 1), (1,)))
|
||||
self.log_softmax = P.LogSoftmax().shard(((1, 1),))
|
||||
|
||||
def construct(self, input_ids, current_index):
|
||||
"""evaluation net"""
|
||||
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
|
||||
logits = self.backbone(input_ids, input_mask)
|
||||
logits = self.gather(logits, current_index)
|
||||
bs, seq_length = F.shape(input_ids)
|
||||
logits = F.reshape(logits, (bs, seq_length, -1))
|
||||
logits = self.gather(logits, current_index, 1)
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
|
|
@ -86,26 +86,20 @@ def set_parse(args_opt):
|
|||
Set config according to the mode
|
||||
"""
|
||||
if args_opt.mode == "200B":
|
||||
args_opt.seq_length = 1024
|
||||
args_opt.vocab_size = 40000
|
||||
args_opt.embedding_size = 16384
|
||||
args_opt.num_layers = 64
|
||||
args_opt.num_heads = 128
|
||||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 6e-5
|
||||
args_opt.end_lr = 6e-6
|
||||
args_opt.optimizer_shard = False
|
||||
args_opt.optimizer_shard = 0
|
||||
args_opt.stage_num = 16
|
||||
args_opt.micro_size = 32
|
||||
args_opt.tensor_model_parallel_num = 16
|
||||
args_opt.per_batch_size = 1
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 4
|
||||
args_opt.micro_size = 1
|
||||
args_opt.per_batch_size = 1
|
||||
elif args_opt.mode == "13B":
|
||||
args_opt.seq_length = 1024
|
||||
args_opt.vocab_size = 40000
|
||||
args_opt.embedding_size = 5120
|
||||
args_opt.num_layers = 40
|
||||
args_opt.num_heads = 40
|
||||
|
@ -113,17 +107,13 @@ def set_parse(args_opt):
|
|||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 5e-5
|
||||
args_opt.end_lr = 1e-6
|
||||
args_opt.optimizer_shard = True
|
||||
args_opt.optimizer_shard = 1
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
args_opt.per_batch_size = 16
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
args_opt.per_batch_size = 1
|
||||
elif args_opt.mode == "2.6B":
|
||||
args_opt.seq_length = 1024
|
||||
args_opt.vocab_size = 40000
|
||||
args_opt.embedding_size = 2560
|
||||
args_opt.num_layers = 32
|
||||
args_opt.num_heads = 32
|
||||
|
@ -131,11 +121,9 @@ def set_parse(args_opt):
|
|||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 1e-4
|
||||
args_opt.end_lr = 1e-6
|
||||
args_opt.optimizer_shard = True
|
||||
args_opt.optimizer_shard = 1
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
args_opt.per_batch_size = 2
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
args_opt.per_batch_size = 1
|
||||
|
|
|
@ -154,8 +154,7 @@ class LearningRate(LearningRateSchedule):
|
|||
warmup_steps,
|
||||
decay_steps,
|
||||
power=1.0,
|
||||
use_cosine=True,
|
||||
lr_scale=0.125):
|
||||
use_cosine=True):
|
||||
super(LearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
|
@ -171,7 +170,6 @@ class LearningRate(LearningRateSchedule):
|
|||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
self.use_cosine = use_cosine
|
||||
self.lr_scale = lr_scale
|
||||
|
||||
def construct(self, global_step):
|
||||
"""dynamic learning rate"""
|
||||
|
@ -186,7 +184,7 @@ class LearningRate(LearningRateSchedule):
|
|||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr * self.lr_scale
|
||||
return lr
|
||||
|
||||
def add_inference_params(opt):
|
||||
"""Add inference params"""
|
||||
|
@ -270,6 +268,18 @@ def add_training_params(opt):
|
|||
type=int,
|
||||
default=2,
|
||||
help="The sink size of the training")
|
||||
opt.add_argument("--full_batch",
|
||||
default=0,
|
||||
help="Import the full size of a batch for each card, default is 0")
|
||||
opt.add_argument("--optimizer_shard",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Enable optimizer parallel, default is 1")
|
||||
opt.add_argument("--per_batch_size",
|
||||
type=int,
|
||||
default=6,
|
||||
help="The batch size for each data parallel way. default 32")
|
||||
|
||||
|
||||
def get_args(inference=False):
|
||||
"""train function for PanguAlpha"""
|
||||
|
|
|
@ -27,9 +27,8 @@ import mindspore.nn as nn
|
|||
from mindspore.train.callback import TimeMonitor, Callback
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from src.dataset import create_dataset
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
|
||||
|
@ -95,9 +94,8 @@ def run_train(args_opt):
|
|||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
gradients_mean=False,
|
||||
device_num=device_num,
|
||||
full_batch=False,
|
||||
enable_parallel_optimizer=True)
|
||||
auto_parallel_context().set_loss_repeated_mean(True)
|
||||
full_batch=bool(args_opt.full_batch),
|
||||
enable_parallel_optimizer=bool(args_opt.optimizer_shard))
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
_set_multi_subgraphs()
|
||||
|
||||
|
@ -108,7 +106,7 @@ def run_train(args_opt):
|
|||
# Set model property
|
||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
batch_size = args_opt.per_batch_size * device_num
|
||||
batch_size = args_opt.per_batch_size * data_parallel_num
|
||||
config = PANGUALPHAConfig(
|
||||
data_parallel_num=data_parallel_num,
|
||||
model_parallel_num=model_parallel_num,
|
||||
|
@ -121,8 +119,6 @@ def run_train(args_opt):
|
|||
expand_ratio=4,
|
||||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False,
|
||||
self_layernorm=True,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
eod_reset=bool(args_opt.eod_reset),
|
||||
|
@ -142,8 +138,7 @@ def run_train(args_opt):
|
|||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||
end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=200000,
|
||||
lr_scale=1)
|
||||
decay_steps=200000)
|
||||
|
||||
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
|
@ -168,7 +163,7 @@ def run_train(args_opt):
|
|||
epoch_num = args_opt.epoch_size
|
||||
# Dataset loading mindrecord files
|
||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
|
||||
data_start_index=0, eod_reset=config.eod_reset,
|
||||
data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
|
||||
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
callback_size = args_opt.sink_size
|
||||
|
|
Loading…
Reference in New Issue