!18440 Fix the transformer network overflow while training standalone with 1 card.

Merge pull request !18440 from casgj/master_0615transformer
This commit is contained in:
i-robot 2021-06-18 08:52:27 +00:00 committed by Gitee
commit a5dff69a86
9 changed files with 306 additions and 16 deletions

View File

@ -78,7 +78,10 @@ python eval.py > eval.log 2>&1 &
├─process_output.sh
├─replace-quote.perl
├─run_distribute_train_ascend.sh
└─run_standalone_train_ascend.sh
├─run_distribute_train_ascend_multi_machines.sh
├─run_eval.sh
├─run_infer_310.sh
└─run_standalone_train.sh
├─src
├─__init__.py
├─beam_search.py
@ -93,6 +96,10 @@ python eval.py > eval.log 2>&1 &
└─weight_init.py
├─create_data.py
├─eval.py
├─export.py
├─mindspore_hub_conf.py
├─postprocess.py
├─preprocess.py
└─train.py
```
@ -201,7 +208,7 @@ Parameters for learning rate:
- Run `run_standalone_train.sh` for non-distributed training of Transformer model.
``` bash
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP DATA_PATH
```
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.

View File

@ -84,7 +84,10 @@ python eval.py > eval.log 2>&1 &
├─process_output.sh
├─replace-quote.perl
├─run_distribute_train_ascend.sh
└─run_standalone_train_ascend.sh
├─run_distribute_train_ascend_multi_machines.sh
├─run_eval.sh
├─run_infer_310.sh
└─run_standalone_train.sh
├─src
├─__init__.py
├─beam_search.py
@ -99,6 +102,10 @@ python eval.py > eval.log 2>&1 &
└─weight_init.py
├─create_data.py
├─eval.py
├─export.py
├─mindspore_hub_conf.py
├─postprocess.py
├─preprocess.py
└─train.py
```
@ -208,7 +215,7 @@ Parameters for learning rate:
- 运行`run_standalone_train.sh`进行Transformer模型的非分布式训练。
``` bash
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP DATA_PATH
```
- 运行`run_distribute_train_ascend.sh`进行Transformer模型的非分布式训练。

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE"
echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
fi
rm -rf run_distribute_train
mkdir run_distribute_train

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distribute_train_ascend_multi_machines.sh DEVICE_NUM SERVER_ID EPOCH_SIZE DATA_PATH RANK_TABLE_FILE"
echo "for example: sh run_distribute_train_ascend_multi_machines.sh 32 0 52 /path/ende-l128-mindrecord00 /path/hccl.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
fi
rm -rf run_distribute_train
mkdir run_distribute_train
cd run_distribute_train || exit
EPOCH_SIZE=$3
DATA_PATH=$4
export HCCL_CONNECT_TIMEOUT=600
export RANK_TABLE_FILE=$5
export RANK_SIZE=$1
export SERVER_ID=$2
export DEVICE_NUM=8
export HCCL_FLAG=1
export DEPLOY_MODE=0
RANK_START=$((DEVICE_NUM*SERVER_ID))
for((i=0;i<DEVICE_NUM;i++))
do
export DEVICE_ID=$i
export RANK_ID=$((i+RANK_START))
export GE_USE_STATIC_MEMORY=1
mkdir helper$i
cp -rf ../src/ ../train.py ./helper$i
cd ./helper$i || exit
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python train.py \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--device_num=$DEVICE_NUM \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--checkpoint_path="" \
--save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \
--data_path=$DATA_PATH \
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
cd ../
done
cd ..

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distribute_pretrain_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_PATH"
echo "sh run_distribute_train_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_PATH"
echo "for example: sh run_distribute_pretrain.sh 8 55 /path/ende-l128-mindrecord00"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
fi
rm -rf run_distribute_train
mkdir run_distribute_train

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_eval.sh DEVICE_TARGET DEVICE_ID"
echo "for example: sh run_eval.sh Ascend 0"
echo "Note: set the checkpoint and dataset path in src/eval_config.py"
echo "=============================================================================================================="
exit 1;
fi
export DEVICE_TARGET=$1
DEVICE_ID=$2

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH"
echo "for example: sh run_standalone_train.sh Ascend 0 52 /path/ende-l128-mindrecord00"
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP DATA_PATH"
echo "for example: sh run_standalone_train.sh Ascend 0 52 8 /path/ende-l128-mindrecord00"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
fi
rm -rf run_standalone_train
mkdir run_standalone_train
@ -29,12 +31,14 @@ cd run_standalone_train || exit
export DEVICE_TARGET=$1
DEVICE_ID=$2
EPOCH_SIZE=$3
DATA_PATH=$4
GRADIENT_ACCUMULATE_STEP=$4
DATA_PATH=$5
if [ $DEVICE_TARGET == 'Ascend' ];then
python train.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--accumulation_steps=$GRADIENT_ACCUMULATE_STEP \
--device_target=$DEVICE_TARGET \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \

View File

@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
"""Transformer for training."""
import numpy as np
from mindspore.common.initializer import initializer
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@ -23,6 +25,8 @@ from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore import context
from .transformer_model import TransformerModel
@ -279,3 +283,190 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)
cast = P.Cast()
add_grads = C.MultitypeFuncGraph("add_grads")
@add_grads.register("Tensor", "Tensor")
def _add_grads(accu_grad, grad):
return accu_grad + cast(grad, mstype.float32)
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
@accumulate_accu_grads.register("Tensor", "Tensor")
def _accumulate_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
zeroslike = P.ZerosLike()
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
class TransformerTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will only be implemented in the weight updated step,
i.e. the sub-step after gradients accumulated N times.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=8, enable_global_norm=False):
super(TransformerTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32))
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
def construct(self,
source_eos_ids,
source_eos_mask,
target_sos_ids,
target_sos_mask,
target_eos_ids,
target_eos_mask,
sens=None):
"""Defines the computation performed."""
source_ids = source_eos_ids
source_mask = source_eos_mask
target_ids = target_sos_ids
target_mask = target_sos_mask
label_ids = target_eos_ids
label_weights = target_eos_mask
weights = self.weights
loss = self.network(source_ids,
source_mask,
target_ids,
target_mask,
label_ids,
label_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
init = F.depend(init, loss)
clear_status = self.clear_status(init)
scaling_sens = F.depend(scaling_sens, clear_status)
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
grads = self.grad(self.network, weights)(source_ids,
source_mask,
target_ids,
target_mask,
label_ids,
label_weights,
self.cast(scaling_sens,
mstype.float32))
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ)
init = F.depend(init, mean_loss)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
if is_accu_step:
succ = False
else:
# apply grad reducer on grads
grads = self.grad_reducer(self.accu_grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
if self.enable_global_norm:
grads = C.clip_by_global_norm(grads, 1.0, None)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
accu_overflow = F.depend(accu_overflow, grads)
accu_overflow = self.overflow_reducer(accu_overflow)
overflow = self.less_equal(self.base, accu_overflow)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
overflow = F.depend(overflow, accu_succ)
overflow = self.reshape(overflow, (()))
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)

View File

@ -34,7 +34,8 @@ from mindspore import context
from mindspore.common import set_seed
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
TransformerTrainOneStepWithLossScaleCell
TransformerTrainOneStepWithLossScaleCell, \
TransformerTrainAccumulationAllReducePostWithLossScaleCell
from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu
from src.dataset import create_transformer_dataset
from src.lr_schedule import create_dynamic_lr
@ -118,10 +119,10 @@ def argparse_init():
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--bucket_boundaries", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
help="sequence length for different bucket")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps, default is 1.")
return parser
def run_transformer_train():
"""
Transformer training.
@ -203,8 +204,13 @@ def run_transformer_train():
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
update_cell = scale_manager.get_update_cell()
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
if args.accumulation_steps > 1:
netwithgrads = TransformerTrainAccumulationAllReducePostWithLossScaleCell(netwithloss, optimizer,
update_cell,
args.accumulation_steps)
else:
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer)