diff --git a/model_zoo/official/nlp/tinybert/export.py b/model_zoo/official/nlp/tinybert/export.py new file mode 100644 index 00000000000..6adc8ac7bcc --- /dev/null +++ b/model_zoo/official/nlp/tinybert/export.py @@ -0,0 +1,79 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into air models""" + +import re +import argparse +import numpy as np + +from mindspore import Tensor, context +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export + +from src.td_config import td_student_net_cfg +from src.tinybert_model import BertModelCLS + +parser = argparse.ArgumentParser(description='tinybert task distill') +parser.add_argument('--ckpt_file', type=str, required=True, help='tinybert ckpt file.') +parser.add_argument('--output_file', type=str, default='tinybert.air', help='tinybert output air name.') +parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') +args = parser.parse_args() + +DEFAULT_NUM_LABELS = 2 +DEFAULT_SEQ_LENGTH = 128 +task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, + "QNLI": {"num_labels": 2, "seq_length": 128}, + "MNLI": {"num_labels": 3, "seq_length": 128}} + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Task: + """ + Encapsulation class of get the task parameter. + """ + def __init__(self, task_name): + self.task_name = task_name + + @property + def num_labels(self): + if self.task_name in task_params and "num_labels" in task_params[self.task_name]: + return task_params[self.task_name]["num_labels"] + return DEFAULT_NUM_LABELS + + @property + def seq_length(self): + if self.task_name in task_params and "seq_length" in task_params[self.task_name]: + return task_params[self.task_name]["seq_length"] + return DEFAULT_SEQ_LENGTH + +if __name__ == '__main__': + task = Task(args.task_name) + td_student_net_cfg.seq_length = task.seq_length + + eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") + param_dict = load_checkpoint(args.ckpt_file) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('tinybert_', 'bert_', key) + new_key = re.sub('^bert.', '', new_key) + new_param_dict[new_key] = value + + load_param_into_net(eval_model, new_param_dict) + eval_model.set_train(False) + + input_ids = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32)) + token_type_id = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32)) + input_mask = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32)) + + export(eval_model, input_ids, token_type_id, input_mask, file_name=args.output_file, file_format="AIR") diff --git a/model_zoo/official/nlp/transformer/export.py b/model_zoo/official/nlp/transformer/export.py new file mode 100644 index 00000000000..0d462b9406a --- /dev/null +++ b/model_zoo/official/nlp/transformer/export.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into air models""" + +import numpy as np + +from mindspore import Tensor, context +from mindspore.train.serialization import load_param_into_net, export + +from src.transformer_model import TransformerModel +from src.eval_config import cfg, transformer_net_cfg +from eval import load_weights + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +if __name__ == '__main__': + tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False) + + parameter_dict = load_weights(cfg.model_file) + load_param_into_net(tfm_model, parameter_dict) + + source_ids = Tensor(np.ones((1, 128)).astype(np.int32)) + source_mask = Tensor(np.ones((1, 128)).astype(np.int32)) + + dec_len = transformer_net_cfg.max_decode_length + + export(tfm_model, source_ids, source_mask, file_name="len" + str(dec_len) + ".air", file_format="AIR")