forked from mindspore-Ecosystem/mindspore
Add training and evaluation of Transformer on GPU
Add gradients clipping to NASNet training and adjust hyper-parameters
This commit is contained in:
parent
fa5c9c1528
commit
f27f047f14
|
@ -23,7 +23,7 @@ nasnet_a_mobile_config_gpu = edict({
|
|||
'rank': 0,
|
||||
'group_size': 1,
|
||||
'work_nums': 8,
|
||||
'epoch_size': 500,
|
||||
'epoch_size': 600,
|
||||
'keep_checkpoint_max': 100,
|
||||
'ckpt_path': './checkpoint/',
|
||||
'is_save_on_master': 0,
|
||||
|
@ -39,7 +39,7 @@ nasnet_a_mobile_config_gpu = edict({
|
|||
|
||||
### Learning Rate Config
|
||||
# 'lr_decay_method': 'exponential',
|
||||
'lr_init': 0.04,
|
||||
'lr_init': 0.04*8,
|
||||
'lr_decay_rate': 0.97,
|
||||
'num_epoch_per_decay': 2.4,
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype
|
|||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.nasnet_a_mobile import NASNetAMobile, CrossEntropy
|
||||
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
|
||||
from src.lr_generator import get_lr
|
||||
|
||||
|
||||
|
@ -69,13 +69,10 @@ if __name__ == '__main__':
|
|||
batches_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# network
|
||||
net = NASNetAMobile(cfg.num_classes)
|
||||
net_with_loss = NASNetAMobileWithLoss(cfg)
|
||||
if args_opt.resume:
|
||||
ckpt = load_checkpoint(args_opt.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
|
||||
#loss
|
||||
loss = CrossEntropy(smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor)
|
||||
load_param_into_net(net_with_loss, ckpt)
|
||||
|
||||
# learning rate schedule
|
||||
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
|
||||
|
@ -88,26 +85,28 @@ if __name__ == '__main__':
|
|||
resume = split_result[-2].split("-")
|
||||
resume_epoch = int(resume[-1])
|
||||
step_num_in_epoch = int(split_result[-1])
|
||||
assert step_num_in_epoch == ds_train.get_dataset_size()\
|
||||
assert step_num_in_epoch == dataset.get_dataset_size()\
|
||||
, "This script only supports resuming at the end of epoch"
|
||||
lr = lr[(ds_train.get_dataset_size() * (resume_epoch - 1) + step_num_in_epoch):]
|
||||
lr = lr[(dataset.get_dataset_size() * (resume_epoch - 1) + step_num_in_epoch):]
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
|
||||
# optimizer
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
for param in net_with_loss.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
{'order_params': net_with_loss.trainable_params()}]
|
||||
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer)
|
||||
net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
|
||||
net_with_grads.set_train()
|
||||
model = Model(net_with_grads)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Transformer evaluation script."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -97,9 +97,14 @@ def run_transformer_eval():
|
|||
"""
|
||||
Transformer evaluation.
|
||||
"""
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False,
|
||||
device_id=device_id)
|
||||
parser = argparse.ArgumentParser(description='tranformer')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False,
|
||||
device_id=args.device_id)
|
||||
|
||||
dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=cfg.data_file)
|
||||
tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_pretrain_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_PATH"
|
||||
echo "for example: sh run_distribute_pretrain.sh 8 55 /path/ende-l128-mindrecord00"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
rm -rf run_distribute_train
|
||||
mkdir run_distribute_train
|
||||
cp -rf ./src/ train.py ./run_distribute_train
|
||||
cd run_distribute_train || exit
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_PATH=$3
|
||||
echo $RANK_SIZE
|
||||
|
||||
mpirun -n $RANK_SIZE \
|
||||
python train.py \
|
||||
--distribute="true" \
|
||||
--device_target="GPU" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_num=$RANK_SIZE \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
--do_shuffle="true" \
|
||||
--checkpoint_path="" \
|
||||
--save_checkpoint_steps=2500 \
|
||||
--save_checkpoint_num=30 \
|
||||
--data_path=$DATA_PATH \
|
||||
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_eval.sh DEVICE_TARGET DEVICE_ID"
|
||||
echo "for example: sh run_eval.sh Ascend 0"
|
||||
echo "Note: set the checkpoint and dataset path in src/eval_config.py"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
export DEVICE_TARGET=$1
|
||||
DEVICE_ID=$2
|
||||
|
||||
python eval.py \
|
||||
--device_target=$DEVICE_TARGET \
|
||||
--device_id=$DEVICE_ID \
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 52 /path/ende-l128-mindrecord00"
|
||||
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH"
|
||||
echo "for example: sh run_standalone_train.sh Ascend 0 52 /path/ende-l128-mindrecord00"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
@ -26,13 +26,15 @@ mkdir run_standalone_train
|
|||
cp -rf ./src/ train.py ./run_standalone_train
|
||||
cd run_standalone_train || exit
|
||||
|
||||
export DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_PATH=$3
|
||||
export DEVICE_TARGET=$1
|
||||
DEVICE_ID=$2
|
||||
EPOCH_SIZE=$3
|
||||
DATA_PATH=$4
|
||||
|
||||
python train.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_target=$DEVICE_TARGET \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
|
@ -42,4 +44,4 @@ python train.py \
|
|||
--save_checkpoint_num=30 \
|
||||
--data_path=$DATA_PATH \
|
||||
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
|
||||
cd ..
|
||||
cd ..
|
|
@ -23,6 +23,7 @@ cfg = edict({
|
|||
'scale_factor': 2,
|
||||
'scale_window': 2000,
|
||||
'optimizer': 'Adam',
|
||||
'optimizer_adam_beta2': 0.997,
|
||||
'lr_schedule': edict({
|
||||
'learning_rate': 2.0,
|
||||
'warmup_steps': 8000,
|
||||
|
@ -51,6 +52,23 @@ if cfg.transformer_network == 'large':
|
|||
input_mask_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16)
|
||||
transformer_net_cfg_gpu = TransformerConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=36560,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=128,
|
||||
initializer_range=0.02,
|
||||
label_smoothing=0.1,
|
||||
input_mask_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16)
|
||||
if cfg.transformer_network == 'base':
|
||||
transformer_net_cfg = TransformerConfig(
|
||||
batch_size=96,
|
||||
|
|
|
@ -166,7 +166,7 @@ class TransformerTrainOneStepCell(nn.Cell):
|
|||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", parallel_mode)
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
|
@ -228,6 +228,12 @@ reciprocal = P.Reciprocal()
|
|||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
|
@ -255,7 +261,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", parallel_mode)
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
|
@ -266,9 +272,16 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
|
@ -305,10 +318,12 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
target_mask,
|
||||
label_ids,
|
||||
label_weights)
|
||||
# alloc status
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_before_grad(init)
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# alloc status
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
|
@ -327,8 +342,16 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
|
||||
if not self.gpu_target:
|
||||
self.get_status(init)
|
||||
# sum overflow buffer elements, 0: not overflow, >0: overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
# convert flag_sum to scalar
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
|
|
|
@ -35,7 +35,7 @@ from mindspore.common import set_seed
|
|||
|
||||
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
||||
TransformerTrainOneStepWithLossScaleCell
|
||||
from src.config import cfg, transformer_net_cfg
|
||||
from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu
|
||||
from src.dataset import create_transformer_dataset
|
||||
from src.lr_schedule import create_dynamic_lr
|
||||
|
||||
|
@ -73,13 +73,17 @@ class LossCallBack(Callback):
|
|||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
|
||||
f.write("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
f.write("time: {}, epoch: {}, step: {}, loss: {}, overflow: {}, loss_scale: {}".format(
|
||||
time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs[0].asnumpy()),
|
||||
str(cb_params.net_outputs[1].asnumpy()),
|
||||
str(cb_params.net_outputs[2].asnumpy())))
|
||||
f.write('\n')
|
||||
|
||||
|
||||
|
@ -91,6 +95,8 @@ def argparse_init():
|
|||
parser.add_argument("--distribute", type=str, default="false", choices=['true', 'false'],
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--enable_lossscale", type=str, default="true", choices=['true', 'false'],
|
||||
|
@ -116,15 +122,21 @@ def run_transformer_train():
|
|||
"""
|
||||
parser = argparse_init()
|
||||
args, _ = parser.parse_known_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
||||
|
||||
if args.distribute == "true":
|
||||
device_num = args.device_num
|
||||
if args.device_target == "Ascend":
|
||||
device_num = args.device_num
|
||||
D.init('hccl')
|
||||
else:
|
||||
D.init('nccl')
|
||||
device_num = D.get_group_size()
|
||||
rank = get_rank()
|
||||
args.device_id = rank
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
D.init()
|
||||
rank_id = args.device_id % device_num
|
||||
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
else:
|
||||
|
@ -135,27 +147,39 @@ def run_transformer_train():
|
|||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||
dataset_path=args.data_path,
|
||||
bucket_boundaries=args.bucket_boundaries)
|
||||
|
||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
||||
if args.device_target == "Ascend":
|
||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
||||
else:
|
||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg_gpu, True)
|
||||
|
||||
if args.checkpoint_path:
|
||||
parameter_dict = load_checkpoint(args.checkpoint_path)
|
||||
load_param_into_net(netwithloss, parameter_dict)
|
||||
|
||||
hidden_size = transformer_net_cfg.hidden_size if args.device_target == "Ascend" \
|
||||
else transformer_net_cfg_gpu.hidden_size
|
||||
lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
|
||||
training_steps=dataset.get_dataset_size()*args.epoch_size,
|
||||
learning_rate=cfg.lr_schedule.learning_rate,
|
||||
warmup_steps=cfg.lr_schedule.warmup_steps,
|
||||
hidden_size=transformer_net_cfg.hidden_size,
|
||||
hidden_size=hidden_size,
|
||||
start_decay_step=cfg.lr_schedule.start_decay_step,
|
||||
min_lr=cfg.lr_schedule.min_lr), mstype.float32)
|
||||
optimizer = Adam(netwithloss.trainable_params(), lr)
|
||||
|
||||
if args.device_target == "GPU" and cfg.transformer_network == "large":
|
||||
optimizer = Adam(netwithloss.trainable_params(), lr, beta2=cfg.optimizer_adam_beta2)
|
||||
else:
|
||||
optimizer = Adam(netwithloss.trainable_params(), lr)
|
||||
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(rank_id=rank_id)]
|
||||
if args.enable_save_ckpt == "true":
|
||||
if device_num == 1 or (device_num > 1 and rank_id == 0):
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args.save_checkpoint_num)
|
||||
if args.device_target == "Ascend":
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args.save_checkpoint_num)
|
||||
else:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset.get_dataset_size(),
|
||||
keep_checkpoint_max=args.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=save_ckpt_path, config=ckpt_config)
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
||||
|
|
Loading…
Reference in New Issue