forked from mindspore-Ecosystem/mindspore
!19347 tinybert can been used on ModelArts
Merge pull request !19347 from 郑彬/tinybert2
This commit is contained in:
commit
fb6ec96862
|
@ -15,27 +15,16 @@
|
|||
"""export checkpoint file into air models"""
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.td_config import td_student_net_cfg
|
||||
from src.tinybert_model import BertModelCLS, BertModelNER
|
||||
|
||||
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="tinybert ckpt file.")
|
||||
parser.add_argument("--file_name", type=str, default="tinybert", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
|
||||
help="The type of the task to train.")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
|
||||
help="The name of the task to train.")
|
||||
args = parser.parse_args()
|
||||
from src.model_utils.config import config as args, td_student_net_cfg
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
|
@ -50,6 +39,7 @@ task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
|||
"TNEWS": {"num_labels": 15, "seq_length": 128},
|
||||
"CLUENER": {"num_labels": 43, "seq_length": 128}}
|
||||
|
||||
|
||||
class Task:
|
||||
"""
|
||||
Encapsulation class of get the task parameter.
|
||||
|
@ -69,7 +59,18 @@ class Task:
|
|||
return task_params[self.task_name]["seq_length"]
|
||||
return DEFAULT_SEQ_LENGTH
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
args.device_id = get_device_id()
|
||||
_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
args.ckpt_file = os.path.join(_file_dir, args.ckpt_file)
|
||||
args.file_name = os.path.join(args.output_path, args.file_name)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
"""export function"""
|
||||
task = Task(args.task_name)
|
||||
td_student_net_cfg.seq_length = task.seq_length
|
||||
td_student_net_cfg.batch_size = DEFAULT_BS
|
||||
|
@ -96,3 +97,7 @@ if __name__ == '__main__':
|
|||
|
||||
input_data = [input_ids, token_type_id, input_mask]
|
||||
export(eval_model, *input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_export()
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: './'
|
||||
# ==============================================================================
|
||||
description: 'general_distill'
|
||||
|
||||
distribute: "false"
|
||||
epoch_size: 3
|
||||
device_id: 0
|
||||
device_num: 1
|
||||
save_ckpt_step: 100
|
||||
max_ckpt_num: 1
|
||||
do_shuffle: "true"
|
||||
enable_data_sink: "true"
|
||||
data_sink_steps: 1
|
||||
save_ckpt_path: ''
|
||||
load_teacher_ckpt_path: ''
|
||||
data_dir: ''
|
||||
schema_dir: ''
|
||||
dataset_type: "tfrecord"
|
||||
|
||||
common_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 65536
|
||||
scale_factor: 2
|
||||
scale_window: 1000
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00005 # 5e-5
|
||||
end_learning_rate: 0.00000000000001 # 1e-14
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
bert_teacher_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 30522
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
bert_student_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 30522
|
||||
hidden_size: 384
|
||||
num_hidden_layers: 4
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 1536
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
modelarts_dataset_unzip_name: ""
|
||||
folder_name_under_zip_file: ''
|
||||
|
||||
distribute: "Run distribute, default is false."
|
||||
epoch_size: "Epoch size, default is 1."
|
||||
device_id: "Device id, default is 0."
|
||||
device_num: "Use device nums, default is 1."
|
||||
save_ckpt_step: "Enable data sink, default is true."
|
||||
max_ckpt_num: ""
|
||||
do_shuffle: "Enable shuffle for dataset, default is true."
|
||||
enable_data_sink: "Enable data sink, default is true."
|
||||
data_sink_steps: "Sink steps for each epoch, default is 1."
|
||||
save_ckpt_path: "Save checkpoint path"
|
||||
load_teacher_ckpt_path: "Load checkpoint file path"
|
||||
data_dir: "Data path, it is better to use absolute path"
|
||||
schema_dir: "Schema path, it is better to use absolute path"
|
||||
dataset_type: "dataset type tfrecord/mindrecord, default is tfrecord"
|
||||
|
||||
---
|
||||
# choices
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
distribute: ["true", "false"]
|
||||
do_shuffle: ["true", "false"]
|
||||
enable_data_sink: ["true", "false"]
|
|
@ -16,20 +16,10 @@
|
|||
"""postprocess"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from src.assessment_method import Accuracy, F1
|
||||
from src.td_config import eval_cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='postprocess')
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"],
|
||||
help="assessment_method include: [accuracy, bf1, mf1], default is accuracy")
|
||||
parser.add_argument("--result_path", type=str, default="./result_Files", help="result path")
|
||||
parser.add_argument("--label_path", type=str, default="./preprocess_Result/label_ids.npy", help="label path")
|
||||
args_opt = parser.parse_args()
|
||||
from src.model_utils.config import eval_cfg, config as args_opt
|
||||
|
||||
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
|
|
|
@ -16,20 +16,11 @@
|
|||
"""preprocess"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.td_config import eval_cfg
|
||||
from src.model_utils.config import eval_cfg, config as args_opt
|
||||
from src.dataset import create_tinybert_dataset, DataType
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='preprocess')
|
||||
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
parser.add_argument("--result_path", type=str, default="./preprocess_Result/", help="result path")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif args_opt.dataset_type == "mindrecord":
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
"""general distill script"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import datetime
|
||||
import mindspore.communication.management as D
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -30,40 +30,76 @@ from mindspore import log as logger
|
|||
from mindspore.common import set_seed
|
||||
from src.dataset import create_tinybert_dataset, DataType
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
from src.model_utils.config import config as args_opt, common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
def get_argument():
|
||||
"""Tinybert general distill argument parser."""
|
||||
parser = argparse.ArgumentParser(description='tinybert general distill')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, args_opt.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if args_opt.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(args_opt.data_path, args_opt.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(args_opt.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
args_opt.device_id = get_device_id()
|
||||
args_opt.device_num = get_device_num()
|
||||
args_opt.data_dir = os.path.join(args_opt.data_path, args_opt.data_dir)
|
||||
args_opt.schema_dir = os.path.join(args_opt.data_path, args_opt.schema_dir)
|
||||
args_opt.save_ckpt_path = os.path.join(args_opt.output_path, args_opt.save_ckpt_path)
|
||||
args_opt.load_teacher_ckpt_path = os.path.join(_file_dir, args_opt.load_teacher_ckpt_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_general_distill():
|
||||
"""
|
||||
run general distill
|
||||
"""
|
||||
args_opt = get_argument()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
reserve_class_name_in_scope=False)
|
||||
if args_opt.device_target == "Ascend":
|
||||
|
@ -81,7 +117,7 @@ def run_general_distill():
|
|||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
|
||||
save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
|
@ -164,6 +200,7 @@ def run_general_distill():
|
|||
dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
set_seed(0)
|
||||
run_general_distill()
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
"""task distill script"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import argparse
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
|
@ -29,9 +29,11 @@ from mindspore import log as logger
|
|||
from src.dataset import create_tinybert_dataset, DataType
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.assessment_method import Accuracy, F1
|
||||
from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
|
||||
from src.tinybert_model import BertModelCLS, BertModelNER
|
||||
from src.model_utils.config import config as args_opt, phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
|
||||
|
@ -40,62 +42,9 @@ if not os.path.exists(td_phase1_save_ckpt_dir):
|
|||
os.makedirs(td_phase1_save_ckpt_dir)
|
||||
if not os.path.exists(td_phase2_save_ckpt_dir):
|
||||
os.makedirs(td_phase2_save_ckpt_dir)
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"],
|
||||
help="Do train task, default is true.")
|
||||
parser.add_argument("--do_eval", type=str, default="true", choices=["true", "false"],
|
||||
help="Do eval task, default is true.")
|
||||
parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
|
||||
help="Epoch size for td phase 1, default is 10.")
|
||||
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable data sink, default is true.")
|
||||
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--load_gd_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--load_td1_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
|
||||
help="The type of the task to train.")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"],
|
||||
help="assessment_method include: [accuracy, bf1, mf1], default is accuracy")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args = parser.parse_args()
|
||||
if args.do_train.lower() != "true" and args.do_eval.lower() != "true":
|
||||
raise ValueError("do train or do eval must have one be true, please confirm your config")
|
||||
if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification":
|
||||
raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification")
|
||||
if args.task_name in ["CLUENER"] and args.task_type != "ner":
|
||||
raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner")
|
||||
if args.task_name in ["SST-2", "QNLI", "MNLI"] and \
|
||||
(td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522):
|
||||
logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\
|
||||
"EN vocabs according to the origin paper.")
|
||||
if args.task_name in ["TNEWS", "CLUENER"] and \
|
||||
(td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128):
|
||||
logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \
|
||||
"EN vocabs according to the origin paper.")
|
||||
return args
|
||||
enable_loss_scale = True
|
||||
|
||||
|
||||
args_opt = parse_args()
|
||||
if args_opt.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif args_opt.dataset_type == "mindrecord":
|
||||
|
@ -129,6 +78,8 @@ class Task:
|
|||
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["seq_length"]
|
||||
return DEFAULT_SEQ_LENGTH
|
||||
|
||||
|
||||
task = Task(args_opt.task_name)
|
||||
|
||||
|
||||
|
@ -193,6 +144,7 @@ def run_predistill():
|
|||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
def run_task_distill(ckpt_file):
|
||||
"""
|
||||
run task distill
|
||||
|
@ -269,6 +221,7 @@ def run_task_distill(ckpt_file):
|
|||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
def eval_result_print(assessment_method="accuracy", callback=None):
|
||||
"""print eval result"""
|
||||
if assessment_method == "accuracy":
|
||||
|
@ -282,6 +235,7 @@ def eval_result_print(assessment_method="accuracy", callback=None):
|
|||
else:
|
||||
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
|
||||
|
||||
|
||||
def do_eval_standalone():
|
||||
"""
|
||||
do eval standalone
|
||||
|
@ -332,12 +286,96 @@ def do_eval_standalone():
|
|||
print("==============================================================")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
global td_phase1_save_ckpt_dir
|
||||
global td_phase2_save_ckpt_dir
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, args_opt.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if args_opt.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(args_opt.data_path, args_opt.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(args_opt.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
args_opt.device_id = get_device_id()
|
||||
td_phase1_save_ckpt_dir = os.path.join(args_opt.output_path, 'tinybert_td_phase1_save_ckpt')
|
||||
td_phase2_save_ckpt_dir = os.path.join(args_opt.output_path, 'tinybert_td_phase2_save_ckpt')
|
||||
if not os.path.exists(td_phase1_save_ckpt_dir):
|
||||
os.makedirs(td_phase1_save_ckpt_dir)
|
||||
if not os.path.exists(td_phase2_save_ckpt_dir):
|
||||
os.makedirs(td_phase2_save_ckpt_dir)
|
||||
args_opt.load_teacher_ckpt_path = os.path.join(_file_dir, args_opt.load_teacher_ckpt_path)
|
||||
args_opt.load_gd_ckpt_path = os.path.join(_file_dir, args_opt.load_gd_ckpt_path)
|
||||
args_opt.train_data_dir = os.path.join(args_opt.data_path, args_opt.train_data_dir)
|
||||
args_opt.schema_dir = os.path.join(args_opt.data_path, args_opt.schema_dir)
|
||||
args_opt.eval_data_dir = os.path.join(args_opt.data_path, args_opt.eval_data_dir)
|
||||
args_opt.load_td1_ckpt_path = os.path.join(_file_dir, args_opt.load_td1_ckpt_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_main():
|
||||
"""task_distill function"""
|
||||
global enable_loss_scale
|
||||
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
||||
raise ValueError("do train or do eval must have one be true, please confirm your config")
|
||||
if args_opt.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args_opt.task_type != "classification":
|
||||
raise ValueError(f"{args_opt.task_name} is a classification dataset, please set --task_type=classification")
|
||||
if args_opt.task_name in ["CLUENER"] and args_opt.task_type != "ner":
|
||||
raise ValueError(f"{args_opt.task_name} is a ner dataset, please set --task_type=ner")
|
||||
if args_opt.task_name in ["SST-2", "QNLI", "MNLI"] and \
|
||||
(td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522):
|
||||
logger.warning(f"{args_opt.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "
|
||||
f"EN vocabs according to the origin paper.")
|
||||
if args_opt.task_name in ["TNEWS", "CLUENER"] and \
|
||||
(td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128):
|
||||
logger.warning(f"{args_opt.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for "
|
||||
f"EN vocabs according to the origin paper.")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
reserve_class_name_in_scope=False)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
enable_loss_scale = True
|
||||
if args_opt.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if td_student_net_cfg.compute_type != mstype.float32:
|
||||
|
@ -363,7 +401,7 @@ if __name__ == '__main__':
|
|||
run_predistill()
|
||||
lists = os.listdir(td_phase1_save_ckpt_dir)
|
||||
if lists:
|
||||
lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir+'/'+fn))
|
||||
lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir + '/' + fn))
|
||||
name_ext = os.path.splitext(lists[-1])
|
||||
if name_ext[-1] != ".ckpt":
|
||||
raise ValueError("Invalid file, checkpoint file should be .ckpt file")
|
||||
|
@ -374,3 +412,7 @@ if __name__ == '__main__':
|
|||
raise ValueError("Checkpoint file not exists, please make sure ckpt file has been saved")
|
||||
else:
|
||||
do_eval_standalone()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_main()
|
||||
|
|
|
@ -54,6 +54,7 @@ do
|
|||
export GLOG_logtostderr=0
|
||||
env > env.log
|
||||
taskset -c $cmdopt python ${PROJECT_DIR}/../run_general_distill.py \
|
||||
--config_path="../../gd_config.yaml" \
|
||||
--distribute="true" \
|
||||
--device_target="Ascend" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
|
|
|
@ -31,6 +31,7 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
|||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python ${PROJECT_DIR}/../run_general_distill.py \
|
||||
--config_path="../../gd_config.yaml" \
|
||||
--distribute="true" \
|
||||
--device_target="GPU" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
|
|
|
@ -28,6 +28,7 @@ CUR_DIR=`pwd`
|
|||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_general_distill.py \
|
||||
--config_path="../../gd_config.yaml" \
|
||||
--distribute="false" \
|
||||
--device_target="Ascend" \
|
||||
--epoch_size=3 \
|
||||
|
|
|
@ -26,6 +26,7 @@ CUR_DIR=`pwd`
|
|||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_task_distill.py \
|
||||
--config_path="../../td_config/td_config_sst2.yaml" \
|
||||
--device_target="Ascend" \
|
||||
--device_id=0 \
|
||||
--do_train="true" \
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_general_distill.py and run_task_distill.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from .tinybert_model import BertConfig
|
||||
|
||||
common_cfg = edict({
|
||||
'batch_size': 32,
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 1000,
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
})
|
||||
'''
|
||||
Including two kinds of network: \
|
||||
teacher network: The BERT-base network.
|
||||
student network: The network which is inherited from teacher network.
|
||||
'''
|
||||
bert_teacher_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
bert_student_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=384,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1536,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.tinybert_model import BertConfig
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="pretrain_base_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
# print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def extra_operations(cfg):
|
||||
"""
|
||||
Do extra work on config
|
||||
|
||||
Args:
|
||||
config: Object after instantiation of class 'Config'.
|
||||
"""
|
||||
def create_filter_fun(keywords):
|
||||
return lambda x: not (True in [key in x.name.lower() for key in keywords])
|
||||
|
||||
if cfg.description == 'general_distill':
|
||||
cfg.common_cfg.loss_scale_value = 2 ** 16
|
||||
cfg.common_cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.common_cfg.AdamWeightDecay.decay_filter)
|
||||
cfg.bert_teacher_net_cfg.dtype = mstype.float32
|
||||
cfg.bert_teacher_net_cfg.compute_type = mstype.float16
|
||||
cfg.bert_student_net_cfg.dtype = mstype.float32
|
||||
cfg.bert_student_net_cfg.compute_type = mstype.float16
|
||||
cfg.bert_teacher_net_cfg = BertConfig(**cfg.bert_teacher_net_cfg.__dict__)
|
||||
cfg.bert_student_net_cfg = BertConfig(**cfg.bert_student_net_cfg.__dict__)
|
||||
elif cfg.description == 'task_distill':
|
||||
cfg.phase1_cfg.loss_scale_value = 2 ** 8
|
||||
cfg.phase1_cfg.optimizer_cfg.AdamWeightDecay.decay_filter = create_filter_fun(
|
||||
cfg.phase1_cfg.optimizer_cfg.AdamWeightDecay.decay_filter)
|
||||
cfg.phase2_cfg.loss_scale_value = 2 ** 16
|
||||
cfg.phase2_cfg.optimizer_cfg.AdamWeightDecay.decay_filter = create_filter_fun(
|
||||
cfg.phase2_cfg.optimizer_cfg.AdamWeightDecay.decay_filter)
|
||||
cfg.td_teacher_net_cfg.dtype = mstype.float32
|
||||
cfg.td_teacher_net_cfg.compute_type = mstype.float16
|
||||
cfg.td_student_net_cfg.dtype = mstype.float32
|
||||
cfg.td_student_net_cfg.compute_type = mstype.float16
|
||||
cfg.td_teacher_net_cfg = BertConfig(**cfg.td_teacher_net_cfg.__dict__)
|
||||
cfg.td_student_net_cfg = BertConfig(**cfg.td_student_net_cfg.__dict__)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
def get_abs_path(path_relative):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
return os.path.join(current_dir, path_relative)
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
parser.add_argument("--config_path", type=get_abs_path, default="../../gd_config.yaml",
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
# pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
config_obj = Config(final_config)
|
||||
extra_operations(config_obj)
|
||||
return config_obj
|
||||
|
||||
|
||||
config = get_config()
|
||||
# td_teacher_net_cfg = config.td_teacher_net_cfg
|
||||
# td_student_net_cfg = config.td_student_net_cfg
|
||||
if config.description == 'general_distill':
|
||||
common_cfg = config.common_cfg
|
||||
bert_teacher_net_cfg = config.bert_teacher_net_cfg
|
||||
bert_student_net_cfg = config.bert_student_net_cfg
|
||||
elif config.description == 'task_distill':
|
||||
phase1_cfg = config.phase1_cfg
|
||||
phase2_cfg = config.phase2_cfg
|
||||
eval_cfg = config.eval_cfg
|
||||
td_teacher_net_cfg = config.td_teacher_net_cfg
|
||||
td_student_net_cfg = config.td_student_net_cfg
|
||||
else:
|
||||
pass
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -1,98 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""config script for task distill"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from .tinybert_model import BertConfig
|
||||
|
||||
phase1_cfg = edict({
|
||||
'batch_size': 32,
|
||||
'loss_scale_value': 2 ** 8,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
phase2_cfg = edict({
|
||||
'batch_size': 32,
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
eval_cfg = edict({
|
||||
'batch_size': 32,
|
||||
})
|
||||
|
||||
'''
|
||||
Including two kinds of network: \
|
||||
teacher network: The BERT-base network with finetune.
|
||||
student network: The model which is produced by GD phase.
|
||||
'''
|
||||
td_teacher_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
td_student_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=384,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1536,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
|
@ -0,0 +1,160 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# ==============================================================================
|
||||
description: 'task_distill'
|
||||
task_type: "classification"
|
||||
task_name: ""
|
||||
device_id: 0
|
||||
# task_distill related
|
||||
do_train: "true"
|
||||
do_eval: "true"
|
||||
td_phase1_epoch_size: 10
|
||||
td_phase2_epoch_size: 3
|
||||
do_shuffle: "true"
|
||||
enable_data_sink: "true"
|
||||
save_ckpt_step: 100
|
||||
max_ckpt_num: 1
|
||||
data_sink_steps: 1
|
||||
load_teacher_ckpt_path: ""
|
||||
load_gd_ckpt_path: ""
|
||||
load_td1_ckpt_path: ""
|
||||
train_data_dir: ""
|
||||
eval_data_dir: ""
|
||||
schema_dir: ""
|
||||
assessment_method: "accuracy"
|
||||
dataset_type: "tfrecord"
|
||||
# export related
|
||||
ckpt_file: ''
|
||||
file_name: "tinybert"
|
||||
file_format: "AIR"
|
||||
# postprocess related
|
||||
result_path: "./result_Files"
|
||||
label_path: "./preprocess_Result/label_ids.npy"
|
||||
phase1_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 256
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00005 # 5e-5
|
||||
end_learning_rate: 0.00000000000001 # 1e-14
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
phase2_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 65536
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00002 # 5e-5
|
||||
end_learning_rate: 0.00000000000001 # 1e-14
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
eval_cfg:
|
||||
batch_size: 32
|
||||
|
||||
td_teacher_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 21128
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
td_student_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 21128
|
||||
hidden_size: 384
|
||||
num_hidden_layers: 4
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 1536
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# task_distill related
|
||||
do_train: "Do train task, default is true."
|
||||
do_eval: "Do eval task, default is true."
|
||||
td_phase1_epoch_size: "Epoch size for td phase 1, default is 10."
|
||||
td_phase2_epoch_size: "Epoch size for td phase 2, default is 3."
|
||||
device_id: "Device id, default is 0."
|
||||
do_shuffle: "Enable shuffle for dataset, default is true."
|
||||
enable_data_sink: "Enable data sink, default is true."
|
||||
save_ckpt_step: ""
|
||||
max_ckpt_num: "Enable data sink, default is true."
|
||||
data_sink_steps: "Sink steps for each epoch, default is 1."
|
||||
load_teacher_ckpt_path: "Load checkpoint file path"
|
||||
load_gd_ckpt_path: "Load checkpoint file path"
|
||||
load_td1_ckpt_path: "Load checkpoint file path"
|
||||
train_data_dir: "Data path, it is better to use absolute path"
|
||||
eval_data_dir: "Data path, it is better to use absolute path"
|
||||
schema_dir: "Schema path, it is better to use absolute path"
|
||||
task_type: "The type of the task to train."
|
||||
task_name: "The name of the task to train."
|
||||
assessment_method: "assessment_method include: [accuracy, bf1, mf1], default is accuracy"
|
||||
dataset_type: "dataset type tfrecord/mindrecord, default is tfrecord"
|
||||
# export related
|
||||
ckpt_file: "tinybert ckpt file."
|
||||
file_name: "output file name."
|
||||
file_format: "file format"
|
||||
# postprocess related
|
||||
result_path: "result path"
|
||||
label_path: "label path"
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
do_train: ["true", "false"]
|
||||
do_eval: ["true", "false"]
|
||||
do_shuffle: ["true", "false"]
|
||||
enable_data_sink: ["true", "false"]
|
||||
task_type: ["classification", "ner"]
|
||||
task_name: ["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"]
|
||||
assessment_method: ["accuracy", "bf1", "mf1"]
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
|
@ -0,0 +1,160 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# ==============================================================================
|
||||
description: 'task_distill'
|
||||
task_type: "classification"
|
||||
task_name: ""
|
||||
device_id: 0
|
||||
# task_distill related
|
||||
do_train: "true"
|
||||
do_eval: "true"
|
||||
td_phase1_epoch_size: 10
|
||||
td_phase2_epoch_size: 3
|
||||
do_shuffle: "true"
|
||||
enable_data_sink: "true"
|
||||
save_ckpt_step: 100
|
||||
max_ckpt_num: 1
|
||||
data_sink_steps: 1
|
||||
load_teacher_ckpt_path: ""
|
||||
load_gd_ckpt_path: ""
|
||||
load_td1_ckpt_path: ""
|
||||
train_data_dir: ""
|
||||
eval_data_dir: ""
|
||||
schema_dir: ""
|
||||
assessment_method: "accuracy"
|
||||
dataset_type: "tfrecord"
|
||||
# export related
|
||||
ckpt_file: ''
|
||||
file_name: "tinybert"
|
||||
file_format: "AIR"
|
||||
# postprocess related
|
||||
result_path: "./result_Files"
|
||||
label_path: "./preprocess_Result/label_ids.npy"
|
||||
phase1_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 256
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00005 # 5e-5
|
||||
end_learning_rate: 0.0 # 0.0
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
phase2_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 65536
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00002 # 5e-5
|
||||
end_learning_rate: 0.0 # 0.0
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
eval_cfg:
|
||||
batch_size: 32
|
||||
|
||||
td_teacher_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 30522
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
td_student_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 30522
|
||||
hidden_size: 384
|
||||
num_hidden_layers: 4
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 1536
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# task_distill related
|
||||
do_train: "Do train task, default is true."
|
||||
do_eval: "Do eval task, default is true."
|
||||
td_phase1_epoch_size: "Epoch size for td phase 1, default is 10."
|
||||
td_phase2_epoch_size: "Epoch size for td phase 2, default is 3."
|
||||
device_id: "Device id, default is 0."
|
||||
do_shuffle: "Enable shuffle for dataset, default is true."
|
||||
enable_data_sink: "Enable data sink, default is true."
|
||||
save_ckpt_step: ""
|
||||
max_ckpt_num: "Enable data sink, default is true."
|
||||
data_sink_steps: "Sink steps for each epoch, default is 1."
|
||||
load_teacher_ckpt_path: "Load checkpoint file path"
|
||||
load_gd_ckpt_path: "Load checkpoint file path"
|
||||
load_td1_ckpt_path: "Load checkpoint file path"
|
||||
train_data_dir: "Data path, it is better to use absolute path"
|
||||
eval_data_dir: "Data path, it is better to use absolute path"
|
||||
schema_dir: "Schema path, it is better to use absolute path"
|
||||
task_type: "The type of the task to train."
|
||||
task_name: "The name of the task to train."
|
||||
assessment_method: "assessment_method include: [accuracy, bf1, mf1], default is accuracy"
|
||||
dataset_type: "dataset type tfrecord/mindrecord, default is tfrecord"
|
||||
# export related
|
||||
ckpt_file: "tinybert ckpt file."
|
||||
file_name: "output file name."
|
||||
file_format: "file format"
|
||||
# postprocess related
|
||||
result_path: "result path"
|
||||
label_path: "label path"
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
do_train: ["true", "false"]
|
||||
do_eval: ["true", "false"]
|
||||
do_shuffle: ["true", "false"]
|
||||
enable_data_sink: ["true", "false"]
|
||||
task_type: ["classification", "ner"]
|
||||
task_name: ["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"]
|
||||
assessment_method: ["accuracy", "bf1", "mf1"]
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
|
@ -0,0 +1,160 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# ==============================================================================
|
||||
description: 'task_distill'
|
||||
task_type: "classification"
|
||||
task_name: ""
|
||||
device_id: 0
|
||||
# task_distill related
|
||||
do_train: "true"
|
||||
do_eval: "true"
|
||||
td_phase1_epoch_size: 10
|
||||
td_phase2_epoch_size: 3
|
||||
do_shuffle: "true"
|
||||
enable_data_sink: "true"
|
||||
save_ckpt_step: 100
|
||||
max_ckpt_num: 1
|
||||
data_sink_steps: 1
|
||||
load_teacher_ckpt_path: ""
|
||||
load_gd_ckpt_path: ""
|
||||
load_td1_ckpt_path: ""
|
||||
train_data_dir: ""
|
||||
eval_data_dir: ""
|
||||
schema_dir: ""
|
||||
assessment_method: "accuracy"
|
||||
dataset_type: "tfrecord"
|
||||
# export related
|
||||
ckpt_file: ''
|
||||
file_name: "tinybert"
|
||||
file_format: "AIR"
|
||||
# postprocess related
|
||||
result_path: "./result_Files"
|
||||
label_path: "./preprocess_Result/label_ids.npy"
|
||||
phase1_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 256
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00005 # 5e-5
|
||||
end_learning_rate: 0.00000000000001 # 1e-14
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
phase2_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 65536
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00002 # 5e-5
|
||||
end_learning_rate: 0.00000000000001 # 1e-14
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
eval_cfg:
|
||||
batch_size: 32
|
||||
|
||||
td_teacher_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 21128
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
td_student_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 21128
|
||||
hidden_size: 384
|
||||
num_hidden_layers: 4
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 1536
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# task_distill related
|
||||
do_train: "Do train task, default is true."
|
||||
do_eval: "Do eval task, default is true."
|
||||
td_phase1_epoch_size: "Epoch size for td phase 1, default is 10."
|
||||
td_phase2_epoch_size: "Epoch size for td phase 2, default is 3."
|
||||
device_id: "Device id, default is 0."
|
||||
do_shuffle: "Enable shuffle for dataset, default is true."
|
||||
enable_data_sink: "Enable data sink, default is true."
|
||||
save_ckpt_step: ""
|
||||
max_ckpt_num: "Enable data sink, default is true."
|
||||
data_sink_steps: "Sink steps for each epoch, default is 1."
|
||||
load_teacher_ckpt_path: "Load checkpoint file path"
|
||||
load_gd_ckpt_path: "Load checkpoint file path"
|
||||
load_td1_ckpt_path: "Load checkpoint file path"
|
||||
train_data_dir: "Data path, it is better to use absolute path"
|
||||
eval_data_dir: "Data path, it is better to use absolute path"
|
||||
schema_dir: "Schema path, it is better to use absolute path"
|
||||
task_type: "The type of the task to train."
|
||||
task_name: "The name of the task to train."
|
||||
assessment_method: "assessment_method include: [accuracy, bf1, mf1], default is accuracy"
|
||||
dataset_type: "dataset type tfrecord/mindrecord, default is tfrecord"
|
||||
# export related
|
||||
ckpt_file: "tinybert ckpt file."
|
||||
file_name: "output file name."
|
||||
file_format: "file format"
|
||||
# postprocess related
|
||||
result_path: "result path"
|
||||
label_path: "label path"
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
do_train: ["true", "false"]
|
||||
do_eval: ["true", "false"]
|
||||
do_shuffle: ["true", "false"]
|
||||
enable_data_sink: ["true", "false"]
|
||||
task_type: ["classification", "ner"]
|
||||
task_name: ["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"]
|
||||
assessment_method: ["accuracy", "bf1", "mf1"]
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
|
@ -0,0 +1,160 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# ==============================================================================
|
||||
description: 'task_distill'
|
||||
task_type: "classification"
|
||||
task_name: ""
|
||||
device_id: 0
|
||||
# task_distill related
|
||||
do_train: "true"
|
||||
do_eval: "true"
|
||||
td_phase1_epoch_size: 10
|
||||
td_phase2_epoch_size: 3
|
||||
do_shuffle: "true"
|
||||
enable_data_sink: "true"
|
||||
save_ckpt_step: 100
|
||||
max_ckpt_num: 1
|
||||
data_sink_steps: 1
|
||||
load_teacher_ckpt_path: ""
|
||||
load_gd_ckpt_path: ""
|
||||
load_td1_ckpt_path: ""
|
||||
train_data_dir: ""
|
||||
eval_data_dir: ""
|
||||
schema_dir: ""
|
||||
assessment_method: "accuracy"
|
||||
dataset_type: "tfrecord"
|
||||
# export related
|
||||
ckpt_file: ''
|
||||
file_name: "tinybert"
|
||||
file_format: "AIR"
|
||||
# postprocess related
|
||||
result_path: "./result_Files"
|
||||
label_path: "./preprocess_Result/label_ids.npy"
|
||||
phase1_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 256
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00005 # 5e-5
|
||||
end_learning_rate: 0.0 # 0.0
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
phase2_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 65536
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00002 # 5e-5
|
||||
end_learning_rate: 0.0 # 0.0
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
eval_cfg:
|
||||
batch_size: 32
|
||||
|
||||
td_teacher_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 30522
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
td_student_net_cfg:
|
||||
seq_length: 128
|
||||
vocab_size: 30522
|
||||
hidden_size: 384
|
||||
num_hidden_layers: 4
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 1536
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# task_distill related
|
||||
do_train: "Do train task, default is true."
|
||||
do_eval: "Do eval task, default is true."
|
||||
td_phase1_epoch_size: "Epoch size for td phase 1, default is 10."
|
||||
td_phase2_epoch_size: "Epoch size for td phase 2, default is 3."
|
||||
device_id: "Device id, default is 0."
|
||||
do_shuffle: "Enable shuffle for dataset, default is true."
|
||||
enable_data_sink: "Enable data sink, default is true."
|
||||
save_ckpt_step: ""
|
||||
max_ckpt_num: "Enable data sink, default is true."
|
||||
data_sink_steps: "Sink steps for each epoch, default is 1."
|
||||
load_teacher_ckpt_path: "Load checkpoint file path"
|
||||
load_gd_ckpt_path: "Load checkpoint file path"
|
||||
load_td1_ckpt_path: "Load checkpoint file path"
|
||||
train_data_dir: "Data path, it is better to use absolute path"
|
||||
eval_data_dir: "Data path, it is better to use absolute path"
|
||||
schema_dir: "Schema path, it is better to use absolute path"
|
||||
task_type: "The type of the task to train."
|
||||
task_name: "The name of the task to train."
|
||||
assessment_method: "assessment_method include: [accuracy, bf1, mf1], default is accuracy"
|
||||
dataset_type: "dataset type tfrecord/mindrecord, default is tfrecord"
|
||||
# export related
|
||||
ckpt_file: "tinybert ckpt file."
|
||||
file_name: "output file name."
|
||||
file_format: "file format"
|
||||
# postprocess related
|
||||
result_path: "result path"
|
||||
label_path: "label path"
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
do_train: ["true", "false"]
|
||||
do_eval: ["true", "false"]
|
||||
do_shuffle: ["true", "false"]
|
||||
enable_data_sink: ["true", "false"]
|
||||
task_type: ["classification", "ner"]
|
||||
task_name: ["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"]
|
||||
assessment_method: ["accuracy", "bf1", "mf1"]
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
|
@ -0,0 +1,160 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# ==============================================================================
|
||||
description: 'task_distill'
|
||||
task_type: "classification"
|
||||
task_name: ""
|
||||
device_id: 0
|
||||
# task_distill related
|
||||
do_train: "true"
|
||||
do_eval: "true"
|
||||
td_phase1_epoch_size: 10
|
||||
td_phase2_epoch_size: 3
|
||||
do_shuffle: "true"
|
||||
enable_data_sink: "true"
|
||||
save_ckpt_step: 100
|
||||
max_ckpt_num: 1
|
||||
data_sink_steps: 1
|
||||
load_teacher_ckpt_path: ""
|
||||
load_gd_ckpt_path: ""
|
||||
load_td1_ckpt_path: ""
|
||||
train_data_dir: ""
|
||||
eval_data_dir: ""
|
||||
schema_dir: ""
|
||||
assessment_method: "accuracy"
|
||||
dataset_type: "tfrecord"
|
||||
# export related
|
||||
ckpt_file: ''
|
||||
file_name: "tinybert"
|
||||
file_format: "AIR"
|
||||
# postprocess related
|
||||
result_path: "./result_Files"
|
||||
label_path: "./preprocess_Result/label_ids.npy"
|
||||
phase1_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 256
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00005 # 5e-5
|
||||
end_learning_rate: 0.0 # 0.0
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
phase2_cfg:
|
||||
batch_size: 32
|
||||
loss_scale_value: 65536
|
||||
scale_factor: 2
|
||||
scale_window: 50
|
||||
optimizer_cfg:
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00002 # 5e-5
|
||||
end_learning_rate: 0.0 # 0.0
|
||||
power: 1.0
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
eps: 0.000001 # 1e-6
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
|
||||
eval_cfg:
|
||||
batch_size: 32
|
||||
|
||||
td_teacher_net_cfg:
|
||||
seq_length: 64
|
||||
vocab_size: 30522
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
td_student_net_cfg:
|
||||
seq_length: 64
|
||||
vocab_size: 30522
|
||||
hidden_size: 384
|
||||
num_hidden_layers: 4
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 1536
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: ''
|
||||
# task_distill related
|
||||
do_train: "Do train task, default is true."
|
||||
do_eval: "Do eval task, default is true."
|
||||
td_phase1_epoch_size: "Epoch size for td phase 1, default is 10."
|
||||
td_phase2_epoch_size: "Epoch size for td phase 2, default is 3."
|
||||
device_id: "Device id, default is 0."
|
||||
do_shuffle: "Enable shuffle for dataset, default is true."
|
||||
enable_data_sink: "Enable data sink, default is true."
|
||||
save_ckpt_step: ""
|
||||
max_ckpt_num: "Enable data sink, default is true."
|
||||
data_sink_steps: "Sink steps for each epoch, default is 1."
|
||||
load_teacher_ckpt_path: "Load checkpoint file path"
|
||||
load_gd_ckpt_path: "Load checkpoint file path"
|
||||
load_td1_ckpt_path: "Load checkpoint file path"
|
||||
train_data_dir: "Data path, it is better to use absolute path"
|
||||
eval_data_dir: "Data path, it is better to use absolute path"
|
||||
schema_dir: "Schema path, it is better to use absolute path"
|
||||
task_type: "The type of the task to train."
|
||||
task_name: "The name of the task to train."
|
||||
assessment_method: "assessment_method include: [accuracy, bf1, mf1], default is accuracy"
|
||||
dataset_type: "dataset type tfrecord/mindrecord, default is tfrecord"
|
||||
# export related
|
||||
ckpt_file: "tinybert ckpt file."
|
||||
file_name: "output file name."
|
||||
file_format: "file format"
|
||||
# postprocess related
|
||||
result_path: "result path"
|
||||
label_path: "label path"
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
do_train: ["true", "false"]
|
||||
do_eval: ["true", "false"]
|
||||
do_shuffle: ["true", "false"]
|
||||
enable_data_sink: ["true", "false"]
|
||||
task_type: ["classification", "ner"]
|
||||
task_name: ["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"]
|
||||
assessment_method: ["accuracy", "bf1", "mf1"]
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
Loading…
Reference in New Issue