gnmt add hub_config

This commit is contained in:
gaojing 2020-12-27 02:40:17 -05:00
parent c52a4381dd
commit 2f2a624281
6 changed files with 139 additions and 105 deletions

View File

@ -135,11 +135,13 @@ The GNMT network script and code result are as follows:
│ ├──lr_scheduler.py // Learning rate scheduler. │ ├──lr_scheduler.py // Learning rate scheduler.
│ ├──optimizer.py // Optimizer. │ ├──optimizer.py // Optimizer.
├── scripts ├── scripts
│ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend. │ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend.
│ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend. │ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend.
│ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend. │ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend.
├── create_dataset.py // dataset preparation. ├── create_dataset.py // Dataset preparation.
├── eval.py // Infer API entry. ├── eval.py // Infer API entry.
├── export.py // Export checkpoint file into air models.
├── mindspore_hub_conf.py // Hub config.
├── requirements.txt // Requirements of third party package. ├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry. ├── train.py // Train API entry.
``` ```

View File

@ -65,33 +65,28 @@ if __name__ == '__main__':
for param in params: for param in params:
value = param.data value = param.data
name = param.name weights_name = param.name
if name not in weights: if weights_name not in weights:
raise ValueError(f"{name} is not found in weights.") raise ValueError(f"{weights_name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f: if isinstance(value, Tensor):
weights_name = name if weights_name in weights:
f.write(weights_name) assert weights_name in weights
f.write("\n") if isinstance(weights[weights_name], Parameter):
if isinstance(value, Tensor): if param.data.dtype == "Float32":
print(name, value.asnumpy().shape) param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
if weights_name in weights: elif param.data.dtype == "Float16":
assert weights_name in weights param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor): elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray): elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype)) param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else: else:
print("weight not found in checkpoint: " + weights_name) param.set_data(weights[weights_name])
param.set_data(zero_weight(value.asnumpy().shape)) else:
f.close() print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
print(" | Load weights successfully.") print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model) tfm_infer = GNMTInferCell(tfm_model)
tfm_infer.set_train(False) tfm_infer.set_train(False)

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
import mindspore.common.dtype as mstype
from config import GNMTConfig
from src.gnmt_model import GNMTNetworkWithLoss, GNMT
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
def create_network(name, *args, **kwargs):
"""create gnmt network."""
if name == "gnmt":
if "config" in kwargs:
config = get_config(kwargs["config"])
else:
raise NotImplementedError(f"Please make sure the configuration file path is correct")
is_training = kwargs.get("is_training", False)
if is_training:
return GNMTNetworkWithLoss(config, is_training=is_training, *args)
return GNMT(config, *args)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -172,6 +172,7 @@ class BeamSearchDecoder(nn.Cell):
max_decode_length=64, max_decode_length=64,
sos_id=2, sos_id=2,
eos_id=3, eos_id=3,
is_using_while=False,
compute_type=mstype.float32): compute_type=mstype.float32):
super(BeamSearchDecoder, self).__init__() super(BeamSearchDecoder, self).__init__()
@ -185,6 +186,7 @@ class BeamSearchDecoder(nn.Cell):
self.cov_penalty_factor = cov_penalty_factor self.cov_penalty_factor = cov_penalty_factor
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.decoder = decoder self.decoder = decoder
self.is_using_while = is_using_while
self.add = P.TensorAdd() self.add = P.TensorAdd()
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
@ -215,7 +217,12 @@ class BeamSearchDecoder(nn.Cell):
self.gather_nd = P.GatherNd() self.gather_nd = P.GatherNd()
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) if self.is_using_while:
self.start = Tensor(0, dtype=mstype.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id),
mstype.int32)
else:
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
self.init_scores = Tensor(init_scores, mstype.float32) self.init_scores = Tensor(init_scores, mstype.float32)
@ -259,7 +266,7 @@ class BeamSearchDecoder(nn.Cell):
self.sub = P.Sub() self.sub = P.Sub()
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None, state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None,
state_finished=None): state_finished=None):
""" """
Beam search one_step output. Beam search one_step output.
@ -359,7 +366,13 @@ class BeamSearchDecoder(nn.Cell):
self.hidden_size)) self.hidden_size))
# update state_seq # update state_seq
state_seq = self.concat((seq, self.expand(word_indices, -1))) if self.is_using_while:
state_seq_new = self.cast(seq, mstype.float32)
word_indices_fp32 = self.cast(word_indices, mstype.float32)
state_seq_new[:, :, idx] = word_indices_fp32
state_seq = self.cast(state_seq_new, mstype.int32)
else:
state_seq = self.concat((seq, self.expand(word_indices, -1)))
cur_input_ids = self.reshape(word_indices, (-1, 1)) cur_input_ids = self.reshape(word_indices, (-1, 1))
state_log_probs = topk_scores state_log_probs = topk_scores
@ -388,11 +401,22 @@ class BeamSearchDecoder(nn.Cell):
decoder_hidden_state = self.decoder_hidden_state decoder_hidden_state = self.decoder_hidden_state
accu_attn_scores = self.accu_attn_scores accu_attn_scores = self.accu_attn_scores
for _ in range(self.max_decode_length + 1): if not self.is_using_while:
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ for _ in range(self.max_decode_length + 1):
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
state_seq, state_length, decoder_hidden_state, accu_attn_scores, state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_finished) state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,
state_finished)
else:
idx = self.start + 1
ends = self.start + self.max_decode_length + 1
while idx < ends:
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores,
state_finished)
idx = idx + 1
# add length penalty scores # add length penalty scores
penalty_len = self.length_penalty(state_length) penalty_len = self.length_penalty(state_length)
# return penalty_len # return penalty_len
@ -408,6 +432,9 @@ class BeamSearchDecoder(nn.Cell):
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
# sort sequence and attention scores # sort sequence and attention scores
predicted_ids = self.gather_nd(state_seq, gather_indices) predicted_ids = self.gather_nd(state_seq, gather_indices)
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] if not self.is_using_while:
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)]
else:
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length]
return predicted_ids return predicted_ids

View File

@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context, Parameter from mindspore import context, Parameter
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from src.dataset import load_dataset from src.dataset import load_dataset
from .gnmt import GNMT from .gnmt import GNMT
@ -37,18 +36,6 @@ context.set_context(
reserve_class_name_in_scope=False) reserve_class_name_in_scope=False)
def get_weight_and_variable(model_path, params):
print("model path is {}".format(model_path))
ms_ckpt = load_checkpoint(model_path)
with open("variable.txt", "w") as f:
for msname in ms_ckpt:
f.write(msname + "\n")
with open("weights.txt", "w") as f:
for param in params:
name = param.name
f.write(name + "\n")
class GNMTInferCell(nn.Cell): class GNMTInferCell(nn.Cell):
""" """
Encapsulation class of GNMT network infer. Encapsulation class of GNMT network infer.
@ -92,38 +79,31 @@ def gnmt_infer(config, dataset):
use_one_hot_embeddings=False) use_one_hot_embeddings=False)
params = tfm_model.trainable_params() params = tfm_model.trainable_params()
get_weight_and_variable(config.existed_ckpt, params)
weights = load_infer_weights(config) weights = load_infer_weights(config)
for param in params: for param in params:
value = param.data value = param.data
name = param.name weights_name = param.name
if name not in weights: if weights_name not in weights:
raise ValueError(f"{name} is not found in weights.") raise ValueError(f"{weights_name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f: if isinstance(value, Tensor):
weights_name = name if weights_name in weights:
f.write(weights_name) assert weights_name in weights
f.write("\n") if isinstance(weights[weights_name], Parameter):
if isinstance(value, Tensor): if param.data.dtype == "Float32":
print(name, value.asnumpy().shape) param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
if weights_name in weights: elif param.data.dtype == "Float16":
assert weights_name in weights param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor): elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray): elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype)) param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else: else:
print("weight not found in checkpoint: " + weights_name) param.set_data(weights[weights_name])
param.set_data(zero_weight(value.asnumpy().shape)) else:
f.close() print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
print(" | Load weights successfully.") print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model) tfm_infer = GNMTInferCell(tfm_model)
model = Model(tfm_infer) model = Model(tfm_infer)

View File

@ -37,36 +37,26 @@ def load_infer_weights(config):
ms_ckpt = load_checkpoint(model_path) ms_ckpt = load_checkpoint(model_path)
is_npz = False is_npz = False
weights = {} weights = {}
with open("variable_after_deal.txt", "w") as f: for param_name in ms_ckpt:
for param_name in ms_ckpt: infer_name = param_name.replace("gnmt.gnmt.", "")
infer_name = param_name.replace("gnmt.gnmt.", "") if infer_name.startswith("embedding_lookup."):
if infer_name.startswith("embedding_lookup."):
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz: if is_npz:
weights[infer_name] = ms_ckpt[param_name] weights[infer_name] = ms_ckpt[param_name]
else: else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy() weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name) infer_name = "beam_decoder.decoder." + infer_name
f.write("\n") if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
f.close() if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
return weights return weights