From 2f2a624281ec634beaa082226f0898960503fbf6 Mon Sep 17 00:00:00 2001 From: gaojing Date: Sun, 27 Dec 2020 02:40:17 -0500 Subject: [PATCH] gnmt add hub_config --- model_zoo/official/nlp/gnmt_v2/README.md | 10 ++-- model_zoo/official/nlp/gnmt_v2/export.py | 45 +++++++------- .../nlp/gnmt_v2/mindspore_hub_conf.py | 40 +++++++++++++ .../nlp/gnmt_v2/src/gnmt_model/beam_search.py | 45 +++++++++++--- .../gnmt_v2/src/gnmt_model/gnmt_for_infer.py | 60 +++++++------------ .../nlp/gnmt_v2/src/utils/load_weights.py | 44 ++++++-------- 6 files changed, 139 insertions(+), 105 deletions(-) create mode 100644 model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py diff --git a/model_zoo/official/nlp/gnmt_v2/README.md b/model_zoo/official/nlp/gnmt_v2/README.md index 521bd1be5a3..64d2745073b 100644 --- a/model_zoo/official/nlp/gnmt_v2/README.md +++ b/model_zoo/official/nlp/gnmt_v2/README.md @@ -135,11 +135,13 @@ The GNMT network script and code result are as follows: │ ├──lr_scheduler.py // Learning rate scheduler. │ ├──optimizer.py // Optimizer. ├── scripts - │ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend. - │ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend. - │ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend. - ├── create_dataset.py // dataset preparation. + │ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend. + │ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend. + │ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend. + ├── create_dataset.py // Dataset preparation. ├── eval.py // Infer API entry. + ├── export.py // Export checkpoint file into air models. + ├── mindspore_hub_conf.py // Hub config. ├── requirements.txt // Requirements of third party package. ├── train.py // Train API entry. ``` diff --git a/model_zoo/official/nlp/gnmt_v2/export.py b/model_zoo/official/nlp/gnmt_v2/export.py index e50b722847a..822ebee1c58 100644 --- a/model_zoo/official/nlp/gnmt_v2/export.py +++ b/model_zoo/official/nlp/gnmt_v2/export.py @@ -65,33 +65,28 @@ if __name__ == '__main__': for param in params: value = param.data - name = param.name - if name not in weights: - raise ValueError(f"{name} is not found in weights.") - with open("weight_after_deal.txt", "a+") as f: - weights_name = name - f.write(weights_name) - f.write("\n") - if isinstance(value, Tensor): - print(name, value.asnumpy().shape) - if weights_name in weights: - assert weights_name in weights - if isinstance(weights[weights_name], Parameter): - if param.data.dtype == "Float32": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) - elif param.data.dtype == "Float16": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"{weights_name} is not found in weights.") + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + if isinstance(weights[weights_name], Parameter): + if param.data.dtype == "Float32": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) + elif param.data.dtype == "Float16": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) - elif isinstance(weights[weights_name], Tensor): - param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) - elif isinstance(weights[weights_name], np.ndarray): - param.set_data(Tensor(weights[weights_name], config.dtype)) - else: - param.set_data(weights[weights_name]) + elif isinstance(weights[weights_name], Tensor): + param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) + elif isinstance(weights[weights_name], np.ndarray): + param.set_data(Tensor(weights[weights_name], config.dtype)) else: - print("weight not found in checkpoint: " + weights_name) - param.set_data(zero_weight(value.asnumpy().shape)) - f.close() + param.set_data(weights[weights_name]) + else: + print("weight not found in checkpoint: " + weights_name) + param.set_data(zero_weight(value.asnumpy().shape)) + print(" | Load weights successfully.") tfm_infer = GNMTInferCell(tfm_model) tfm_infer.set_train(False) diff --git a/model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py b/model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py new file mode 100644 index 00000000000..c85ba28580d --- /dev/null +++ b/model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +import mindspore.common.dtype as mstype + +from config import GNMTConfig +from src.gnmt_model import GNMTNetworkWithLoss, GNMT + + +def get_config(config): + config = GNMTConfig.from_json_file(config) + config.compute_type = mstype.float16 + config.dtype = mstype.float32 + return config + + +def create_network(name, *args, **kwargs): + """create gnmt network.""" + if name == "gnmt": + if "config" in kwargs: + config = get_config(kwargs["config"]) + else: + raise NotImplementedError(f"Please make sure the configuration file path is correct") + is_training = kwargs.get("is_training", False) + if is_training: + return GNMTNetworkWithLoss(config, is_training=is_training, *args) + return GNMT(config, *args) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py index 0c361283c85..fceb9870994 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py @@ -172,6 +172,7 @@ class BeamSearchDecoder(nn.Cell): max_decode_length=64, sos_id=2, eos_id=3, + is_using_while=False, compute_type=mstype.float32): super(BeamSearchDecoder, self).__init__() @@ -185,6 +186,7 @@ class BeamSearchDecoder(nn.Cell): self.cov_penalty_factor = cov_penalty_factor self.max_decode_length = max_decode_length self.decoder = decoder + self.is_using_while = is_using_while self.add = P.TensorAdd() self.expand = P.ExpandDims() @@ -215,7 +217,12 @@ class BeamSearchDecoder(nn.Cell): self.gather_nd = P.GatherNd() self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) - self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) + if self.is_using_while: + self.start = Tensor(0, dtype=mstype.int32) + self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), + mstype.int32) + else: + self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) self.init_scores = Tensor(init_scores, mstype.float32) @@ -259,7 +266,7 @@ class BeamSearchDecoder(nn.Cell): self.sub = P.Sub() def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, - state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None, + state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None, state_finished=None): """ Beam search one_step output. @@ -359,7 +366,13 @@ class BeamSearchDecoder(nn.Cell): self.hidden_size)) # update state_seq - state_seq = self.concat((seq, self.expand(word_indices, -1))) + if self.is_using_while: + state_seq_new = self.cast(seq, mstype.float32) + word_indices_fp32 = self.cast(word_indices, mstype.float32) + state_seq_new[:, :, idx] = word_indices_fp32 + state_seq = self.cast(state_seq_new, mstype.int32) + else: + state_seq = self.concat((seq, self.expand(word_indices, -1))) cur_input_ids = self.reshape(word_indices, (-1, 1)) state_log_probs = topk_scores @@ -388,11 +401,22 @@ class BeamSearchDecoder(nn.Cell): decoder_hidden_state = self.decoder_hidden_state accu_attn_scores = self.accu_attn_scores - for _ in range(self.max_decode_length + 1): - cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ - state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, - state_seq, state_length, decoder_hidden_state, accu_attn_scores, - state_finished) + if not self.is_using_while: + for _ in range(self.max_decode_length + 1): + cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ + state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, + state_seq, state_length, None, decoder_hidden_state, accu_attn_scores, + state_finished) + else: + idx = self.start + 1 + ends = self.start + self.max_decode_length + 1 + while idx < ends: + cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ + state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, + state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, + state_finished) + idx = idx + 1 + # add length penalty scores penalty_len = self.length_penalty(state_length) # return penalty_len @@ -408,6 +432,9 @@ class BeamSearchDecoder(nn.Cell): gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) # sort sequence and attention scores predicted_ids = self.gather_nd(state_seq, gather_indices) - predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] + if not self.is_using_while: + predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] + else: + predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] return predicted_ids diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py index 079a654a19b..2c62caf517a 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py @@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore import context, Parameter from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint from src.dataset import load_dataset from .gnmt import GNMT @@ -37,18 +36,6 @@ context.set_context( reserve_class_name_in_scope=False) -def get_weight_and_variable(model_path, params): - print("model path is {}".format(model_path)) - ms_ckpt = load_checkpoint(model_path) - with open("variable.txt", "w") as f: - for msname in ms_ckpt: - f.write(msname + "\n") - with open("weights.txt", "w") as f: - for param in params: - name = param.name - f.write(name + "\n") - - class GNMTInferCell(nn.Cell): """ Encapsulation class of GNMT network infer. @@ -92,38 +79,31 @@ def gnmt_infer(config, dataset): use_one_hot_embeddings=False) params = tfm_model.trainable_params() - get_weight_and_variable(config.existed_ckpt, params) weights = load_infer_weights(config) - for param in params: value = param.data - name = param.name - if name not in weights: - raise ValueError(f"{name} is not found in weights.") - with open("weight_after_deal.txt", "a+") as f: - weights_name = name - f.write(weights_name) - f.write("\n") - if isinstance(value, Tensor): - print(name, value.asnumpy().shape) - if weights_name in weights: - assert weights_name in weights - if isinstance(weights[weights_name], Parameter): - if param.data.dtype == "Float32": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) - elif param.data.dtype == "Float16": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"{weights_name} is not found in weights.") + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + if isinstance(weights[weights_name], Parameter): + if param.data.dtype == "Float32": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) + elif param.data.dtype == "Float16": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) - elif isinstance(weights[weights_name], Tensor): - param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) - elif isinstance(weights[weights_name], np.ndarray): - param.set_data(Tensor(weights[weights_name], config.dtype)) - else: - param.set_data(weights[weights_name]) + elif isinstance(weights[weights_name], Tensor): + param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) + elif isinstance(weights[weights_name], np.ndarray): + param.set_data(Tensor(weights[weights_name], config.dtype)) else: - print("weight not found in checkpoint: " + weights_name) - param.set_data(zero_weight(value.asnumpy().shape)) - f.close() + param.set_data(weights[weights_name]) + else: + print("weight not found in checkpoint: " + weights_name) + param.set_data(zero_weight(value.asnumpy().shape)) + print(" | Load weights successfully.") tfm_infer = GNMTInferCell(tfm_model) model = Model(tfm_infer) diff --git a/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py b/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py index a29c9b4d12f..9add9b5def9 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py +++ b/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py @@ -37,36 +37,26 @@ def load_infer_weights(config): ms_ckpt = load_checkpoint(model_path) is_npz = False weights = {} - with open("variable_after_deal.txt", "w") as f: - for param_name in ms_ckpt: - infer_name = param_name.replace("gnmt.gnmt.", "") - if infer_name.startswith("embedding_lookup."): - if is_npz: - weights[infer_name] = ms_ckpt[param_name] - else: - weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - infer_name = "beam_decoder.decoder." + infer_name - if is_npz: - weights[infer_name] = ms_ckpt[param_name] - else: - weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - continue - - elif not infer_name.startswith("gnmt_encoder"): - if infer_name.startswith("gnmt_decoder."): - infer_name = infer_name.replace("gnmt_decoder.", "decoder.") - infer_name = "beam_decoder.decoder." + infer_name - + for param_name in ms_ckpt: + infer_name = param_name.replace("gnmt.gnmt.", "") + if infer_name.startswith("embedding_lookup."): if is_npz: weights[infer_name] = ms_ckpt[param_name] else: weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") + infer_name = "beam_decoder.decoder." + infer_name + if is_npz: + weights[infer_name] = ms_ckpt[param_name] + else: + weights[infer_name] = ms_ckpt[param_name].data.asnumpy() + continue + elif not infer_name.startswith("gnmt_encoder"): + if infer_name.startswith("gnmt_decoder."): + infer_name = infer_name.replace("gnmt_decoder.", "decoder.") + infer_name = "beam_decoder.decoder." + infer_name - f.close() + if is_npz: + weights[infer_name] = ms_ckpt[param_name] + else: + weights[infer_name] = ms_ckpt[param_name].data.asnumpy() return weights