export script for tinybert and transformer

This commit is contained in:
yuzhenhua 2020-10-26 16:42:17 +08:00
parent 4bbb854d3c
commit 9417620bf5
2 changed files with 118 additions and 0 deletions

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air models"""
import re
import argparse
import numpy as np
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.td_config import td_student_net_cfg
from src.tinybert_model import BertModelCLS
parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument('--ckpt_file', type=str, required=True, help='tinybert ckpt file.')
parser.add_argument('--output_file', type=str, default='tinybert.air', help='tinybert output air name.')
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
args = parser.parse_args()
DEFAULT_NUM_LABELS = 2
DEFAULT_SEQ_LENGTH = 128
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}}
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Task:
"""
Encapsulation class of get the task parameter.
"""
def __init__(self, task_name):
self.task_name = task_name
@property
def num_labels(self):
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
return task_params[self.task_name]["num_labels"]
return DEFAULT_NUM_LABELS
@property
def seq_length(self):
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
return task_params[self.task_name]["seq_length"]
return DEFAULT_SEQ_LENGTH
if __name__ == '__main__':
task = Task(args.task_name)
td_student_net_cfg.seq_length = task.seq_length
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(args.ckpt_file)
new_param_dict = {}
for key, value in param_dict.items():
new_key = re.sub('tinybert_', 'bert_', key)
new_key = re.sub('^bert.', '', new_key)
new_param_dict[new_key] = value
load_param_into_net(eval_model, new_param_dict)
eval_model.set_train(False)
input_ids = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32))
token_type_id = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32))
input_mask = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32))
export(eval_model, input_ids, token_type_id, input_mask, file_name=args.output_file, file_format="AIR")

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air models"""
import numpy as np
from mindspore import Tensor, context
from mindspore.train.serialization import load_param_into_net, export
from src.transformer_model import TransformerModel
from src.eval_config import cfg, transformer_net_cfg
from eval import load_weights
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if __name__ == '__main__':
tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
parameter_dict = load_weights(cfg.model_file)
load_param_into_net(tfm_model, parameter_dict)
source_ids = Tensor(np.ones((1, 128)).astype(np.int32))
source_mask = Tensor(np.ones((1, 128)).astype(np.int32))
dec_len = transformer_net_cfg.max_decode_length
export(tfm_model, source_ids, source_mask, file_name="len" + str(dec_len) + ".air", file_format="AIR")