forked from mindspore-Ecosystem/mindspore
!4187 tinybert script suit for gpu
Merge pull request !4187 from hanhuifeng/tinybert_suit_gpu
This commit is contained in:
commit
881fc135bf
|
@ -46,7 +46,7 @@ usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T
|
|||
|
||||
options:
|
||||
--distribute whether to run distributely: "true" | "false"
|
||||
--device_target target device to run, currently only support "Ascend"
|
||||
--device_target targeted device to run task: "Ascend" | "GPU"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
|
@ -64,7 +64,7 @@ usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T
|
|||
|
||||
options:
|
||||
--distribute whether to run distributely: "true" | "false"
|
||||
--device_target target device to run, currently only support "Ascend"
|
||||
--device_target targeted device to run task: "Ascend" | "GPU"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--device_num device id to run task
|
||||
|
|
|
@ -20,16 +20,20 @@ import argparse
|
|||
import datetime
|
||||
import numpy
|
||||
import mindspore.communication.management as D
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore import log as logger
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd
|
||||
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
|
||||
|
||||
|
||||
|
||||
def run_general_distill():
|
||||
"""
|
||||
|
@ -53,7 +57,6 @@ def run_general_distill():
|
|||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
|
@ -61,13 +64,17 @@ def run_general_distill():
|
|||
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
if args_opt.distribute == "true":
|
||||
D.init('hccl')
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init('hccl')
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
D.init('nccl')
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
|
@ -75,6 +82,21 @@ def run_general_distill():
|
|||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
enable_loss_scale = True
|
||||
if args_opt.device_target == "GPU":
|
||||
if bert_teacher_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
bert_teacher_net_cfg.compute_type = mstype.float32
|
||||
if bert_student_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
bert_student_net_cfg.compute_type = mstype.float32
|
||||
# Both the forward and backward of the network are calculated using fp32,
|
||||
# and the loss scale is not necessary
|
||||
enable_loss_scale = False
|
||||
|
||||
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
|
||||
teacher_ckpt=args_opt.load_teacher_ckpt_path,
|
||||
student_config=bert_student_net_cfg,
|
||||
|
@ -82,11 +104,11 @@ def run_general_distill():
|
|||
|
||||
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
|
||||
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print('dataset size: ', dataset_size)
|
||||
print("dataset repeatcount: ", dataset.get_repeat_count())
|
||||
if args_opt.enable_data_sink == "true":
|
||||
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.epoch_size
|
||||
|
@ -110,12 +132,13 @@ def run_general_distill():
|
|||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
save_ckpt_dir)]
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
||||
scale_factor=common_cfg.scale_factor,
|
||||
scale_window=common_cfg.scale_window)
|
||||
|
||||
netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
if enable_loss_scale:
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
||||
scale_factor=common_cfg.scale_factor,
|
||||
scale_window=common_cfg.scale_window)
|
||||
netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import os
|
||||
import re
|
||||
import argparse
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
|
@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore import log as logger
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.assessment_method import Accuracy
|
||||
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td
|
||||
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
|
||||
from src.tinybert_model import BertModelCLS
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
|
@ -45,14 +47,14 @@ def parse_args():
|
|||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.")
|
||||
parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.")
|
||||
parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
|
||||
help="Epoch size for td phase 1, default is 10.")
|
||||
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||
|
@ -64,11 +66,43 @@ def parse_args():
|
|||
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
|
||||
help="The name of the task to train.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
args_opt = parse_args()
|
||||
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
DEFAULT_SEQ_LENGTH = 128
|
||||
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
||||
"QNLI": {"num_labels": 2, "seq_length": 128},
|
||||
"MNLI": {"num_labels": 3, "seq_length": 128}}
|
||||
|
||||
|
||||
class Task:
|
||||
"""
|
||||
Encapsulation class of get the task parameter.
|
||||
"""
|
||||
def __init__(self, task_name):
|
||||
self.task_name = task_name
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["num_labels"]
|
||||
return DEFAULT_NUM_LABELS
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["seq_length"]
|
||||
return DEFAULT_SEQ_LENGTH
|
||||
task = Task(args_opt.task_name)
|
||||
|
||||
|
||||
def run_predistill():
|
||||
"""
|
||||
run predistill
|
||||
|
@ -81,7 +115,7 @@ def run_predistill():
|
|||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
is_training=True, task_type='classification',
|
||||
num_labels=args_opt.num_labels, is_predistill=True)
|
||||
num_labels=task.num_labels, is_predistill=True)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
@ -91,8 +125,9 @@ def run_predistill():
|
|||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print('td1 dataset size: ', dataset_size)
|
||||
print('td1 dataset repeatcount: ', dataset.get_repeat_count())
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.td_phase1_epoch_size
|
||||
|
@ -117,10 +152,14 @@ def run_predistill():
|
|||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase1_save_ckpt_dir)]
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
if enable_loss_scale:
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
|
||||
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
|
@ -139,7 +178,7 @@ def run_task_distill(ckpt_file):
|
|||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
is_training=True, task_type='classification',
|
||||
num_labels=args_opt.num_labels, is_predistill=False)
|
||||
num_labels=task.num_labels, is_predistill=False)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
@ -149,6 +188,7 @@ def run_task_distill(ckpt_file):
|
|||
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
print('td2 train dataset size: ', dataset_size)
|
||||
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
|
@ -175,6 +215,7 @@ def run_task_distill(ckpt_file):
|
|||
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir, args_opt.schema_dir)
|
||||
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||
|
@ -185,11 +226,14 @@ def run_task_distill(ckpt_file):
|
|||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase2_save_ckpt_dir)]
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
if enable_loss_scale:
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, train_dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
|
@ -203,7 +247,7 @@ def do_eval_standalone():
|
|||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student")
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
|
@ -213,10 +257,13 @@ def do_eval_standalone():
|
|||
load_param_into_net(eval_model, new_param_dict)
|
||||
eval_model.set_train(False)
|
||||
|
||||
eval_dataset = create_tinybert_dataset('td', batch_size=1,
|
||||
eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size,
|
||||
device_num=1, rank=0, do_shuffle="false",
|
||||
data_dir=args_opt.eval_data_dir,
|
||||
schema_dir=args_opt.schema_dir)
|
||||
print('eval dataset size: ', eval_dataset.get_dataset_size())
|
||||
print('eval dataset batch size: ', eval_dataset.get_batch_size())
|
||||
|
||||
callback = Accuracy()
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in eval_dataset.create_dict_iterator():
|
||||
|
@ -231,9 +278,26 @@ def do_eval_standalone():
|
|||
print("============== acc is {}".format(acc))
|
||||
print("======================================")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
||||
raise ValueError("do_train or do eval must have one be true, please confirm your config")
|
||||
|
||||
enable_loss_scale = True
|
||||
if args_opt.device_target == "GPU":
|
||||
if td_teacher_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
td_teacher_net_cfg.compute_type = mstype.float32
|
||||
if td_student_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
td_student_net_cfg.compute_type = mstype.float32
|
||||
# Both the forward and backward of the network are calculated using fp32,
|
||||
# and the loss scale is not necessary
|
||||
enable_loss_scale = False
|
||||
|
||||
td_teacher_net_cfg.seq_length = task.seq_length
|
||||
td_student_net_cfg.seq_length = task.seq_length
|
||||
|
||||
if args_opt.do_train == "true":
|
||||
# run predistill
|
||||
run_predistill()
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distribute_gd_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR TEACHER_CKPT_PATH"
|
||||
echo "for example: bash run_distribute_gd_for_gpu.sh 8 3 /path/data/ /path/datasetSchema.json /path/bert_base.ckpt"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
SCHEMA_DIR=$4
|
||||
TEACHER_CKPT_PATH=$5
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE \
|
||||
python ${PROJECT_DIR}/../run_general_distill.py \
|
||||
--distribute="true" \
|
||||
--device_target="GPU" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--save_ckpt_path="" \
|
||||
--data_dir=$DATA_DIR \
|
||||
--schema_dir=$SCHEMA_DIR \
|
||||
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &
|
|
@ -32,7 +32,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
|
|||
--do_eval="true" \
|
||||
--td_phase1_epoch_size=10 \
|
||||
--td_phase2_epoch_size=3 \
|
||||
--num_labels=2 \
|
||||
--task_name="" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
|
|
|
@ -19,7 +19,6 @@ import os
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
|
||||
def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
|
||||
do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
|
@ -45,7 +44,5 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
|
|||
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
|
||||
return ds
|
||||
|
|
|
@ -292,6 +292,60 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertTrainCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(BertTrainCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.sens = sens
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
||||
class BertNetworkWithLoss_td(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
|
@ -411,12 +465,12 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
total_loss += cls_loss
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
class BertEvaluationCell(nn.Cell):
|
||||
class BertEvaluationWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertEvaluationCell, self).__init__(auto_prefix=False)
|
||||
super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
|
@ -496,3 +550,54 @@ class BertEvaluationCell(nn.Cell):
|
|||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertEvaluationCell(nn.Cell):
|
||||
"""
|
||||
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(BertEvaluationCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.sens = sens
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
|
|
@ -110,7 +110,10 @@ class EvalCallBack(Callback):
|
|||
if acc > self.global_acc:
|
||||
self.global_acc = acc
|
||||
print("The best acc is {}".format(acc))
|
||||
_exec_save_checkpoint(self.network, "eval_model.ckpt")
|
||||
eval_model_ckpt_file = "eval_model.ckpt"
|
||||
if os.path.exists(eval_model_ckpt_file):
|
||||
os.remove(eval_model_ckpt_file)
|
||||
_exec_save_checkpoint(self.network, eval_model_ckpt_file)
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue