forked from mindspore-Ecosystem/mindspore
!4814 support minddateset for tinybert
Merge pull request !4814 from dengyutao/tinybert
This commit is contained in:
commit
7127304c67
|
@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode
|
|||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore import log as logger
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.dataset import create_tinybert_dataset, DataType
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
|
||||
|
@ -55,6 +55,7 @@ def run_general_distill():
|
|||
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
@ -99,8 +100,15 @@ def run_general_distill():
|
|||
student_config=bert_student_net_cfg,
|
||||
is_training=True, use_one_hot_embeddings=False)
|
||||
|
||||
if args_opt.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif arg_opt.dataset_type == "mindrecord":
|
||||
dataset_type = DataType.MINDRECORD
|
||||
else:
|
||||
raise Exception("dataset format is not supported yet")
|
||||
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
|
||||
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir,
|
||||
data_type=dataset_type)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print('dataset size: ', dataset_size)
|
||||
print("dataset repeatcount: ", dataset.get_repeat_count())
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore import log as logger
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.dataset import create_tinybert_dataset, DataType
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.assessment_method import Accuracy
|
||||
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||
|
@ -68,7 +68,7 @@ def parse_args():
|
|||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
|
||||
help="The name of the task to train.")
|
||||
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -119,9 +119,17 @@ def run_predistill():
|
|||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
if arg_opt.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif arg_opt.dataset_type == "mindrecord":
|
||||
dataset_type = DataType.MINDRECORD
|
||||
else:
|
||||
raise Exception("dataset format is not supported yet")
|
||||
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir, args_opt.schema_dir)
|
||||
args_opt.train_data_dir, args_opt.schema_dir,
|
||||
data_tpye=dataset_type)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print('td1 dataset size: ', dataset_size)
|
||||
|
|
|
@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \
|
|||
--save_ckpt_path="" \
|
||||
--load_teacher_ckpt_path="" \
|
||||
--data_dir="" \
|
||||
--schema_dir="" > log.txt 2>&1 &
|
||||
--schema_dir="" \
|
||||
--dataset_type="tfrecord" > log.txt 2>&1 &
|
||||
|
|
|
@ -16,26 +16,38 @@
|
|||
"""create tinybert dataset"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
class DataType(Enum):
|
||||
"""Enumerate supported dataset format"""
|
||||
TFRECORD = 1
|
||||
MINDRECORD = 2
|
||||
|
||||
def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
|
||||
do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
do_shuffle="true", data_dir=None, schema_dir=None,
|
||||
data_type=DataType.TFRECORD):
|
||||
"""create tinybert dataset"""
|
||||
files = os.listdir(data_dir)
|
||||
data_files = []
|
||||
for file_name in files:
|
||||
if "record" in file_name:
|
||||
if "record" in file_name and "db" not in file_name:
|
||||
data_files.append(os.path.join(data_dir, file_name))
|
||||
if task == "td":
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
else:
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids"]
|
||||
|
||||
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
if data_type == DataType.MINDRECORD:
|
||||
ds = de.MindDataset(data_files, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
|
||||
else:
|
||||
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
|
|
Loading…
Reference in New Issue