forked from mindspore-Ecosystem/mindspore
!16338 Add FP16 Predicition for PanGu-Alpha Model
From: @huangxinjing Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
18e3180ca4
|
@ -115,7 +115,7 @@ def create_group(group, rank_num, rank_ids):
|
|||
c_group = c_str(group)
|
||||
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
|
||||
if ret != 0:
|
||||
raise RuntimeError('Create group error, the error code is ', ret)
|
||||
raise RuntimeError('Create group error, the error code is ' + str(ret))
|
||||
else:
|
||||
raise TypeError('Rank ids must be a python list.')
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
op_cast = P.Cast()
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
|
@ -60,7 +61,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
|
@ -84,7 +84,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
return op_cast(gradient, F.dtype(param))
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
|
|
|
@ -70,17 +70,18 @@ python src/preprocess.py --input_glob data/*.txt
|
|||
|
||||
## Run Training
|
||||
|
||||
After installing MindSpore via the official website, you can start training as follows:
|
||||
After installing MindSpore via the official website, you can start training 2.6B model
|
||||
on 8 cards as follows:
|
||||
|
||||
```bash
|
||||
|
||||
# run distributed training example
|
||||
|
||||
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8
|
||||
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 fp32
|
||||
|
||||
```
|
||||
|
||||
We recommend to run the code on 32 Ascend cards.
|
||||
We recommend to run the code on 32 Ascend cards for training 13B models.
|
||||
|
||||
For distributed training, an hccl configuration file with JSON format needs to be created in advance.
|
||||
Please follow the instructions in the link below:
|
||||
|
@ -96,12 +97,41 @@ Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence
|
|||
- checkpint file: \*.part\[0-4\] and *.npy under the same parameter size
|
||||
- strategy file: a file described how the parameters are sliced across different devices.
|
||||
|
||||
### Run Prediction
|
||||
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:
|
||||
|
||||
```shell
|
||||
ckpts
|
||||
├── checkpoint_file
|
||||
│ ├── filtered_*.ckpt
|
||||
│ ├── word_embedding.npy
|
||||
│ ├── top_query_embedding.npy
|
||||
│ └── position_embedding.npy
|
||||
├── strategy_load_ckpt
|
||||
│ └── strategy.ckpt
|
||||
└── tokenizer
|
||||
├── vocab10.model
|
||||
└── vocab10.vocab
|
||||
```
|
||||
|
||||
### Run Prediction on Distributed mode
|
||||
|
||||
The following script will run prediction on 8 Ascend cards.
|
||||
|
||||
```bash
|
||||
$FILE_PATH=/home/your_path
|
||||
$FILE_PATH=/home/your_path/ckpts
|
||||
bash scripts/run_distribute_predict.sh 8 /home/config/rank_table_8p.json ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
|
||||
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B
|
||||
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp32
|
||||
```
|
||||
|
||||
### Run Prediction Using One Device
|
||||
|
||||
The following script will run prediction on 1 Ascend cards. The difference is the net is initialized with float16 type.
|
||||
And the rank_table should be configured to one device.
|
||||
|
||||
```bash
|
||||
$FILE_PATH=/home/your_path/ckpts
|
||||
bash scripts/run_distribute_predict.sh 1 /home/config/rank_table_1p.json ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
|
||||
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp16
|
||||
```
|
||||
|
||||
### Run Serving
|
||||
|
|
|
@ -97,7 +97,8 @@ def run_predict(args_opt):
|
|||
micro_size=args_opt.micro_size,
|
||||
eod_reset=False,
|
||||
word_emb_dp=True,
|
||||
load_ckpt_path=args_opt.load_ckpt_path)
|
||||
load_ckpt_path=args_opt.load_ckpt_path,
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16)
|
||||
print("===config is: ", config, flush=True)
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ export TOKENIZER=$4
|
|||
export CKPT_PATH=$5
|
||||
export CKPT_NAME=$6
|
||||
export MODE=$7
|
||||
export PARAM_INIT_TYPE=$8
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
|
@ -18,5 +19,5 @@ do
|
|||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --tokenizer_path=$TOKENIZER --load_ckpt_path=$CKPT_PATH \
|
||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict >train_deep$i.log 2>&1 &
|
||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE >log$i.log 2>&1 &
|
||||
done
|
|
@ -25,6 +25,7 @@ ROOT_PATH=`pwd`
|
|||
DATA_DIR=$1
|
||||
export RANK_TABLE_FILE=$2
|
||||
RANK_SIZE=$3
|
||||
PARAM_INIT_TYPE=$4
|
||||
|
||||
|
||||
for((i=0;i<${RANK_SIZE};i++));
|
||||
|
@ -34,5 +35,6 @@ do
|
|||
cd ${ROOT_PATH}/device$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_url=$DATA_DIR --run_type=train >log$i.log 2>&1 &
|
||||
python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_url=$DATA_DIR --run_type=train \
|
||||
--param_init_type=$PARAM_INIT_TYPE > log$i.log 2>&1 &
|
||||
done
|
||||
|
|
|
@ -126,11 +126,11 @@ class Mapping(nn.Cell):
|
|||
self.output_size = output_size
|
||||
self.input_size = input_size
|
||||
self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
|
||||
[input_size, output_size]),
|
||||
[input_size, output_size], config.param_init_type),
|
||||
name="mapping_weight")
|
||||
self.bias = Parameter(initializer("zeros", [
|
||||
output_size,
|
||||
]),
|
||||
], config.param_init_type),
|
||||
name="mapping_bias",
|
||||
parallel_optimizer=False)
|
||||
self.dtype = config.compute_dtype
|
||||
|
@ -167,11 +167,12 @@ class Mapping_output(nn.Cell):
|
|||
self.output_size = output_size
|
||||
self.input_size = input_size
|
||||
self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
|
||||
[input_size, output_size]),
|
||||
[input_size, output_size],
|
||||
config.param_init_type),
|
||||
name="mapping_weight")
|
||||
self.bias = Parameter(initializer("zeros", [
|
||||
output_size,
|
||||
]),
|
||||
], config.param_init_type),
|
||||
name="mapping_bias")
|
||||
self.dtype = config.compute_dtype
|
||||
self.cast = P.Cast()
|
||||
|
@ -358,22 +359,38 @@ class Attention(nn.Cell):
|
|||
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
|
||||
|
||||
dense_shape = [config.embedding_size, config.embedding_size]
|
||||
bias_shape = [config.embedding_size]
|
||||
# Query
|
||||
self.dense1 = nn.Dense(config.embedding_size,
|
||||
config.embedding_size).to_float(
|
||||
config.compute_dtype)
|
||||
config.embedding_size,
|
||||
weight_init=initializer(init='normal', shape=dense_shape,
|
||||
dtype=config.param_init_type),
|
||||
bias_init=initializer(init='zeros', shape=bias_shape,
|
||||
dtype=config.param_init_type)).to_float(config.compute_dtype)
|
||||
self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||
self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||
# Key
|
||||
self.dense2 = nn.Dense(config.embedding_size,
|
||||
config.embedding_size).to_float(
|
||||
config.compute_dtype)
|
||||
config.embedding_size,
|
||||
weight_init=initializer(init='normal',
|
||||
shape=dense_shape,
|
||||
dtype=config.param_init_type),
|
||||
bias_init=initializer(init='zeros',
|
||||
shape=bias_shape,
|
||||
dtype=config.param_init_type)).to_float(config.compute_dtype)
|
||||
self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||
self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||
# Value
|
||||
self.dense3 = nn.Dense(config.embedding_size,
|
||||
config.embedding_size).to_float(
|
||||
config.compute_dtype)
|
||||
config.embedding_size,
|
||||
weight_init=initializer(init='normal',
|
||||
shape=dense_shape,
|
||||
dtype=config.param_init_type),
|
||||
bias_init=initializer(init='zeros',
|
||||
shape=bias_shape,
|
||||
dtype=config.param_init_type)).to_float(config.compute_dtype)
|
||||
self.dense3.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||
self.dense3.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||
|
||||
|
@ -723,13 +740,19 @@ class PanguAlpha_Model(nn.Cell):
|
|||
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
||||
# in other words, in backward process each block will be almosttotally recomputed.
|
||||
per_block.recompute()
|
||||
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
|
||||
per_block.attention.dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.output.dropout.dropout_gen_mask.recompute(False)
|
||||
if config.use_recompute:
|
||||
per_block.recompute()
|
||||
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
|
||||
per_block.attention.dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.output.dropout.dropout_gen_mask.recompute(False)
|
||||
if config.param_init_type == mstype.float16:
|
||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
||||
# so we fuse communications of layernorm to a large value(+100)
|
||||
per_block.layernorm1.set_comm_fusion(int(int(i / fusion_group_size) + 100))
|
||||
per_block.layernorm2.set_comm_fusion(int(int(i / fusion_group_size) + 100))
|
||||
self.blocks.append(per_block)
|
||||
|
||||
if config.self_layernorm:
|
||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
|
||||
mstype.float32).set_comm_fusion(
|
||||
|
@ -741,6 +764,11 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||
self.layernorm.gamma.parallel_optimizer = False
|
||||
self.layernorm.beta.parallel_optimizer = False
|
||||
if config.param_init_type == mstype.float16:
|
||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
||||
# so we fuse communications of layernorm to a large value(+100)
|
||||
self.layernorm.set_comm_fusion(int(num_layers / fusion_group_size + 100))
|
||||
self.use_past = config.use_past
|
||||
self.past = tuple([None] * config.num_layers)
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
|
@ -768,19 +796,24 @@ class PanguAlpha_Model(nn.Cell):
|
|||
|
||||
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
|
||||
embedding_table=top_query_table_param)
|
||||
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2)
|
||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
||||
# so we fuse communications of embedding to a large value(+100)
|
||||
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 200)
|
||||
self.top_query_embedding.embedding_table.parallel_optimizer = False
|
||||
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.top_query_embedding.expand.shard(((config.dp, 1),))
|
||||
self.top_query_layer = QueryLayer(config)
|
||||
if config.use_recompute:
|
||||
self.top_query_layer.recompute()
|
||||
|
||||
self.top_query_layer.output.dropout.dropout_gen_mask.recompute(False)
|
||||
self.top_query_layer.attention.dropout.dropout_gen_mask.recompute(False)
|
||||
self.top_query_layer.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||
|
||||
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2)
|
||||
self.top_query_layer.layernorm1.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
|
||||
self.top_query_layer.layernorm2.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
|
||||
|
||||
self.use_top_query_attention = config.use_top_query_attention
|
||||
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ class PANGUALPHAConfig:
|
|||
micro_size=32,
|
||||
load_ckpt_path=None,
|
||||
use_top_query_attention=True,
|
||||
param_init_type=mstype.float32,
|
||||
use_recompute=True):
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
|
@ -70,6 +71,7 @@ class PANGUALPHAConfig:
|
|||
self.load_ckpt_path = load_ckpt_path
|
||||
self.use_top_query_attention = use_top_query_attention
|
||||
self.use_recompute = use_recompute
|
||||
self.param_init_type = param_init_type
|
||||
|
||||
def __str__(self):
|
||||
info = "[PANGUALPHAConfig]" + '===' * 10 + '\n'
|
||||
|
|
|
@ -27,6 +27,45 @@ from mindspore.nn.learning_rate_schedule import LearningRateSchedule, Polynomial
|
|||
|
||||
from mindspore.parallel._utils import _get_global_rank
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn import AdamWeightDecay
|
||||
from mindspore.common import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
class FP32StateAdamWeightDecay(AdamWeightDecay):
|
||||
r"""
|
||||
This class is almost same with the mindspore's AdamWeightDecay implements, the
|
||||
only difference is the optimizer's state will be always initialized with float32,
|
||||
where the original AdamWeightDecay will initialize the optimizer's state with float16,
|
||||
if the parameters are initialized with fp16.
|
||||
This setting will avoid overflow in training PanGu-Alpha model using fp16.
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(FP32StateAdamWeightDecay, self).__init__(params, learning_rate=learning_rate,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
|
||||
self.moments1 = self.clone_state(self.parameters, prefix='adam_m', init='zeros')
|
||||
self.moments2 = self.clone_state(self.parameters, prefix='adam_v', init='zeros')
|
||||
|
||||
def clone_state(self, parameter_tuple, prefix, init):
|
||||
r"""
|
||||
parameter_tuple: ParameterTuple. The parameters of the network
|
||||
prefix: str. The prefix name of the parameters
|
||||
init: str. The initialization method
|
||||
"""
|
||||
new = []
|
||||
for old_param in parameter_tuple:
|
||||
new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32))
|
||||
new_state.param_info = old_param.param_info.clone()
|
||||
new_state.is_init = False
|
||||
new_state.set_data(initializer(init, shape=old_param.shape, dtype=mstype.float32))
|
||||
new_state.name = prefix + '.' + new_state.name
|
||||
new.append(new_state)
|
||||
return ParameterTuple(new)
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
|
||||
|
@ -256,6 +295,10 @@ def get_args():
|
|||
type=str,
|
||||
default="./tokenizer_path",
|
||||
help="The path where stores vocab and vocab model file")
|
||||
parser.add_argument("--param_init_type",
|
||||
type=str,
|
||||
default="fp32",
|
||||
help="The initialization type for parameters. Default fp32.")
|
||||
|
||||
add_training_params(parser)
|
||||
args_opt = parser.parse_args()
|
||||
|
|
|
@ -34,7 +34,7 @@ from src.dataset import create_dataset
|
|||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from src.utils import LearningRate, get_args
|
||||
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
|
@ -95,7 +95,7 @@ def run_train(args_opt):
|
|||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
gradients_mean=False,
|
||||
device_num=device_num,
|
||||
full_batch=True,
|
||||
full_batch=False,
|
||||
enable_parallel_optimizer=True)
|
||||
auto_parallel_context().set_loss_repeated_mean(True)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
|
@ -126,6 +126,7 @@ def run_train(args_opt):
|
|||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
eod_reset=bool(args_opt.eod_reset),
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
|
||||
word_emb_dp=True)
|
||||
print("===config is: ", config, flush=True)
|
||||
|
||||
|
@ -161,7 +162,7 @@ def run_train(args_opt):
|
|||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
# Initial scaling sens
|
||||
loss_scale_value = math.pow(2, 32)
|
||||
epoch_num = args_opt.epoch_size
|
||||
|
|
Loading…
Reference in New Issue