forked from mindspore-Ecosystem/mindspore
!18440 Fix the transformer network overflow while training standalone with 1 card.
Merge pull request !18440 from casgj/master_0615transformer
This commit is contained in:
commit
a5dff69a86
|
@ -78,7 +78,10 @@ python eval.py > eval.log 2>&1 &
|
||||||
├─process_output.sh
|
├─process_output.sh
|
||||||
├─replace-quote.perl
|
├─replace-quote.perl
|
||||||
├─run_distribute_train_ascend.sh
|
├─run_distribute_train_ascend.sh
|
||||||
└─run_standalone_train_ascend.sh
|
├─run_distribute_train_ascend_multi_machines.sh
|
||||||
|
├─run_eval.sh
|
||||||
|
├─run_infer_310.sh
|
||||||
|
└─run_standalone_train.sh
|
||||||
├─src
|
├─src
|
||||||
├─__init__.py
|
├─__init__.py
|
||||||
├─beam_search.py
|
├─beam_search.py
|
||||||
|
@ -93,6 +96,10 @@ python eval.py > eval.log 2>&1 &
|
||||||
└─weight_init.py
|
└─weight_init.py
|
||||||
├─create_data.py
|
├─create_data.py
|
||||||
├─eval.py
|
├─eval.py
|
||||||
|
├─export.py
|
||||||
|
├─mindspore_hub_conf.py
|
||||||
|
├─postprocess.py
|
||||||
|
├─preprocess.py
|
||||||
└─train.py
|
└─train.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -201,7 +208,7 @@ Parameters for learning rate:
|
||||||
- Run `run_standalone_train.sh` for non-distributed training of Transformer model.
|
- Run `run_standalone_train.sh` for non-distributed training of Transformer model.
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
|
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP DATA_PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
|
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
|
||||||
|
|
|
@ -84,7 +84,10 @@ python eval.py > eval.log 2>&1 &
|
||||||
├─process_output.sh
|
├─process_output.sh
|
||||||
├─replace-quote.perl
|
├─replace-quote.perl
|
||||||
├─run_distribute_train_ascend.sh
|
├─run_distribute_train_ascend.sh
|
||||||
└─run_standalone_train_ascend.sh
|
├─run_distribute_train_ascend_multi_machines.sh
|
||||||
|
├─run_eval.sh
|
||||||
|
├─run_infer_310.sh
|
||||||
|
└─run_standalone_train.sh
|
||||||
├─src
|
├─src
|
||||||
├─__init__.py
|
├─__init__.py
|
||||||
├─beam_search.py
|
├─beam_search.py
|
||||||
|
@ -99,6 +102,10 @@ python eval.py > eval.log 2>&1 &
|
||||||
└─weight_init.py
|
└─weight_init.py
|
||||||
├─create_data.py
|
├─create_data.py
|
||||||
├─eval.py
|
├─eval.py
|
||||||
|
├─export.py
|
||||||
|
├─mindspore_hub_conf.py
|
||||||
|
├─postprocess.py
|
||||||
|
├─preprocess.py
|
||||||
└─train.py
|
└─train.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -208,7 +215,7 @@ Parameters for learning rate:
|
||||||
- 运行`run_standalone_train.sh`,进行Transformer模型的非分布式训练。
|
- 运行`run_standalone_train.sh`,进行Transformer模型的非分布式训练。
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
|
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP DATA_PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
- 运行`run_distribute_train_ascend.sh`,进行Transformer模型的非分布式训练。
|
- 运行`run_distribute_train_ascend.sh`,进行Transformer模型的非分布式训练。
|
||||||
|
|
|
@ -13,13 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
if [ $# != 4 ] ; then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE"
|
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE"
|
||||||
echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json"
|
echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json"
|
||||||
echo "It is better to use absolute path."
|
echo "It is better to use absolute path."
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
rm -rf run_distribute_train
|
rm -rf run_distribute_train
|
||||||
mkdir run_distribute_train
|
mkdir run_distribute_train
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
if [ $# != 5 ] ; then
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh run_distribute_train_ascend_multi_machines.sh DEVICE_NUM SERVER_ID EPOCH_SIZE DATA_PATH RANK_TABLE_FILE"
|
||||||
|
echo "for example: sh run_distribute_train_ascend_multi_machines.sh 32 0 52 /path/ende-l128-mindrecord00 /path/hccl.json"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -rf run_distribute_train
|
||||||
|
mkdir run_distribute_train
|
||||||
|
cd run_distribute_train || exit
|
||||||
|
|
||||||
|
EPOCH_SIZE=$3
|
||||||
|
DATA_PATH=$4
|
||||||
|
|
||||||
|
export HCCL_CONNECT_TIMEOUT=600
|
||||||
|
export RANK_TABLE_FILE=$5
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
export SERVER_ID=$2
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export HCCL_FLAG=1
|
||||||
|
export DEPLOY_MODE=0
|
||||||
|
|
||||||
|
RANK_START=$((DEVICE_NUM*SERVER_ID))
|
||||||
|
|
||||||
|
for((i=0;i<DEVICE_NUM;i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$((i+RANK_START))
|
||||||
|
export GE_USE_STATIC_MEMORY=1
|
||||||
|
|
||||||
|
mkdir helper$i
|
||||||
|
cp -rf ../src/ ../train.py ./helper$i
|
||||||
|
cd ./helper$i || exit
|
||||||
|
echo "start training for rank $i, device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python train.py \
|
||||||
|
--distribute="true" \
|
||||||
|
--epoch_size=$EPOCH_SIZE \
|
||||||
|
--device_id=$DEVICE_ID \
|
||||||
|
--device_num=$DEVICE_NUM \
|
||||||
|
--enable_save_ckpt="true" \
|
||||||
|
--enable_lossscale="true" \
|
||||||
|
--do_shuffle="true" \
|
||||||
|
--checkpoint_path="" \
|
||||||
|
--save_checkpoint_steps=2500 \
|
||||||
|
--save_checkpoint_num=30 \
|
||||||
|
--data_path=$DATA_PATH \
|
||||||
|
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
|
||||||
|
cd ../
|
||||||
|
done
|
||||||
|
cd ..
|
|
@ -13,13 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
if [ $# != 3 ] ; then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "sh run_distribute_pretrain_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_PATH"
|
echo "sh run_distribute_train_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_PATH"
|
||||||
echo "for example: sh run_distribute_pretrain.sh 8 55 /path/ende-l128-mindrecord00"
|
echo "for example: sh run_distribute_pretrain.sh 8 55 /path/ende-l128-mindrecord00"
|
||||||
echo "It is better to use absolute path."
|
echo "It is better to use absolute path."
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
rm -rf run_distribute_train
|
rm -rf run_distribute_train
|
||||||
mkdir run_distribute_train
|
mkdir run_distribute_train
|
||||||
|
|
|
@ -13,13 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
if [ $# != 2 ] ; then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "sh run_eval.sh DEVICE_TARGET DEVICE_ID"
|
echo "sh run_eval.sh DEVICE_TARGET DEVICE_ID"
|
||||||
echo "for example: sh run_eval.sh Ascend 0"
|
echo "for example: sh run_eval.sh Ascend 0"
|
||||||
echo "Note: set the checkpoint and dataset path in src/eval_config.py"
|
echo "Note: set the checkpoint and dataset path in src/eval_config.py"
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
export DEVICE_TARGET=$1
|
export DEVICE_TARGET=$1
|
||||||
DEVICE_ID=$2
|
DEVICE_ID=$2
|
||||||
|
|
|
@ -13,13 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
if [ $# != 5 ] ; then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH"
|
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP DATA_PATH"
|
||||||
echo "for example: sh run_standalone_train.sh Ascend 0 52 /path/ende-l128-mindrecord00"
|
echo "for example: sh run_standalone_train.sh Ascend 0 52 8 /path/ende-l128-mindrecord00"
|
||||||
echo "It is better to use absolute path."
|
echo "It is better to use absolute path."
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
rm -rf run_standalone_train
|
rm -rf run_standalone_train
|
||||||
mkdir run_standalone_train
|
mkdir run_standalone_train
|
||||||
|
@ -29,12 +31,14 @@ cd run_standalone_train || exit
|
||||||
export DEVICE_TARGET=$1
|
export DEVICE_TARGET=$1
|
||||||
DEVICE_ID=$2
|
DEVICE_ID=$2
|
||||||
EPOCH_SIZE=$3
|
EPOCH_SIZE=$3
|
||||||
DATA_PATH=$4
|
GRADIENT_ACCUMULATE_STEP=$4
|
||||||
|
DATA_PATH=$5
|
||||||
|
|
||||||
if [ $DEVICE_TARGET == 'Ascend' ];then
|
if [ $DEVICE_TARGET == 'Ascend' ];then
|
||||||
python train.py \
|
python train.py \
|
||||||
--distribute="false" \
|
--distribute="false" \
|
||||||
--epoch_size=$EPOCH_SIZE \
|
--epoch_size=$EPOCH_SIZE \
|
||||||
|
--accumulation_steps=$GRADIENT_ACCUMULATE_STEP \
|
||||||
--device_target=$DEVICE_TARGET \
|
--device_target=$DEVICE_TARGET \
|
||||||
--device_id=$DEVICE_ID \
|
--device_id=$DEVICE_ID \
|
||||||
--enable_save_ckpt="true" \
|
--enable_save_ckpt="true" \
|
||||||
|
|
|
@ -13,7 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Transformer for training."""
|
"""Transformer for training."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
@ -23,6 +25,8 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
from mindspore.communication.management import get_group_size
|
from mindspore.communication.management import get_group_size
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
from .transformer_model import TransformerModel
|
from .transformer_model import TransformerModel
|
||||||
|
|
||||||
|
@ -279,3 +283,190 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
|
||||||
succ = self.optimizer(grads)
|
succ = self.optimizer(grads)
|
||||||
ret = (loss, cond, scaling_sens)
|
ret = (loss, cond, scaling_sens)
|
||||||
return F.depend(ret, succ)
|
return F.depend(ret, succ)
|
||||||
|
|
||||||
|
|
||||||
|
cast = P.Cast()
|
||||||
|
add_grads = C.MultitypeFuncGraph("add_grads")
|
||||||
|
|
||||||
|
|
||||||
|
@add_grads.register("Tensor", "Tensor")
|
||||||
|
def _add_grads(accu_grad, grad):
|
||||||
|
return accu_grad + cast(grad, mstype.float32)
|
||||||
|
|
||||||
|
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||||
|
|
||||||
|
@update_accu_grads.register("Tensor", "Tensor")
|
||||||
|
def _update_accu_grads(accu_grad, grad):
|
||||||
|
succ = True
|
||||||
|
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
|
||||||
|
|
||||||
|
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
|
||||||
|
|
||||||
|
@accumulate_accu_grads.register("Tensor", "Tensor")
|
||||||
|
def _accumulate_accu_grads(accu_grad, grad):
|
||||||
|
succ = True
|
||||||
|
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||||
|
|
||||||
|
|
||||||
|
zeroslike = P.ZerosLike()
|
||||||
|
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||||
|
|
||||||
|
|
||||||
|
@reset_accu_grads.register("Tensor")
|
||||||
|
def _reset_accu_grads(accu_grad):
|
||||||
|
succ = True
|
||||||
|
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph.
|
||||||
|
|
||||||
|
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||||
|
|
||||||
|
For distribution mode, allreduce will only be implemented in the weight updated step,
|
||||||
|
i.e. the sub-step after gradients accumulated N times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||||
|
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||||
|
batch_size * accumulation_steps. Default: 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=8, enable_global_norm=False):
|
||||||
|
super(TransformerTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.set_grad()
|
||||||
|
self.weights = optimizer.parameters
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.accumulation_steps = accumulation_steps
|
||||||
|
self.enable_global_norm = enable_global_norm
|
||||||
|
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||||
|
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||||
|
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||||
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||||
|
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||||
|
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||||
|
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = F.identity
|
||||||
|
self.degree = 1
|
||||||
|
if self.reducer_flag:
|
||||||
|
self.degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
|
self.overflow_reducer = F.identity
|
||||||
|
if self.is_distributed:
|
||||||
|
self.overflow_reducer = P.AllReduce()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_status = P.NPUClearFloatStatus()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
|
self.logical_or = P.LogicalOr()
|
||||||
|
self.not_equal = P.NotEqual()
|
||||||
|
self.select = P.Select()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.loss_scale = None
|
||||||
|
self.loss_scaling_manager = scale_update_cell
|
||||||
|
if scale_update_cell:
|
||||||
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
source_eos_ids,
|
||||||
|
source_eos_mask,
|
||||||
|
target_sos_ids,
|
||||||
|
target_sos_mask,
|
||||||
|
target_eos_ids,
|
||||||
|
target_eos_mask,
|
||||||
|
sens=None):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
source_ids = source_eos_ids
|
||||||
|
source_mask = source_eos_mask
|
||||||
|
target_ids = target_sos_ids
|
||||||
|
target_mask = target_sos_mask
|
||||||
|
label_ids = target_eos_ids
|
||||||
|
label_weights = target_eos_mask
|
||||||
|
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(source_ids,
|
||||||
|
source_mask,
|
||||||
|
target_ids,
|
||||||
|
target_mask,
|
||||||
|
label_ids,
|
||||||
|
label_weights)
|
||||||
|
if sens is None:
|
||||||
|
scaling_sens = self.loss_scale
|
||||||
|
else:
|
||||||
|
scaling_sens = sens
|
||||||
|
# alloc status and clear should be right before gradoperation
|
||||||
|
init = self.alloc_status()
|
||||||
|
init = F.depend(init, loss)
|
||||||
|
clear_status = self.clear_status(init)
|
||||||
|
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||||
|
# update accumulation parameters
|
||||||
|
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||||
|
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||||
|
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||||
|
mean_loss = self.accu_loss / self.local_step
|
||||||
|
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||||
|
|
||||||
|
grads = self.grad(self.network, weights)(source_ids,
|
||||||
|
source_mask,
|
||||||
|
target_ids,
|
||||||
|
target_mask,
|
||||||
|
label_ids,
|
||||||
|
label_weights,
|
||||||
|
self.cast(scaling_sens,
|
||||||
|
mstype.float32))
|
||||||
|
|
||||||
|
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
|
||||||
|
mean_loss = F.depend(mean_loss, accu_succ)
|
||||||
|
|
||||||
|
init = F.depend(init, mean_loss)
|
||||||
|
get_status = self.get_status(init)
|
||||||
|
init = F.depend(init, get_status)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
overflow = self.less_equal(self.base, flag_sum)
|
||||||
|
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||||
|
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||||
|
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||||
|
|
||||||
|
if is_accu_step:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(self.accu_grads)
|
||||||
|
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||||
|
if self.enable_global_norm:
|
||||||
|
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||||
|
else:
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
accu_overflow = F.depend(accu_overflow, grads)
|
||||||
|
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||||
|
overflow = self.less_equal(self.base, accu_overflow)
|
||||||
|
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||||
|
overflow = F.depend(overflow, accu_succ)
|
||||||
|
overflow = self.reshape(overflow, (()))
|
||||||
|
if sens is None:
|
||||||
|
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||||
|
if overflow:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
|
||||||
|
ret = (mean_loss, overflow, scaling_sens)
|
||||||
|
return F.depend(ret, succ)
|
||||||
|
|
|
@ -34,7 +34,8 @@ from mindspore import context
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
||||||
TransformerTrainOneStepWithLossScaleCell
|
TransformerTrainOneStepWithLossScaleCell, \
|
||||||
|
TransformerTrainAccumulationAllReducePostWithLossScaleCell
|
||||||
from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu
|
from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu
|
||||||
from src.dataset import create_transformer_dataset
|
from src.dataset import create_transformer_dataset
|
||||||
from src.lr_schedule import create_dynamic_lr
|
from src.lr_schedule import create_dynamic_lr
|
||||||
|
@ -118,10 +119,10 @@ def argparse_init():
|
||||||
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
|
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
|
||||||
parser.add_argument("--bucket_boundaries", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
|
parser.add_argument("--bucket_boundaries", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
|
||||||
help="sequence length for different bucket")
|
help="sequence length for different bucket")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps, default is 1.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def run_transformer_train():
|
def run_transformer_train():
|
||||||
"""
|
"""
|
||||||
Transformer training.
|
Transformer training.
|
||||||
|
@ -203,8 +204,13 @@ def run_transformer_train():
|
||||||
scale_factor=cfg.scale_factor,
|
scale_factor=cfg.scale_factor,
|
||||||
scale_window=cfg.scale_window)
|
scale_window=cfg.scale_window)
|
||||||
update_cell = scale_manager.get_update_cell()
|
update_cell = scale_manager.get_update_cell()
|
||||||
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
if args.accumulation_steps > 1:
|
||||||
scale_update_cell=update_cell)
|
netwithgrads = TransformerTrainAccumulationAllReducePostWithLossScaleCell(netwithloss, optimizer,
|
||||||
|
update_cell,
|
||||||
|
args.accumulation_steps)
|
||||||
|
else:
|
||||||
|
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||||
|
scale_update_cell=update_cell)
|
||||||
else:
|
else:
|
||||||
netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer)
|
netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue